From 103bfc31c6ea6e055af638454394734cac2ab26c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 11 Mar 2024 22:27:10 +0200 Subject: [PATCH] Update dependencies and lots of code * Bump minimum Go version to 1.21 * Add contexts everywhere * Switch database code to new dbutil patterns * Finish switching away from maulogger --- .github/workflows/go.yml | 3 +- .pre-commit-config.yaml | 2 +- CHANGELOG.md | 6 + analytics.go | 8 +- backfillqueue.go | 83 +- bridgestate.go | 2 +- commands.go | 154 ++- custompuppet.go | 21 +- database/backfill.go | 340 ----- database/backfillqueue.go | 253 ++++ database/backfillstate.go | 94 ++ database/database.go | 58 +- database/disappearingmessage.go | 81 +- database/historysync.go | 245 ++-- database/mediabackfillrequest.go | 123 +- database/message.go | 142 +-- database/polloption.go | 143 +-- database/portal.go | 264 ++-- database/puppet.go | 138 +-- database/reaction.go | 74 +- database/upgrades/upgrades.go | 3 +- database/user.go | 157 +-- database/userportal.go | 76 +- disappear.go | 58 +- formatting.go | 17 +- go.mod | 45 +- go.sum | 100 +- historysync.go | 497 +++++--- main.go | 62 +- matrix.go | 78 +- messagetracking.go | 133 +- metrics.go | 34 +- portal.go | 1991 ++++++++++++++++++------------ provisioning.go | 128 +- puppet.go | 137 +- urlpreview.go | 65 +- user.go | 526 +++++--- 37 files changed, 3423 insertions(+), 2918 deletions(-) delete mode 100644 database/backfill.go create mode 100644 database/backfillqueue.go create mode 100644 database/backfillstate.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index b2e717d..3d392af 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -8,7 +8,8 @@ jobs: strategy: fail-fast: false matrix: - go-version: ["1.20", "1.21"] + go-version: ["1.21", "1.22"] + name: Lint ${{ matrix.go-version == '1.22' && '(latest)' || '(old)' }} steps: - uses: actions/checkout@v4 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 64a6eae..6da4e37 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,6 +15,6 @@ repos: - id: go-vet-repo-mod - repo: https://github.com/beeper/pre-commit-go - rev: v0.2.2 + rev: v0.3.1 hooks: - id: zerolog-ban-msgf diff --git a/CHANGELOG.md b/CHANGELOG.md index 949d7df..87ae8c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +# v0.10.6 (unreleased) + +* Bumped minimum Go version to 1.21. +* Added 8-letter code pairing support to provisioning API. +* Added more bugs to fix later. + # v0.10.5 (2023-12-16) * Added support for sending media to channels. diff --git a/analytics.go b/analytics.go index 60ebd0a..167b8b1 100644 --- a/analytics.go +++ b/analytics.go @@ -22,7 +22,7 @@ import ( "fmt" "net/http" - log "maunium.net/go/maulogger/v2" + "github.com/rs/zerolog" "maunium.net/go/mautrix/id" ) @@ -30,7 +30,7 @@ type AnalyticsClient struct { url string key string userID string - log log.Logger + log zerolog.Logger client http.Client } @@ -88,9 +88,9 @@ func (sc *AnalyticsClient) Track(userID id.UserID, event string, properties ...m props["bridge"] = "whatsapp" err := sc.trackSync(userID, event, props) if err != nil { - sc.log.Errorfln("Error tracking %s: %v", event, err) + sc.log.Err(err).Str("event", event).Msg("Error tracking event") } else { - sc.log.Debugln("Tracked", event) + sc.log.Debug().Str("event", event).Msg("Tracked event") } }() } diff --git a/backfillqueue.go b/backfillqueue.go index ab79644..2ca4e7e 100644 --- a/backfillqueue.go +++ b/backfillqueue.go @@ -1,5 +1,5 @@ // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2021 Tulir Asokan, Sumner Evans +// Copyright (C) 2024 Tulir Asokan, Sumner Evans // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -17,22 +17,21 @@ package main import ( + "context" "time" - log "maunium.net/go/maulogger/v2" + "github.com/rs/zerolog" "maunium.net/go/mautrix/id" "maunium.net/go/mautrix-whatsapp/database" ) type BackfillQueue struct { - BackfillQuery *database.BackfillQuery + BackfillQuery *database.BackfillTaskQuery reCheckChannels []chan bool - log log.Logger } func (bq *BackfillQueue) ReCheck() { - bq.log.Infofln("Sending re-checks to %d channels", len(bq.reCheckChannels)) for _, channel := range bq.reCheckChannels { go func(c chan bool) { c <- true @@ -40,12 +39,19 @@ func (bq *BackfillQueue) ReCheck() { } } -func (bq *BackfillQueue) GetNextBackfill(userID id.UserID, backfillTypes []database.BackfillType, waitForBackfillTypes []database.BackfillType, reCheckChannel chan bool) *database.Backfill { +func (bq *BackfillQueue) GetNextBackfill(ctx context.Context, userID id.UserID, backfillTypes []database.BackfillType, waitForBackfillTypes []database.BackfillType, reCheckChannel chan bool) *database.BackfillTask { for { - if !bq.BackfillQuery.HasUnstartedOrInFlightOfType(userID, waitForBackfillTypes) { + if !bq.BackfillQuery.HasUnstartedOrInFlightOfType(ctx, userID, waitForBackfillTypes) { // check for immediate when dealing with deferred - if backfill := bq.BackfillQuery.GetNext(userID, backfillTypes); backfill != nil { - backfill.MarkDispatched() + if backfill, err := bq.BackfillQuery.GetNext(ctx, userID, backfillTypes); err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get next backfill task") + } else if backfill != nil { + err = backfill.MarkDispatched(ctx) + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err). + Int("queue_id", backfill.QueueID). + Msg("Failed to mark backfill task as dispatched") + } return backfill } } @@ -58,38 +64,73 @@ func (bq *BackfillQueue) GetNextBackfill(userID id.UserID, backfillTypes []datab } func (user *User) HandleBackfillRequestsLoop(backfillTypes []database.BackfillType, waitForBackfillTypes []database.BackfillType) { + log := user.zlog.With(). + Str("action", "backfill request loop"). + Any("types", backfillTypes). + Logger() + ctx := log.WithContext(context.TODO()) reCheckChannel := make(chan bool) user.BackfillQueue.reCheckChannels = append(user.BackfillQueue.reCheckChannels, reCheckChannel) for { - req := user.BackfillQueue.GetNextBackfill(user.MXID, backfillTypes, waitForBackfillTypes, reCheckChannel) - user.log.Infofln("Handling backfill request %s", req) + req := user.BackfillQueue.GetNextBackfill(ctx, user.MXID, backfillTypes, waitForBackfillTypes, reCheckChannel) + log.Info().Any("backfill_request", req).Msg("Handling backfill request") + log := log.With(). + Int("queue_id", req.QueueID). + Stringer("portal_jid", req.Portal.JID). + Logger() + ctx := log.WithContext(ctx) - conv := user.bridge.DB.HistorySync.GetConversation(user.MXID, *req.Portal) - if conv == nil { - user.log.Debugfln("Could not find history sync conversation data for %s", req.Portal.String()) - req.MarkDone() + conv, err := user.bridge.DB.HistorySync.GetConversation(ctx, user.MXID, req.Portal) + if err != nil { + log.Err(err).Msg("Failed to get conversation data for backfill request") + continue + } else if conv == nil { + log.Debug().Msg("Couldn't find conversation data for backfill request") + err = req.MarkDone(ctx) + if err != nil { + log.Err(err).Msg("Failed to mark backfill request as done after data was not found") + } continue } portal := user.GetPortalByJID(conv.PortalKey.JID) // Update the client store with basic chat settings. if conv.MuteEndTime.After(time.Now()) { - user.Client.Store.ChatSettings.PutMutedUntil(conv.PortalKey.JID, conv.MuteEndTime) + err = user.Client.Store.ChatSettings.PutMutedUntil(conv.PortalKey.JID, conv.MuteEndTime) + if err != nil { + log.Err(err).Msg("Failed to save muted until time from conversation data") + } } if conv.Archived { - user.Client.Store.ChatSettings.PutArchived(conv.PortalKey.JID, true) + err = user.Client.Store.ChatSettings.PutArchived(conv.PortalKey.JID, true) + if err != nil { + log.Err(err).Msg("Failed to save archived state from conversation data") + } } if conv.Pinned > 0 { - user.Client.Store.ChatSettings.PutPinned(conv.PortalKey.JID, true) + err = user.Client.Store.ChatSettings.PutPinned(conv.PortalKey.JID, true) + if err != nil { + log.Err(err).Msg("Failed to save pinned state from conversation data") + } } if conv.EphemeralExpiration != nil && portal.ExpirationTime != *conv.EphemeralExpiration { + log.Debug(). + Uint32("old_time", portal.ExpirationTime). + Uint32("new_time", *conv.EphemeralExpiration). + Msg("Updating portal ephemeral expiration time") portal.ExpirationTime = *conv.EphemeralExpiration - portal.Update(nil) + err = portal.Update(ctx) + if err != nil { + log.Err(err).Msg("Failed to save portal after updating expiration time") + } } - user.backfillInChunks(req, conv, portal) - req.MarkDone() + user.backfillInChunks(ctx, req, conv, portal) + err = req.MarkDone(ctx) + if err != nil { + log.Err(err).Msg("Failed to mark backfill request as done after backfilling") + } } } diff --git a/bridgestate.go b/bridgestate.go index afd67bf..20fd581 100644 --- a/bridgestate.go +++ b/bridgestate.go @@ -93,7 +93,7 @@ func (prov *ProvisioningAPI) BridgeStatePing(w http.ResponseWriter, r *http.Requ remote = remote.Fill(user) resp.RemoteStates[remote.RemoteID] = remote } - user.log.Debugfln("Responding bridge state in bridge status endpoint: %+v", resp) + user.zlog.Debug().Any("response_data", &resp).Msg("Responding bridge state in bridge status endpoint") jsonResponse(w, http.StatusOK, &resp) if len(resp.RemoteStates) > 0 { user.BridgeState.SetPrev(remote) diff --git a/commands.go b/commands.go index 339a195..79f2913 100644 --- a/commands.go +++ b/commands.go @@ -29,6 +29,7 @@ import ( "strings" "time" + "github.com/rs/zerolog" "github.com/skip2/go-qrcode" "github.com/tidwall/gjson" @@ -37,7 +38,6 @@ import ( "go.mau.fi/whatsmeow/types" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/bridge" "maunium.net/go/mautrix/bridge/commands" "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/event" @@ -117,7 +117,10 @@ func fnSetRelay(ce *WrappedCommandEvent) { ce.Reply("Only bridge admins are allowed to enable relay mode on this instance of the bridge") } else { ce.Portal.RelayUserID = ce.User.MXID - ce.Portal.Update(nil) + err := ce.Portal.Update(ce.Ctx) + if err != nil { + ce.ZLog.Err(err).Msg("Failed to save portal after setting relay user") + } ce.Reply("Messages from non-logged-in users in this room will now be bridged through your WhatsApp account") } } @@ -139,7 +142,10 @@ func fnUnsetRelay(ce *WrappedCommandEvent) { ce.Reply("Only bridge admins are allowed to enable relay mode on this instance of the bridge") } else { ce.Portal.RelayUserID = "" - ce.Portal.Update(nil) + err := ce.Portal.Update(ce.Ctx) + if err != nil { + ce.ZLog.Err(err).Msg("Failed to save portal after clearing relay user") + } ce.Reply("Messages from non-logged-in users will no longer be bridged in this room") } } @@ -246,7 +252,7 @@ func fnJoin(ce *WrappedCommandEvent) { ce.Reply("Failed to join group: %v", err) return } - ce.Log.Debugln("%s successfully joined group %s", ce.User.MXID, jid) + ce.ZLog.Debug().Stringer("group_jid", jid).Msg("User successfully joined WhatsApp group with link") ce.Reply("Successfully joined group `%s`, the portal should be created momentarily", jid) } else if strings.HasPrefix(ce.Args[0], whatsmeow.NewsletterLinkPrefix) { info, err := ce.User.Client.GetNewsletterInfoWithInvite(ce.Args[0]) @@ -259,14 +265,14 @@ func fnJoin(ce *WrappedCommandEvent) { ce.Reply("Failed to follow channel: %v", err) return } - ce.Log.Debugln("%s successfully followed channel %s", ce.User.MXID, info.ID) + ce.ZLog.Debug().Stringer("channel_jid", info.ID).Msg("User successfully followed WhatsApp channel with link") ce.Reply("Successfully followed channel `%s`, the portal should be created momentarily", info.ID) } else { ce.Reply("That doesn't look like a WhatsApp invite link") } } -func tryDecryptEvent(crypto bridge.Crypto, evt *event.Event) (json.RawMessage, error) { +func tryDecryptEvent(ce *WrappedCommandEvent, evt *event.Event) (json.RawMessage, error) { var data json.RawMessage if evt.Type != event.EventEncrypted { data = evt.Content.VeryRaw @@ -275,7 +281,7 @@ func tryDecryptEvent(crypto bridge.Crypto, evt *event.Event) (json.RawMessage, e if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { return nil, err } - decrypted, err := crypto.Decrypt(evt) + decrypted, err := ce.Bridge.Crypto.Decrypt(ce.Ctx, evt) if err != nil { return nil, err } @@ -311,11 +317,11 @@ var cmdAccept = &commands.FullHandler{ func fnAccept(ce *WrappedCommandEvent) { if len(ce.ReplyTo) == 0 { ce.Reply("You must reply to a group invite message when using this command.") - } else if evt, err := ce.Portal.MainIntent().GetEvent(ce.RoomID, ce.ReplyTo); err != nil { - ce.Log.Errorln("Failed to get event %s to handle !wa accept command: %v", ce.ReplyTo, err) + } else if evt, err := ce.Portal.MainIntent().GetEvent(ce.Ctx, ce.RoomID, ce.ReplyTo); err != nil { + ce.ZLog.Err(err).Stringer("reply_to_mxid", ce.ReplyTo).Msg("Failed to get reply target event to handle !wa accept command") ce.Reply("Failed to get reply event") - } else if rawContent, err := tryDecryptEvent(ce.Bridge.Crypto, evt); err != nil { - ce.Log.Errorln("Failed to decrypt event %s to handle !wa accept command: %v", ce.ReplyTo, err) + } else if rawContent, err := tryDecryptEvent(ce, evt); err != nil { + ce.ZLog.Err(err).Stringer("reply_to_mxid", ce.ReplyTo).Msg("Failed to decrypt reply target event to handle !wa accept command") ce.Reply("Failed to decrypt reply event") } else if meta, err := parseInviteMeta(rawContent); err != nil || meta == nil { ce.Reply("That doesn't look like a group invite message.") @@ -344,16 +350,16 @@ func fnCreate(ce *WrappedCommandEvent) { return } - members, err := ce.Bot.JoinedMembers(ce.RoomID) + members, err := ce.Bot.JoinedMembers(ce.Ctx, ce.RoomID) if err != nil { ce.Reply("Failed to get room members: %v", err) return } var roomNameEvent event.RoomNameEventContent - err = ce.Bot.StateEvent(ce.RoomID, event.StateRoomName, "", &roomNameEvent) + err = ce.Bot.StateEvent(ce.Ctx, ce.RoomID, event.StateRoomName, "", &roomNameEvent) if err != nil && !errors.Is(err, mautrix.MNotFound) { - ce.Log.Errorln("Failed to get room name to create group:", err) + ce.ZLog.Err(err).Msg("Failed to get room name to create group") ce.Reply("Failed to get room name") return } else if len(roomNameEvent.Name) == 0 { @@ -362,15 +368,17 @@ func fnCreate(ce *WrappedCommandEvent) { } var encryptionEvent event.EncryptionEventContent - err = ce.Bot.StateEvent(ce.RoomID, event.StateEncryption, "", &encryptionEvent) + err = ce.Bot.StateEvent(ce.Ctx, ce.RoomID, event.StateEncryption, "", &encryptionEvent) if err != nil && !errors.Is(err, mautrix.MNotFound) { + ce.ZLog.Err(err).Msg("Failed to get room encryption status to create group") ce.Reply("Failed to get room encryption status") return } var createEvent event.CreateEventContent - err = ce.Bot.StateEvent(ce.RoomID, event.StateCreate, "", &createEvent) + err = ce.Bot.StateEvent(ce.Ctx, ce.RoomID, event.StateCreate, "", &createEvent) if err != nil && !errors.Is(err, mautrix.MNotFound) { + ce.ZLog.Err(err).Msg("Failed to get room create event to create group") ce.Reply("Failed to get room create event") return } @@ -395,7 +403,11 @@ func fnCreate(ce *WrappedCommandEvent) { // TODO check m.space.parent to create rooms directly in communities messageID := ce.User.Client.GenerateMessageID() - ce.Log.Infofln("Creating group for %s with name %s and participants %+v (create key: %s)", ce.RoomID, roomNameEvent.Name, participants, messageID) + ce.ZLog.Info(). + Str("room_name", roomNameEvent.Name). + Any("participants", participants). + Str("create_key", messageID). + Msg("Creating WhatsApp group for Matrix room") ce.User.createKeyDedup = messageID resp, err := ce.User.Client.CreateGroup(whatsmeow.ReqCreateGroup{ CreateKey: messageID, @@ -409,21 +421,25 @@ func fnCreate(ce *WrappedCommandEvent) { ce.Reply("Failed to create group: %v", err) return } + ce.ZLog.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("group_jid", resp.JID.String()) + }) portal := ce.User.GetPortalByJID(resp.JID) portal.roomCreateLock.Lock() defer portal.roomCreateLock.Unlock() if len(portal.MXID) != 0 { - portal.log.Warnln("Detected race condition in room creation") + ce.ZLog.Warn().Msg("Detected race condition in room creation") // TODO race condition, clean up the old room } portal.MXID = ce.RoomID + portal.updateLogger() portal.Name = roomNameEvent.Name portal.IsParent = resp.IsParent portal.Encrypted = encryptionEvent.Algorithm == id.AlgorithmMegolmV1 if !portal.Encrypted && ce.Bridge.Config.Bridge.Encryption.Default { - _, err = portal.MainIntent().SendStateEvent(portal.MXID, event.StateEncryption, "", portal.GetEncryptionEventContent()) + _, err = portal.MainIntent().SendStateEvent(ce.Ctx, portal.MXID, event.StateEncryption, "", portal.GetEncryptionEventContent()) if err != nil { - portal.log.Warnln("Failed to enable encryption in room:", err) + ce.ZLog.Err(err).Msg("Failed to enable encryption in room") if errors.Is(err, mautrix.MForbidden) { ce.Reply("I don't seem to have permission to enable encryption in this room.") } else { @@ -433,8 +449,11 @@ func fnCreate(ce *WrappedCommandEvent) { portal.Encrypted = true } - portal.Update(nil) - portal.UpdateBridgeInfo() + err = portal.Update(ce.Ctx) + if err != nil { + ce.ZLog.Err(err).Msg("Failed to save portal after creating group") + } + portal.UpdateBridgeInfo(ce.Ctx) ce.User.createKeyDedup = "" ce.Reply("Successfully created WhatsApp group %s", portal.Key.JID) @@ -512,7 +531,7 @@ func fnLogin(ce *WrappedCommandEvent) { } } if qrEventID != "" { - _, _ = ce.Bot.RedactEvent(ce.RoomID, qrEventID) + _, _ = ce.Bot.RedactEvent(ce.Ctx, ce.RoomID, qrEventID) } } @@ -529,9 +548,9 @@ func (user *User) sendQR(ce *WrappedCommandEvent, code string, prevEvent id.Even if len(prevEvent) != 0 { content.SetEdit(prevEvent) } - resp, err := ce.Bot.SendMessageEvent(ce.RoomID, event.EventMessage, &content) + resp, err := ce.Bot.SendMessageEvent(ce.Ctx, ce.RoomID, event.EventMessage, &content) if err != nil { - user.log.Errorln("Failed to send edited QR code to user:", err) + ce.ZLog.Err(err).Msg("Failed to send edited QR code to user") } else if len(prevEvent) == 0 { prevEvent = resp.EventID } @@ -541,16 +560,16 @@ func (user *User) sendQR(ce *WrappedCommandEvent, code string, prevEvent id.Even func (user *User) uploadQR(ce *WrappedCommandEvent, code string) (id.ContentURI, bool) { qrCode, err := qrcode.Encode(code, qrcode.Low, 256) if err != nil { - user.log.Errorln("Failed to encode QR code:", err) + ce.ZLog.Err(err).Msg("Failed to encode QR code") ce.Reply("Failed to encode QR code: %v", err) return id.ContentURI{}, false } bot := user.bridge.AS.BotClient() - resp, err := bot.UploadBytes(qrCode, "image/png") + resp, err := bot.UploadBytes(ce.Ctx, qrCode, "image/png") if err != nil { - user.log.Errorln("Failed to upload QR code:", err) + ce.ZLog.Err(err).Msg("Failed to upload QR code") ce.Reply("Failed to upload QR code: %v", err) return id.ContentURI{}, false } @@ -578,14 +597,14 @@ func fnLogout(ce *WrappedCommandEvent) { puppet.ClearCustomMXID() err := ce.User.Client.Logout() if err != nil { - ce.User.log.Warnln("Error while logging out:", err) + ce.ZLog.Err(err).Msg("Unknown error while logging out") ce.Reply("Unknown error while logging out: %v", err) return } ce.User.Session = nil ce.User.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut}) ce.User.DeleteConnection() - ce.User.DeleteSession() + ce.User.DeleteSession(ce.Ctx) ce.Reply("Logged out successfully.") } @@ -620,10 +639,13 @@ func fnTogglePresence(ce *WrappedCommandEvent) { if ce.User.IsLoggedIn() { err := ce.User.Client.SendPresence(newPresence) if err != nil { - ce.User.log.Warnln("Failed to set presence:", err) + ce.ZLog.Err(err).Msg("Failed to send presence to WhatsApp") } } - customPuppet.Update() + err := customPuppet.Update(ce.Ctx) + if err != nil { + ce.ZLog.Err(err).Msg("Failed to save puppet after toggling presence") + } } var cmdDeleteSession = &commands.FullHandler{ @@ -642,7 +664,7 @@ func fnDeleteSession(ce *WrappedCommandEvent) { } ce.User.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut}) ce.User.DeleteConnection() - ce.User.DeleteSession() + ce.User.DeleteSession(ce.Ctx) ce.Reply("Session information purged") } @@ -716,19 +738,19 @@ func fnPing(ce *WrappedCommandEvent) { } } -func canDeletePortal(portal *Portal, userID id.UserID) bool { +func canDeletePortal(ce *WrappedCommandEvent, portal *Portal) bool { if len(portal.MXID) == 0 { return false } - members, err := portal.MainIntent().JoinedMembers(portal.MXID) + members, err := portal.MainIntent().JoinedMembers(ce.Ctx, portal.MXID) if err != nil { - portal.log.Errorfln("Failed to get joined members to check if portal can be deleted by %s: %v", userID, err) + ce.ZLog.Err(err).Stringer("portal_mxid", portal.MXID).Msg("Failed to get joined members to check if portal can be deleted by user") return false } for otherUser := range members.Joined { _, isPuppet := portal.bridge.ParsePuppetMXID(otherUser) - if isPuppet || otherUser == portal.bridge.Bot.UserID || otherUser == userID { + if isPuppet || otherUser == portal.bridge.Bot.UserID || otherUser == ce.User.MXID { continue } user := portal.bridge.GetUserByMXID(otherUser) @@ -750,14 +772,14 @@ var cmdDeletePortal = &commands.FullHandler{ } func fnDeletePortal(ce *WrappedCommandEvent) { - if !ce.User.Admin && !canDeletePortal(ce.Portal, ce.User.MXID) { + if !ce.User.Admin && !canDeletePortal(ce, ce.Portal) { ce.Reply("Only bridge admins can delete portals with other Matrix users") return } - ce.Portal.log.Infoln(ce.User.MXID, "requested deletion of portal.") - ce.Portal.Delete() - ce.Portal.Cleanup(false) + ce.ZLog.Info().Msg("User requested deletion of current portal") + ce.Portal.Delete(ce.Ctx) + ce.Portal.Cleanup(ce.Ctx, false) } var cmdDeleteAllPortals = &commands.FullHandler{ @@ -778,7 +800,7 @@ func fnDeleteAllPortals(ce *WrappedCommandEvent) { } else { portalsToDelete = portals[:0] for _, portal := range portals { - if canDeletePortal(portal, ce.User.MXID) { + if canDeletePortal(ce, portal) { portalsToDelete = append(portalsToDelete, portal) } } @@ -790,7 +812,7 @@ func fnDeleteAllPortals(ce *WrappedCommandEvent) { leave := func(portal *Portal) { if len(portal.MXID) > 0 { - _, _ = portal.MainIntent().KickUser(portal.MXID, &mautrix.ReqKickUser{ + _, _ = portal.MainIntent().KickUser(ce.Ctx, portal.MXID, &mautrix.ReqKickUser{ Reason: "Deleting portal", UserID: ce.User.MXID, }) @@ -801,21 +823,21 @@ func fnDeleteAllPortals(ce *WrappedCommandEvent) { intent := customPuppet.CustomIntent() leave = func(portal *Portal) { if len(portal.MXID) > 0 { - _, _ = intent.LeaveRoom(portal.MXID) - _, _ = intent.ForgetRoom(portal.MXID) + _, _ = intent.LeaveRoom(ce.Ctx, portal.MXID) + _, _ = intent.ForgetRoom(ce.Ctx, portal.MXID) } } } ce.Reply("Found %d portals, deleting...", len(portalsToDelete)) for _, portal := range portalsToDelete { - portal.Delete() + portal.Delete(ce.Ctx) leave(portal) } ce.Reply("Finished deleting portal info. Now cleaning up rooms in background.") go func() { for _, portal := range portalsToDelete { - portal.Cleanup(false) + portal.Cleanup(ce.Ctx, false) } ce.Reply("Finished background cleanup of deleted portal rooms.") }() @@ -882,7 +904,7 @@ func fnList(ce *WrappedCommandEvent) { } var err error page := 1 - max := 100 + maxPerPage := 100 if len(ce.Args) > 1 { page, err = strconv.Atoi(ce.Args[1]) if err != nil || page <= 0 { @@ -891,11 +913,11 @@ func fnList(ce *WrappedCommandEvent) { } } if len(ce.Args) > 2 { - max, err = strconv.Atoi(ce.Args[2]) - if err != nil || max <= 0 { + maxPerPage, err = strconv.Atoi(ce.Args[2]) + if err != nil || maxPerPage <= 0 { ce.Reply("\"%s\" isn't a valid number of items per page", ce.Args[2]) return - } else if max > 400 { + } else if maxPerPage > 400 { ce.Reply("Warning: a high number of items per page may fail to send a reply") } } @@ -924,8 +946,8 @@ func fnList(ce *WrappedCommandEvent) { ce.Reply("No %s found", strings.ToLower(typeName)) return } - pages := int(math.Ceil(float64(len(result)) / float64(max))) - if (page-1)*max >= len(result) { + pages := int(math.Ceil(float64(len(result)) / float64(maxPerPage))) + if (page-1)*maxPerPage >= len(result) { if pages == 1 { ce.Reply("There is only 1 page of %s", strings.ToLower(typeName)) } else { @@ -933,11 +955,11 @@ func fnList(ce *WrappedCommandEvent) { } return } - lastIndex := page * max + lastIndex := page * maxPerPage if lastIndex > len(result) { lastIndex = len(result) } - result = result[(page-1)*max : lastIndex] + result = result[(page-1)*maxPerPage : lastIndex] ce.Reply("### %s (page %d of %d)\n\n%s", typeName, page, pages, strings.Join(result, "\n")) } @@ -1036,13 +1058,13 @@ func fnOpen(ce *WrappedCommandEvent) { } jid = newsletterMetadata.ID } - ce.Log.Debugln("Importing", jid, "for", ce.User.MXID) + ce.ZLog.Debug().Stringer("chat_jid", jid).Msg("Importing chat for user") portal := ce.User.GetPortalByJID(jid) if len(portal.MXID) > 0 { - portal.UpdateMatrixRoom(ce.User, groupInfo, newsletterMetadata) + portal.UpdateMatrixRoom(ce.Ctx, ce.User, groupInfo, newsletterMetadata) ce.Reply("Portal room synced.") } else { - err = portal.CreateMatrixRoom(ce.User, groupInfo, newsletterMetadata, true, true) + err = portal.CreateMatrixRoom(ce.Ctx, ce.User, groupInfo, newsletterMetadata, true, true) if err != nil { ce.Reply("Failed to create room: %v", err) } else { @@ -1085,7 +1107,7 @@ func fnPM(ce *WrappedCommandEvent) { return } - portal, puppet, justCreated, err := user.StartPM(targetUser.JID, "manual PM command") + portal, puppet, justCreated, err := user.StartPM(ce.Ctx, targetUser.JID, "manual PM command") if err != nil { ce.Reply("Failed to create portal room: %v", err) } else if !justCreated { @@ -1154,11 +1176,16 @@ func fnSync(ce *WrappedCommandEvent) { ce.Reply("Personal filtering spaces are not enabled on this instance of the bridge") return } - keys := ce.Bridge.DB.Portal.FindPrivateChatsNotInSpace(ce.User.JID) + keys, err := ce.Bridge.DB.Portal.FindPrivateChatsNotInSpace(ce.Ctx, ce.User.JID) + if err != nil { + ce.ZLog.Err(err).Msg("Failed to get list of private chats not in space") + ce.Reply("Failed to get list of private chats not in space") + return + } count := 0 for _, key := range keys { portal := ce.Bridge.GetPortalByJID(key) - portal.addToPersonalSpace(ce.User) + portal.addToPersonalSpace(ce.Ctx, ce.User) count++ } plural := "s" @@ -1208,6 +1235,9 @@ func fnDisappearingTimer(ce *WrappedCommandEvent) { ce.Portal.ExpirationTime = prevExpirationTime return } - ce.Portal.Update(nil) + err = ce.Portal.Update(ce.Ctx) + if err != nil { + ce.ZLog.Err(err).Msg("Failed to save portal after setting disappearing timer") + } ce.React("✅") } diff --git a/custompuppet.go b/custompuppet.go index 47ae104..8fccc57 100644 --- a/custompuppet.go +++ b/custompuppet.go @@ -17,6 +17,9 @@ package main import ( + "context" + "fmt" + "maunium.net/go/mautrix/id" ) @@ -24,8 +27,11 @@ func (puppet *Puppet) SwitchCustomMXID(accessToken string, mxid id.UserID) error puppet.CustomMXID = mxid puppet.AccessToken = accessToken puppet.EnablePresence = puppet.bridge.Config.Bridge.DefaultBridgePresence - puppet.Update() - err := puppet.StartCustomMXID(false) + err := puppet.Update(context.TODO()) + if err != nil { + return fmt.Errorf("failed to save access token: %w", err) + } + err = puppet.StartCustomMXID(false) if err != nil { return err } @@ -45,12 +51,15 @@ func (puppet *Puppet) ClearCustomMXID() { puppet.customIntent = nil puppet.customUser = nil if save { - puppet.Update() + err := puppet.Update(context.TODO()) + if err != nil { + puppet.zlog.Err(err).Msg("Failed to clear custom MXID") + } } } func (puppet *Puppet) StartCustomMXID(reloginOnFail bool) error { - newIntent, newAccessToken, err := puppet.bridge.DoublePuppet.Setup(puppet.CustomMXID, puppet.AccessToken, reloginOnFail) + newIntent, newAccessToken, err := puppet.bridge.DoublePuppet.Setup(context.TODO(), puppet.CustomMXID, puppet.AccessToken, reloginOnFail) if err != nil { puppet.ClearCustomMXID() return err @@ -60,11 +69,11 @@ func (puppet *Puppet) StartCustomMXID(reloginOnFail bool) error { puppet.bridge.puppetsLock.Unlock() if puppet.AccessToken != newAccessToken { puppet.AccessToken = newAccessToken - puppet.Update() + err = puppet.Update(context.TODO()) } puppet.customIntent = newIntent puppet.customUser = puppet.bridge.GetUserByMXID(puppet.CustomMXID) - return nil + return err } func (user *User) tryAutomaticDoublePuppeting() { diff --git a/database/backfill.go b/database/backfill.go deleted file mode 100644 index 273370b..0000000 --- a/database/backfill.go +++ /dev/null @@ -1,340 +0,0 @@ -// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2021 Tulir Asokan, Sumner Evans -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package database - -import ( - "database/sql" - "errors" - "fmt" - "strconv" - "strings" - "sync" - "time" - - "go.mau.fi/util/dbutil" - log "maunium.net/go/maulogger/v2" - - "maunium.net/go/mautrix/id" -) - -type BackfillType int - -const ( - BackfillImmediate BackfillType = 0 - BackfillForward BackfillType = 100 - BackfillDeferred BackfillType = 200 -) - -func (bt BackfillType) String() string { - switch bt { - case BackfillImmediate: - return "IMMEDIATE" - case BackfillForward: - return "FORWARD" - case BackfillDeferred: - return "DEFERRED" - } - return "UNKNOWN" -} - -type BackfillQuery struct { - db *Database - log log.Logger - - backfillQueryLock sync.Mutex -} - -func (bq *BackfillQuery) New() *Backfill { - return &Backfill{ - db: bq.db, - log: bq.log, - Portal: &PortalKey{}, - } -} - -func (bq *BackfillQuery) NewWithValues(userID id.UserID, backfillType BackfillType, priority int, portal *PortalKey, timeStart *time.Time, maxBatchEvents, maxTotalEvents, batchDelay int) *Backfill { - return &Backfill{ - db: bq.db, - log: bq.log, - UserID: userID, - BackfillType: backfillType, - Priority: priority, - Portal: portal, - TimeStart: timeStart, - MaxBatchEvents: maxBatchEvents, - MaxTotalEvents: maxTotalEvents, - BatchDelay: batchDelay, - } -} - -const ( - getNextBackfillQuery = ` - SELECT queue_id, user_mxid, type, priority, portal_jid, portal_receiver, time_start, max_batch_events, max_total_events, batch_delay - FROM backfill_queue - WHERE user_mxid=$1 - AND type IN (%s) - AND ( - dispatch_time IS NULL - OR ( - dispatch_time < $2 - AND completed_at IS NULL - ) - ) - ORDER BY type, priority, queue_id - LIMIT 1 - ` - getUnstartedOrInFlightQuery = ` - SELECT 1 - FROM backfill_queue - WHERE user_mxid=$1 - AND type IN (%s) - AND (dispatch_time IS NULL OR completed_at IS NULL) - LIMIT 1 - ` -) - -// GetNext returns the next backfill to perform -func (bq *BackfillQuery) GetNext(userID id.UserID, backfillTypes []BackfillType) (backfill *Backfill) { - bq.backfillQueryLock.Lock() - defer bq.backfillQueryLock.Unlock() - - var types []string - for _, backfillType := range backfillTypes { - types = append(types, strconv.Itoa(int(backfillType))) - } - rows, err := bq.db.Query(fmt.Sprintf(getNextBackfillQuery, strings.Join(types, ",")), userID, time.Now().Add(-15*time.Minute)) - if err != nil || rows == nil { - bq.log.Errorfln("Failed to query next backfill queue job: %v", err) - return - } - defer rows.Close() - if rows.Next() { - backfill = bq.New().Scan(rows) - } - return -} - -func (bq *BackfillQuery) HasUnstartedOrInFlightOfType(userID id.UserID, backfillTypes []BackfillType) bool { - if len(backfillTypes) == 0 { - return false - } - - bq.backfillQueryLock.Lock() - defer bq.backfillQueryLock.Unlock() - - types := []string{} - for _, backfillType := range backfillTypes { - types = append(types, strconv.Itoa(int(backfillType))) - } - rows, err := bq.db.Query(fmt.Sprintf(getUnstartedOrInFlightQuery, strings.Join(types, ",")), userID) - if err != nil || rows == nil { - if err != nil && !errors.Is(err, sql.ErrNoRows) { - bq.log.Warnfln("Failed to query backfill queue jobs: %v", err) - } - // No rows means that there are no unstarted or in flight backfill - // requests. - return false - } - defer rows.Close() - return rows.Next() -} - -func (bq *BackfillQuery) DeleteAll(userID id.UserID) { - bq.backfillQueryLock.Lock() - defer bq.backfillQueryLock.Unlock() - _, err := bq.db.Exec("DELETE FROM backfill_queue WHERE user_mxid=$1", userID) - if err != nil { - bq.log.Warnfln("Failed to delete backfill queue items for %s: %v", userID, err) - } -} - -func (bq *BackfillQuery) DeleteAllForPortal(userID id.UserID, portalKey PortalKey) { - bq.backfillQueryLock.Lock() - defer bq.backfillQueryLock.Unlock() - _, err := bq.db.Exec(` - DELETE FROM backfill_queue - WHERE user_mxid=$1 - AND portal_jid=$2 - AND portal_receiver=$3 - `, userID, portalKey.JID, portalKey.Receiver) - if err != nil { - bq.log.Warnfln("Failed to delete backfill queue items for %s/%s: %v", userID, portalKey.JID, err) - } -} - -type Backfill struct { - db *Database - log log.Logger - - // Fields - QueueID int - UserID id.UserID - BackfillType BackfillType - Priority int - Portal *PortalKey - TimeStart *time.Time - MaxBatchEvents int - MaxTotalEvents int - BatchDelay int - DispatchTime *time.Time - CompletedAt *time.Time -} - -func (b *Backfill) String() string { - return fmt.Sprintf("Backfill{QueueID: %d, UserID: %s, BackfillType: %s, Priority: %d, Portal: %s, TimeStart: %s, MaxBatchEvents: %d, MaxTotalEvents: %d, BatchDelay: %d, DispatchTime: %s, CompletedAt: %s}", - b.QueueID, b.UserID, b.BackfillType, b.Priority, b.Portal, b.TimeStart, b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.CompletedAt, b.DispatchTime, - ) -} - -func (b *Backfill) Scan(row dbutil.Scannable) *Backfill { - var maxTotalEvents, batchDelay sql.NullInt32 - err := row.Scan(&b.QueueID, &b.UserID, &b.BackfillType, &b.Priority, &b.Portal.JID, &b.Portal.Receiver, &b.TimeStart, &b.MaxBatchEvents, &maxTotalEvents, &batchDelay) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - b.log.Errorln("Database scan failed:", err) - } - return nil - } - b.MaxTotalEvents = int(maxTotalEvents.Int32) - b.BatchDelay = int(batchDelay.Int32) - return b -} - -func (b *Backfill) Insert() { - b.db.Backfill.backfillQueryLock.Lock() - defer b.db.Backfill.backfillQueryLock.Unlock() - - rows, err := b.db.Query(` - INSERT INTO backfill_queue - (user_mxid, type, priority, portal_jid, portal_receiver, time_start, max_batch_events, max_total_events, batch_delay, dispatch_time, completed_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) - RETURNING queue_id - `, b.UserID, b.BackfillType, b.Priority, b.Portal.JID, b.Portal.Receiver, b.TimeStart, b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.DispatchTime, b.CompletedAt) - defer rows.Close() - if err != nil || !rows.Next() { - b.log.Warnfln("Failed to insert %v/%s with priority %d: %v", b.BackfillType, b.Portal.JID, b.Priority, err) - return - } - err = rows.Scan(&b.QueueID) - if err != nil { - b.log.Warnfln("Failed to insert %s/%s with priority %s: %v", b.BackfillType, b.Portal.JID, b.Priority, err) - } -} - -func (b *Backfill) MarkDispatched() { - b.db.Backfill.backfillQueryLock.Lock() - defer b.db.Backfill.backfillQueryLock.Unlock() - - if b.QueueID == 0 { - b.log.Errorfln("Cannot mark backfill as dispatched without queue_id. Maybe it wasn't actually inserted in the database?") - return - } - _, err := b.db.Exec("UPDATE backfill_queue SET dispatch_time=$1 WHERE queue_id=$2", time.Now(), b.QueueID) - if err != nil { - b.log.Warnfln("Failed to mark %s/%s as dispatched: %v", b.BackfillType, b.Priority, err) - } -} - -func (b *Backfill) MarkDone() { - b.db.Backfill.backfillQueryLock.Lock() - defer b.db.Backfill.backfillQueryLock.Unlock() - - if b.QueueID == 0 { - b.log.Errorfln("Cannot mark backfill done without queue_id. Maybe it wasn't actually inserted in the database?") - return - } - _, err := b.db.Exec("UPDATE backfill_queue SET completed_at=$1 WHERE queue_id=$2", time.Now(), b.QueueID) - if err != nil { - b.log.Warnfln("Failed to mark %s/%s as complete: %v", b.BackfillType, b.Priority, err) - } -} - -func (bq *BackfillQuery) NewBackfillState(userID id.UserID, portalKey *PortalKey) *BackfillState { - return &BackfillState{ - db: bq.db, - log: bq.log, - UserID: userID, - Portal: portalKey, - } -} - -const ( - getBackfillState = ` - SELECT user_mxid, portal_jid, portal_receiver, processing_batch, backfill_complete, first_expected_ts - FROM backfill_state - WHERE user_mxid=$1 - AND portal_jid=$2 - AND portal_receiver=$3 - ` -) - -type BackfillState struct { - db *Database - log log.Logger - - // Fields - UserID id.UserID - Portal *PortalKey - ProcessingBatch bool - BackfillComplete bool - FirstExpectedTimestamp uint64 -} - -func (b *BackfillState) Scan(row dbutil.Scannable) *BackfillState { - err := row.Scan(&b.UserID, &b.Portal.JID, &b.Portal.Receiver, &b.ProcessingBatch, &b.BackfillComplete, &b.FirstExpectedTimestamp) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - b.log.Errorln("Database scan failed:", err) - } - return nil - } - return b -} - -func (b *BackfillState) Upsert() { - _, err := b.db.Exec(` - INSERT INTO backfill_state - (user_mxid, portal_jid, portal_receiver, processing_batch, backfill_complete, first_expected_ts) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (user_mxid, portal_jid, portal_receiver) - DO UPDATE SET - processing_batch=EXCLUDED.processing_batch, - backfill_complete=EXCLUDED.backfill_complete, - first_expected_ts=EXCLUDED.first_expected_ts`, - b.UserID, b.Portal.JID, b.Portal.Receiver, b.ProcessingBatch, b.BackfillComplete, b.FirstExpectedTimestamp) - if err != nil { - b.log.Warnfln("Failed to insert backfill state for %s: %v", b.Portal.JID, err) - } -} - -func (b *BackfillState) SetProcessingBatch(processing bool) { - b.ProcessingBatch = processing - b.Upsert() -} - -func (bq *BackfillQuery) GetBackfillState(userID id.UserID, portalKey *PortalKey) (backfillState *BackfillState) { - rows, err := bq.db.Query(getBackfillState, userID, portalKey.JID, portalKey.Receiver) - if err != nil || rows == nil { - bq.log.Error(err) - return - } - defer rows.Close() - if rows.Next() { - backfillState = bq.NewBackfillState(userID, portalKey).Scan(rows) - } - return -} diff --git a/database/backfillqueue.go b/database/backfillqueue.go new file mode 100644 index 0000000..bf3bd99 --- /dev/null +++ b/database/backfillqueue.go @@ -0,0 +1,253 @@ +// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. +// Copyright (C) 2024 Tulir Asokan, Sumner Evans +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package database + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/id" +) + +type BackfillType int + +const ( + BackfillImmediate BackfillType = 0 + BackfillForward BackfillType = 100 + BackfillDeferred BackfillType = 200 +) + +func (bt BackfillType) String() string { + switch bt { + case BackfillImmediate: + return "IMMEDIATE" + case BackfillForward: + return "FORWARD" + case BackfillDeferred: + return "DEFERRED" + } + return "UNKNOWN" +} + +type BackfillTaskQuery struct { + *dbutil.QueryHelper[*BackfillTask] + + //backfillQueryLock sync.Mutex +} + +func newBackfillTask(qh *dbutil.QueryHelper[*BackfillTask]) *BackfillTask { + return &BackfillTask{qh: qh} +} + +func (bq *BackfillTaskQuery) NewWithValues(userID id.UserID, backfillType BackfillType, priority int, portal PortalKey, timeStart *time.Time, maxBatchEvents, maxTotalEvents, batchDelay int) *BackfillTask { + return &BackfillTask{ + qh: bq.QueryHelper, + + UserID: userID, + BackfillType: backfillType, + Priority: priority, + Portal: portal, + TimeStart: timeStart, + MaxBatchEvents: maxBatchEvents, + MaxTotalEvents: maxTotalEvents, + BatchDelay: batchDelay, + } +} + +const ( + getNextBackfillTaskQueryTemplate = ` + SELECT queue_id, user_mxid, type, priority, portal_jid, portal_receiver, time_start, max_batch_events, max_total_events, batch_delay + FROM backfill_queue + WHERE user_mxid=$1 + AND type IN (%s) + AND ( + dispatch_time IS NULL + OR ( + dispatch_time < $2 + AND completed_at IS NULL + ) + ) + ORDER BY type, priority, queue_id + LIMIT 1 + ` + getUnstartedOrInFlightBackfillTaskQueryTemplate = ` + SELECT 1 + FROM backfill_queue + WHERE user_mxid=$1 + AND type IN (%s) + AND (dispatch_time IS NULL OR completed_at IS NULL) + LIMIT 1 + ` + deleteBackfillQueueForUserQuery = "DELETE FROM backfill_queue WHERE user_mxid=$1" + deleteBackfillQueueForPortalQuery = ` + DELETE FROM backfill_queue + WHERE user_mxid=$1 + AND portal_jid=$2 + AND portal_receiver=$3 + ` + insertBackfillTaskQuery = ` + INSERT INTO backfill_queue ( + user_mxid, type, priority, portal_jid, portal_receiver, time_start, + max_batch_events, max_total_events, batch_delay, dispatch_time, completed_at + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + RETURNING queue_id + ` + markBackfillTaskDispatchedQuery = "UPDATE backfill_queue SET dispatch_time=$1 WHERE queue_id=$2" + markBackfillTaskDoneQuery = "UPDATE backfill_queue SET completed_at=$1 WHERE queue_id=$2" +) + +func typesToString(backfillTypes []BackfillType) string { + types := make([]string, len(backfillTypes)) + for i, backfillType := range backfillTypes { + types[i] = strconv.Itoa(int(backfillType)) + } + return strings.Join(types, ",") +} + +// GetNext returns the next backfill to perform +func (bq *BackfillTaskQuery) GetNext(ctx context.Context, userID id.UserID, backfillTypes []BackfillType) (*BackfillTask, error) { + if len(backfillTypes) == 0 { + return nil, nil + } + //bq.backfillQueryLock.Lock() + //defer bq.backfillQueryLock.Unlock() + + query := fmt.Sprintf(getNextBackfillTaskQueryTemplate, typesToString(backfillTypes)) + return bq.QueryOne(ctx, query, userID, time.Now().Add(-15*time.Minute)) +} + +func (bq *BackfillTaskQuery) HasUnstartedOrInFlightOfType(ctx context.Context, userID id.UserID, backfillTypes []BackfillType) (has bool) { + if len(backfillTypes) == 0 { + return false + } + + //bq.backfillQueryLock.Lock() + //defer bq.backfillQueryLock.Unlock() + + query := fmt.Sprintf(getUnstartedOrInFlightBackfillTaskQueryTemplate, typesToString(backfillTypes)) + err := bq.GetDB().QueryRow(ctx, query, userID).Scan(&has) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + zerolog.Ctx(ctx).Err(err).Msg("Failed to check if backfill queue has jobs") + } + return +} + +func (bq *BackfillTaskQuery) DeleteAll(ctx context.Context, userID id.UserID) error { + //bq.backfillQueryLock.Lock() + //defer bq.backfillQueryLock.Unlock() + return bq.Exec(ctx, deleteBackfillQueueForUserQuery, userID) +} + +func (bq *BackfillTaskQuery) DeleteAllForPortal(ctx context.Context, userID id.UserID, portalKey PortalKey) error { + //bq.backfillQueryLock.Lock() + //defer bq.backfillQueryLock.Unlock() + return bq.Exec(ctx, deleteBackfillQueueForPortalQuery, userID, portalKey.JID, portalKey.Receiver) +} + +type BackfillTask struct { + qh *dbutil.QueryHelper[*BackfillTask] + + QueueID int + UserID id.UserID + BackfillType BackfillType + Priority int + Portal PortalKey + TimeStart *time.Time + MaxBatchEvents int + MaxTotalEvents int + BatchDelay int + DispatchTime *time.Time + CompletedAt *time.Time +} + +func (b *BackfillTask) MarshalZerologObject(evt *zerolog.Event) { + evt.Int("queue_id", b.QueueID). + Stringer("user_id", b.UserID). + Stringer("backfill_type", b.BackfillType). + Int("priority", b.Priority). + Stringer("portal_jid", b.Portal.JID). + Any("time_start", b.TimeStart). + Int("max_batch_events", b.MaxBatchEvents). + Int("max_total_events", b.MaxTotalEvents). + Int("batch_delay", b.BatchDelay). + Any("dispatch_time", b.DispatchTime). + Any("completed_at", b.CompletedAt) +} + +func (b *BackfillTask) String() string { + return fmt.Sprintf( + "BackfillTask{QueueID: %d, UserID: %s, BackfillType: %s, Priority: %d, Portal: %s, TimeStart: %s, MaxBatchEvents: %d, MaxTotalEvents: %d, BatchDelay: %d, DispatchTime: %s, CompletedAt: %s}", + b.QueueID, b.UserID, b.BackfillType, b.Priority, b.Portal, b.TimeStart, b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.CompletedAt, b.DispatchTime, + ) +} + +func (b *BackfillTask) Scan(row dbutil.Scannable) (*BackfillTask, error) { + var maxTotalEvents, batchDelay sql.NullInt32 + err := row.Scan( + &b.QueueID, &b.UserID, &b.BackfillType, &b.Priority, &b.Portal.JID, &b.Portal.Receiver, &b.TimeStart, + &b.MaxBatchEvents, &maxTotalEvents, &batchDelay, + ) + if err != nil { + return nil, err + } + b.MaxTotalEvents = int(maxTotalEvents.Int32) + b.BatchDelay = int(batchDelay.Int32) + return b, nil +} + +func (b *BackfillTask) sqlVariables() []any { + return []any{ + b.UserID, b.BackfillType, b.Priority, b.Portal.JID, b.Portal.Receiver, b.TimeStart, + b.MaxBatchEvents, b.MaxTotalEvents, b.BatchDelay, b.DispatchTime, b.CompletedAt, + } +} + +func (b *BackfillTask) Insert(ctx context.Context) error { + //b.db.Backfill.backfillQueryLock.Lock() + //defer b.db.Backfill.backfillQueryLock.Unlock() + + return b.qh.GetDB().QueryRow(ctx, insertBackfillTaskQuery, b.sqlVariables()...).Scan(&b.QueueID) +} + +func (b *BackfillTask) MarkDispatched(ctx context.Context) error { + //b.db.Backfill.backfillQueryLock.Lock() + //defer b.db.Backfill.backfillQueryLock.Unlock() + + if b.QueueID == 0 { + return fmt.Errorf("can't mark backfill as dispatched without queue_id") + } + return b.qh.Exec(ctx, markBackfillTaskDispatchedQuery, time.Now(), b.QueueID) +} + +func (b *BackfillTask) MarkDone(ctx context.Context) error { + //b.db.Backfill.backfillQueryLock.Lock() + //defer b.db.Backfill.backfillQueryLock.Unlock() + + if b.QueueID == 0 { + return fmt.Errorf("can't mark backfill as dispatched without queue_id") + } + return b.qh.Exec(ctx, markBackfillTaskDoneQuery, time.Now(), b.QueueID) +} diff --git a/database/backfillstate.go b/database/backfillstate.go new file mode 100644 index 0000000..6ad49b1 --- /dev/null +++ b/database/backfillstate.go @@ -0,0 +1,94 @@ +// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. +// Copyright (C) 2024 Tulir Asokan, Sumner Evans +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package database + +import ( + "context" + + "go.mau.fi/util/dbutil" + "maunium.net/go/mautrix/id" +) + +type BackfillStateQuery struct { + *dbutil.QueryHelper[*BackfillState] +} + +func newBackfillState(qh *dbutil.QueryHelper[*BackfillState]) *BackfillState { + return &BackfillState{qh: qh} +} + +func (bq *BackfillStateQuery) NewBackfillState(userID id.UserID, portalKey PortalKey) *BackfillState { + return &BackfillState{ + qh: bq.QueryHelper, + + UserID: userID, + Portal: portalKey, + } +} + +const ( + getBackfillStateQuery = ` + SELECT user_mxid, portal_jid, portal_receiver, processing_batch, backfill_complete, first_expected_ts + FROM backfill_state + WHERE user_mxid=$1 + AND portal_jid=$2 + AND portal_receiver=$3 + ` + upsertBackfillStateQuery = ` + INSERT INTO backfill_state + (user_mxid, portal_jid, portal_receiver, processing_batch, backfill_complete, first_expected_ts) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (user_mxid, portal_jid, portal_receiver) + DO UPDATE SET + processing_batch=EXCLUDED.processing_batch, + backfill_complete=EXCLUDED.backfill_complete, + first_expected_ts=EXCLUDED.first_expected_ts + ` +) + +func (bq *BackfillStateQuery) GetBackfillState(ctx context.Context, userID id.UserID, portalKey PortalKey) (*BackfillState, error) { + return bq.QueryOne(ctx, getBackfillStateQuery, userID, portalKey.JID, portalKey.Receiver) +} + +type BackfillState struct { + qh *dbutil.QueryHelper[*BackfillState] + + UserID id.UserID + Portal PortalKey + ProcessingBatch bool + BackfillComplete bool + FirstExpectedTimestamp uint64 +} + +func (b *BackfillState) Scan(row dbutil.Scannable) (*BackfillState, error) { + return dbutil.ValueOrErr(b, row.Scan( + &b.UserID, &b.Portal.JID, &b.Portal.Receiver, &b.ProcessingBatch, &b.BackfillComplete, &b.FirstExpectedTimestamp, + )) +} + +func (b *BackfillState) sqlVariables() []any { + return []any{b.UserID, b.Portal.JID, b.Portal.Receiver, b.ProcessingBatch, b.BackfillComplete, b.FirstExpectedTimestamp} +} + +func (b *BackfillState) Upsert(ctx context.Context) error { + return b.qh.Exec(ctx, upsertBackfillStateQuery, b.sqlVariables()...) +} + +func (b *BackfillState) SetProcessingBatch(ctx context.Context, processing bool) error { + b.ProcessingBatch = processing + return b.Upsert(ctx) +} diff --git a/database/database.go b/database/database.go index 4bc749a..5822c56 100644 --- a/database/database.go +++ b/database/database.go @@ -1,5 +1,5 @@ // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2022 Tulir Asokan +// Copyright (C) 2024 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -26,7 +26,6 @@ import ( "go.mau.fi/util/dbutil" "go.mau.fi/whatsmeow/store" "go.mau.fi/whatsmeow/store/sqlstore" - "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix-whatsapp/database/upgrades" ) @@ -45,51 +44,28 @@ type Database struct { Reaction *ReactionQuery DisappearingMessage *DisappearingMessageQuery - Backfill *BackfillQuery + BackfillQueue *BackfillTaskQuery + BackfillState *BackfillStateQuery HistorySync *HistorySyncQuery MediaBackfillRequest *MediaBackfillRequestQuery } -func New(baseDB *dbutil.Database, log maulogger.Logger) *Database { - db := &Database{Database: baseDB} +func New(db *dbutil.Database) *Database { db.UpgradeTable = upgrades.Table - db.User = &UserQuery{ - db: db, - log: log.Sub("User"), + return &Database{ + Database: db, + User: &UserQuery{dbutil.MakeQueryHelper(db, newUser)}, + Portal: &PortalQuery{dbutil.MakeQueryHelper(db, newPortal)}, + Puppet: &PuppetQuery{dbutil.MakeQueryHelper(db, newPuppet)}, + Message: &MessageQuery{dbutil.MakeQueryHelper(db, newMessage)}, + Reaction: &ReactionQuery{dbutil.MakeQueryHelper(db, newReaction)}, + + DisappearingMessage: &DisappearingMessageQuery{dbutil.MakeQueryHelper(db, newDisappearingMessage)}, + BackfillQueue: &BackfillTaskQuery{dbutil.MakeQueryHelper(db, newBackfillTask)}, + BackfillState: &BackfillStateQuery{dbutil.MakeQueryHelper(db, newBackfillState)}, + HistorySync: &HistorySyncQuery{dbutil.MakeQueryHelper(db, newHistorySyncConversation)}, + MediaBackfillRequest: &MediaBackfillRequestQuery{dbutil.MakeQueryHelper(db, newMediaBackfillRequest)}, } - db.Portal = &PortalQuery{ - db: db, - log: log.Sub("Portal"), - } - db.Puppet = &PuppetQuery{ - db: db, - log: log.Sub("Puppet"), - } - db.Message = &MessageQuery{ - db: db, - log: log.Sub("Message"), - } - db.Reaction = &ReactionQuery{ - db: db, - log: log.Sub("Reaction"), - } - db.DisappearingMessage = &DisappearingMessageQuery{ - db: db, - log: log.Sub("DisappearingMessage"), - } - db.Backfill = &BackfillQuery{ - db: db, - log: log.Sub("Backfill"), - } - db.HistorySync = &HistorySyncQuery{ - db: db, - log: log.Sub("HistorySync"), - } - db.MediaBackfillRequest = &MediaBackfillRequestQuery{ - db: db, - log: log.Sub("MediaBackfillRequest"), - } - return db } func isRetryableError(err error) bool { diff --git a/database/disappearingmessage.go b/database/disappearingmessage.go index bae11e0..31eef09 100644 --- a/database/disappearingmessage.go +++ b/database/disappearingmessage.go @@ -1,5 +1,5 @@ // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2022 Tulir Asokan +// Copyright (C) 2024 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -17,32 +17,29 @@ package database import ( + "context" "database/sql" - "errors" "time" - "go.mau.fi/util/dbutil" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" + + "go.mau.fi/util/dbutil" ) type DisappearingMessageQuery struct { - db *Database - log log.Logger + *dbutil.QueryHelper[*DisappearingMessage] } -func (dmq *DisappearingMessageQuery) New() *DisappearingMessage { +func newDisappearingMessage(qh *dbutil.QueryHelper[*DisappearingMessage]) *DisappearingMessage { return &DisappearingMessage{ - db: dmq.db, - log: dmq.log, + qh: qh, } } func (dmq *DisappearingMessageQuery) NewWithValues(roomID id.RoomID, eventID id.EventID, expireIn time.Duration, expireAt time.Time) *DisappearingMessage { dm := &DisappearingMessage{ - db: dmq.db, - log: dmq.log, + qh: dmq.QueryHelper, + RoomID: roomID, EventID: eventID, ExpireIn: expireIn, @@ -55,22 +52,17 @@ const ( getAllScheduledDisappearingMessagesQuery = ` SELECT room_id, event_id, expire_in, expire_at FROM disappearing_message WHERE expire_at IS NOT NULL AND expire_at <= $1 ` + insertDisappearingMessageQuery = `INSERT INTO disappearing_message (room_id, event_id, expire_in, expire_at) VALUES ($1, $2, $3, $4)` + updateDisappearingMessageExpiryQuery = "UPDATE disappearing_message SET expire_at=$1 WHERE room_id=$2 AND event_id=$3" + deleteDisappearingMessageQuery = "DELETE FROM disappearing_message WHERE room_id=$1 AND event_id=$2" ) -func (dmq *DisappearingMessageQuery) GetUpcomingScheduled(duration time.Duration) (messages []*DisappearingMessage) { - rows, err := dmq.db.Query(getAllScheduledDisappearingMessagesQuery, time.Now().Add(duration).UnixMilli()) - if err != nil || rows == nil { - return nil - } - for rows.Next() { - messages = append(messages, dmq.New().Scan(rows)) - } - return +func (dmq *DisappearingMessageQuery) GetUpcomingScheduled(ctx context.Context, duration time.Duration) ([]*DisappearingMessage, error) { + return dmq.QueryMany(ctx, getAllScheduledDisappearingMessagesQuery, time.Now().Add(duration).UnixMilli()) } type DisappearingMessage struct { - db *Database - log log.Logger + qh *dbutil.QueryHelper[*DisappearingMessage] RoomID id.RoomID EventID id.EventID @@ -78,50 +70,33 @@ type DisappearingMessage struct { ExpireAt time.Time } -func (msg *DisappearingMessage) Scan(row dbutil.Scannable) *DisappearingMessage { +func (msg *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage, error) { var expireIn int64 var expireAt sql.NullInt64 err := row.Scan(&msg.RoomID, &msg.EventID, &expireIn, &expireAt) if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - msg.log.Errorln("Database scan failed:", err) - } - return nil + return nil, err } msg.ExpireIn = time.Duration(expireIn) * time.Millisecond if expireAt.Valid { msg.ExpireAt = time.UnixMilli(expireAt.Int64) } - return msg + return msg, nil } -func (msg *DisappearingMessage) Insert(txn dbutil.Execable) { - if txn == nil { - txn = msg.db - } - var expireAt sql.NullInt64 - if !msg.ExpireAt.IsZero() { - expireAt.Valid = true - expireAt.Int64 = msg.ExpireAt.UnixMilli() - } - _, err := txn.Exec(`INSERT INTO disappearing_message (room_id, event_id, expire_in, expire_at) VALUES ($1, $2, $3, $4)`, - msg.RoomID, msg.EventID, msg.ExpireIn.Milliseconds(), expireAt) - if err != nil { - msg.log.Warnfln("Failed to insert %s/%s: %v", msg.RoomID, msg.EventID, err) - } +func (msg *DisappearingMessage) sqlVariables() []any { + return []any{msg.RoomID, msg.EventID, msg.ExpireIn.Milliseconds(), dbutil.UnixMilliPtr(msg.ExpireAt)} } -func (msg *DisappearingMessage) StartTimer() { +func (msg *DisappearingMessage) Insert(ctx context.Context) error { + return msg.qh.Exec(ctx, insertDisappearingMessageQuery, msg.sqlVariables()...) +} + +func (msg *DisappearingMessage) StartTimer(ctx context.Context) error { msg.ExpireAt = time.Now().Add(msg.ExpireIn * time.Second) - _, err := msg.db.Exec("UPDATE disappearing_message SET expire_at=$1 WHERE room_id=$2 AND event_id=$3", msg.ExpireAt.Unix(), msg.RoomID, msg.EventID) - if err != nil { - msg.log.Warnfln("Failed to update %s/%s: %v", msg.RoomID, msg.EventID, err) - } + return msg.qh.Exec(ctx, updateDisappearingMessageExpiryQuery, msg.ExpireAt.Unix(), msg.RoomID, msg.EventID) } -func (msg *DisappearingMessage) Delete() { - _, err := msg.db.Exec("DELETE FROM disappearing_message WHERE room_id=$1 AND event_id=$2", msg.RoomID, msg.EventID) - if err != nil { - msg.log.Warnfln("Failed to delete %s/%s: %v", msg.RoomID, msg.EventID, err) - } +func (msg *DisappearingMessage) Delete(ctx context.Context) error { + return msg.qh.Exec(ctx, deleteDisappearingMessageQuery, msg.RoomID, msg.EventID) } diff --git a/database/historysync.go b/database/historysync.go index d896bbc..a5b57e4 100644 --- a/database/historysync.go +++ b/database/historysync.go @@ -1,5 +1,5 @@ // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2022 Tulir Asokan, Sumner Evans +// Copyright (C) 2024 Tulir Asokan, Sumner Evans // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -17,8 +17,7 @@ package database import ( - "database/sql" - "errors" + "context" "fmt" "time" @@ -26,23 +25,19 @@ import ( "go.mau.fi/util/dbutil" waProto "go.mau.fi/whatsmeow/binary/proto" "google.golang.org/protobuf/proto" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" ) type HistorySyncQuery struct { - db *Database - log log.Logger + *dbutil.QueryHelper[*HistorySyncConversation] } type HistorySyncConversation struct { - db *Database - log log.Logger + qh *dbutil.QueryHelper[*HistorySyncConversation] UserID id.UserID ConversationID string - PortalKey *PortalKey + PortalKey PortalKey LastMessageTimestamp time.Time MuteEndTime time.Time Archived bool @@ -54,18 +49,16 @@ type HistorySyncConversation struct { UnreadCount uint32 } -func (hsq *HistorySyncQuery) NewConversation() *HistorySyncConversation { +func newHistorySyncConversation(qh *dbutil.QueryHelper[*HistorySyncConversation]) *HistorySyncConversation { return &HistorySyncConversation{ - db: hsq.db, - log: hsq.log, - PortalKey: &PortalKey{}, + qh: qh, } } func (hsq *HistorySyncQuery) NewConversationWithValues( userID id.UserID, conversationID string, - portalKey *PortalKey, + portalKey PortalKey, lastMessageTimestamp, muteEndTime uint64, archived bool, @@ -74,10 +67,10 @@ func (hsq *HistorySyncQuery) NewConversationWithValues( endOfHistoryTransferType waProto.Conversation_EndOfHistoryTransferType, ephemeralExpiration *uint32, markedAsUnread bool, - unreadCount uint32) *HistorySyncConversation { + unreadCount uint32, +) *HistorySyncConversation { return &HistorySyncConversation{ - db: hsq.db, - log: hsq.log, + qh: hsq.QueryHelper, UserID: userID, ConversationID: conversationID, PortalKey: portalKey, @@ -94,6 +87,17 @@ func (hsq *HistorySyncQuery) NewConversationWithValues( } const ( + upsertHistorySyncConversationQuery = ` + INSERT INTO history_sync_conversation (user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + ON CONFLICT (user_mxid, conversation_id) + DO UPDATE SET + last_message_timestamp=CASE + WHEN EXCLUDED.last_message_timestamp > history_sync_conversation.last_message_timestamp THEN EXCLUDED.last_message_timestamp + ELSE history_sync_conversation.last_message_timestamp + END, + end_of_history_transfer_type=EXCLUDED.end_of_history_transfer_type + ` getNMostRecentConversations = ` SELECT user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count FROM history_sync_conversation @@ -108,24 +112,19 @@ const ( AND portal_jid=$2 AND portal_receiver=$3 ` + deleteAllConversationsQuery = "DELETE FROM history_sync_conversation WHERE user_mxid=$1" + deleteHistorySyncConversationQuery = ` + DELETE FROM history_sync_conversation + WHERE user_mxid=$1 AND conversation_id=$2 + ` ) -func (hsc *HistorySyncConversation) Upsert() { - _, err := hsc.db.Exec(` - INSERT INTO history_sync_conversation (user_mxid, conversation_id, portal_jid, portal_receiver, last_message_timestamp, archived, pinned, mute_end_time, disappearing_mode, end_of_history_transfer_type, ephemeral_expiration, marked_as_unread, unread_count) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) - ON CONFLICT (user_mxid, conversation_id) - DO UPDATE SET - last_message_timestamp=CASE - WHEN EXCLUDED.last_message_timestamp > history_sync_conversation.last_message_timestamp THEN EXCLUDED.last_message_timestamp - ELSE history_sync_conversation.last_message_timestamp - END, - end_of_history_transfer_type=EXCLUDED.end_of_history_transfer_type - `, +func (hsc *HistorySyncConversation) sqlVariables() []any { + return []any{ hsc.UserID, hsc.ConversationID, - hsc.PortalKey.JID.String(), - hsc.PortalKey.Receiver.String(), + hsc.PortalKey.JID, + hsc.PortalKey.Receiver, hsc.LastMessageTimestamp, hsc.Archived, hsc.Pinned, @@ -134,14 +133,16 @@ func (hsc *HistorySyncConversation) Upsert() { hsc.EndOfHistoryTransferType, hsc.EphemeralExpiration, hsc.MarkedAsUnread, - hsc.UnreadCount) - if err != nil { - hsc.log.Warnfln("Failed to insert history sync conversation %s/%s: %v", hsc.UserID, hsc.ConversationID, err) + hsc.UnreadCount, } } -func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) *HistorySyncConversation { - err := row.Scan( +func (hsc *HistorySyncConversation) Upsert(ctx context.Context) error { + return hsc.qh.Exec(ctx, upsertHistorySyncConversationQuery, hsc.sqlVariables()...) +} + +func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) (*HistorySyncConversation, error) { + return dbutil.ValueOrErr(hsc, row.Scan( &hsc.UserID, &hsc.ConversationID, &hsc.PortalKey.JID, @@ -154,69 +155,59 @@ func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) *HistorySyncConve &hsc.EndOfHistoryTransferType, &hsc.EphemeralExpiration, &hsc.MarkedAsUnread, - &hsc.UnreadCount) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - hsc.log.Errorln("Database scan failed:", err) - } - return nil - } - return hsc + &hsc.UnreadCount, + )) } -func (hsq *HistorySyncQuery) GetRecentConversations(userID id.UserID, n int) (conversations []*HistorySyncConversation) { +func (hsq *HistorySyncQuery) GetRecentConversations(ctx context.Context, userID id.UserID, n int) ([]*HistorySyncConversation, error) { nPtr := &n // Negative limit on SQLite means unlimited, but Postgres prefers a NULL limit. - if n < 0 && hsq.db.Dialect == dbutil.Postgres { + if n < 0 && hsq.GetDB().Dialect == dbutil.Postgres { nPtr = nil } - rows, err := hsq.db.Query(getNMostRecentConversations, userID, nPtr) - defer rows.Close() - if err != nil || rows == nil { - return nil - } - for rows.Next() { - conversations = append(conversations, hsq.NewConversation().Scan(rows)) - } - return + return hsq.QueryMany(ctx, getNMostRecentConversations, userID, nPtr) } -func (hsq *HistorySyncQuery) GetConversation(userID id.UserID, portalKey PortalKey) (conversation *HistorySyncConversation) { - rows, err := hsq.db.Query(getConversationByPortal, userID, portalKey.JID, portalKey.Receiver) - defer rows.Close() - if err != nil || rows == nil { - return nil - } - if rows.Next() { - conversation = hsq.NewConversation().Scan(rows) - } - return +func (hsq *HistorySyncQuery) GetConversation(ctx context.Context, userID id.UserID, portalKey PortalKey) (*HistorySyncConversation, error) { + return hsq.QueryOne(ctx, getConversationByPortal, userID, portalKey.JID, portalKey.Receiver) } -func (hsq *HistorySyncQuery) DeleteAllConversations(userID id.UserID) { - _, err := hsq.db.Exec("DELETE FROM history_sync_conversation WHERE user_mxid=$1", userID) - if err != nil { - hsq.log.Warnfln("Failed to delete historical chat info for %s/%s: %v", userID, err) - } +func (hsq *HistorySyncQuery) DeleteAllConversations(ctx context.Context, userID id.UserID) error { + return hsq.Exec(ctx, deleteAllConversationsQuery, userID) } const ( - getMessagesBetween = ` + insertHistorySyncMessageQuery = ` + INSERT INTO history_sync_message (user_mxid, conversation_id, message_id, timestamp, data, inserted_time) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (user_mxid, conversation_id, message_id) DO NOTHING + ` + getHistorySyncMessagesBetweenQueryTemplate = ` SELECT data FROM history_sync_message WHERE user_mxid=$1 AND conversation_id=$2 %s ORDER BY timestamp DESC %s ` - deleteMessagesBetweenExclusive = ` + deleteHistorySyncMessagesBetweenExclusiveQuery = ` DELETE FROM history_sync_message WHERE user_mxid=$1 AND conversation_id=$2 AND timestamp<$3 AND timestamp>$4 ` + deleteAllHistorySyncMessagesQuery = "DELETE FROM history_sync_message WHERE user_mxid=$1" + deleteHistorySyncMessagesForPortalQuery = ` + DELETE FROM history_sync_message + WHERE user_mxid=$1 AND conversation_id=$2 + ` + conversationHasHistorySyncMessagesQuery = ` + SELECT EXISTS( + SELECT 1 FROM history_sync_message + WHERE user_mxid=$1 AND conversation_id=$2 + ) + ` ) type HistorySyncMessage struct { - db *Database - log log.Logger + hsq *HistorySyncQuery UserID id.UserID ConversationID string @@ -231,8 +222,8 @@ func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversation return nil, err } return &HistorySyncMessage{ - db: hsq.db, - log: hsq.log, + hsq: hsq, + UserID: userID, ConversationID: conversationID, MessageID: messageID, @@ -241,18 +232,27 @@ func (hsq *HistorySyncQuery) NewMessageWithValues(userID id.UserID, conversation }, nil } -func (hsm *HistorySyncMessage) Insert() error { - _, err := hsm.db.Exec(` - INSERT INTO history_sync_message (user_mxid, conversation_id, message_id, timestamp, data, inserted_time) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (user_mxid, conversation_id, message_id) DO NOTHING - `, hsm.UserID, hsm.ConversationID, hsm.MessageID, hsm.Timestamp, hsm.Data, time.Now()) - return err +func (hsm *HistorySyncMessage) Insert(ctx context.Context) error { + return hsm.hsq.Exec(ctx, insertHistorySyncMessageQuery, hsm.UserID, hsm.ConversationID, hsm.MessageID, hsm.Timestamp, hsm.Data, time.Now()) } -func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) (messages []*waProto.WebMessageInfo) { +func scanWebMessageInfo(rows dbutil.Scannable) (*waProto.WebMessageInfo, error) { + var msgData []byte + err := rows.Scan(&msgData) + if err != nil { + return nil, err + } + var historySyncMsg waProto.HistorySyncMsg + err = proto.Unmarshal(msgData, &historySyncMsg) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal message: %w", err) + } + return historySyncMsg.GetMessage(), nil +} + +func (hsq *HistorySyncQuery) GetMessagesBetween(ctx context.Context, userID id.UserID, conversationID string, startTime, endTime *time.Time, limit int) ([]*waProto.WebMessageInfo, error) { whereClauses := "" - args := []interface{}{userID, conversationID} + args := []any{userID, conversationID} argNum := 3 if startTime != nil { whereClauses += fmt.Sprintf(" AND timestamp >= $%d", argNum) @@ -268,80 +268,35 @@ func (hsq *HistorySyncQuery) GetMessagesBetween(userID id.UserID, conversationID if limit > 0 { limitClause = fmt.Sprintf("LIMIT %d", limit) } + query := fmt.Sprintf(getHistorySyncMessagesBetweenQueryTemplate, whereClauses, limitClause) - rows, err := hsq.db.Query(fmt.Sprintf(getMessagesBetween, whereClauses, limitClause), args...) - defer rows.Close() - if err != nil || rows == nil { - if err != nil && !errors.Is(err, sql.ErrNoRows) { - hsq.log.Warnfln("Failed to query messages between range: %v", err) - } - return nil - } - - var msgData []byte - for rows.Next() { - err = rows.Scan(&msgData) - if err != nil { - hsq.log.Errorfln("Database scan failed: %v", err) - continue - } - var historySyncMsg waProto.HistorySyncMsg - err = proto.Unmarshal(msgData, &historySyncMsg) - if err != nil { - hsq.log.Errorfln("Failed to unmarshal history sync message: %v", err) - continue - } - messages = append(messages, historySyncMsg.Message) - } - return + return dbutil.ConvertRowFn[*waProto.WebMessageInfo](scanWebMessageInfo). + NewRowIter(hsq.GetDB().Query(ctx, query, args...)). + AsList() } -func (hsq *HistorySyncQuery) DeleteMessages(userID id.UserID, conversationID string, messages []*waProto.WebMessageInfo) error { +func (hsq *HistorySyncQuery) DeleteMessages(ctx context.Context, userID id.UserID, conversationID string, messages []*waProto.WebMessageInfo) error { newest := messages[0] beforeTS := time.Unix(int64(newest.GetMessageTimestamp())+1, 0) oldest := messages[len(messages)-1] afterTS := time.Unix(int64(oldest.GetMessageTimestamp())-1, 0) - _, err := hsq.db.Exec(deleteMessagesBetweenExclusive, userID, conversationID, beforeTS, afterTS) - return err + return hsq.Exec(ctx, deleteHistorySyncMessagesBetweenExclusiveQuery, userID, conversationID, beforeTS, afterTS) } -func (hsq *HistorySyncQuery) DeleteAllMessages(userID id.UserID) { - _, err := hsq.db.Exec("DELETE FROM history_sync_message WHERE user_mxid=$1", userID) - if err != nil { - hsq.log.Warnfln("Failed to delete historical messages for %s: %v", userID, err) - } +func (hsq *HistorySyncQuery) DeleteAllMessages(ctx context.Context, userID id.UserID) error { + return hsq.Exec(ctx, deleteAllHistorySyncMessagesQuery, userID) } -func (hsq *HistorySyncQuery) DeleteAllMessagesForPortal(userID id.UserID, portalKey PortalKey) { - _, err := hsq.db.Exec(` - DELETE FROM history_sync_message - WHERE user_mxid=$1 AND conversation_id=$2 - `, userID, portalKey.JID) - if err != nil { - hsq.log.Warnfln("Failed to delete historical messages for %s/%s: %v", userID, portalKey.JID, err) - } +func (hsq *HistorySyncQuery) DeleteAllMessagesForPortal(ctx context.Context, userID id.UserID, portalKey PortalKey) error { + return hsq.Exec(ctx, deleteHistorySyncMessagesForPortalQuery, userID, portalKey.JID) } -func (hsq *HistorySyncQuery) ConversationHasMessages(userID id.UserID, portalKey PortalKey) (exists bool) { - err := hsq.db.QueryRow(` - SELECT EXISTS( - SELECT 1 FROM history_sync_message - WHERE user_mxid=$1 AND conversation_id=$2 - ) - `, userID, portalKey.JID).Scan(&exists) - if err != nil { - hsq.log.Warnfln("Failed to check if any messages are stored for %s/%s: %v", userID, portalKey.JID, err) - } +func (hsq *HistorySyncQuery) ConversationHasMessages(ctx context.Context, userID id.UserID, portalKey PortalKey) (exists bool, err error) { + err = hsq.GetDB().QueryRow(ctx, conversationHasHistorySyncMessagesQuery, userID, portalKey.JID).Scan(&exists) return } -func (hsq *HistorySyncQuery) DeleteConversation(userID id.UserID, jid string) { +func (hsq *HistorySyncQuery) DeleteConversation(ctx context.Context, userID id.UserID, jid string) error { // This will also clear history_sync_message as there's a foreign key constraint - _, err := hsq.db.Exec(` - DELETE FROM history_sync_conversation - WHERE user_mxid=$1 AND conversation_id=$2 - `, userID, jid) - if err != nil { - hsq.log.Warnfln("Failed to delete historical messages for %s/%s: %v", userID, jid, err) - } + return hsq.Exec(ctx, deleteHistorySyncConversationQuery, userID, jid) } diff --git a/database/mediabackfillrequest.go b/database/mediabackfillrequest.go index e71b986..aa74b9e 100644 --- a/database/mediabackfillrequest.go +++ b/database/mediabackfillrequest.go @@ -1,5 +1,5 @@ // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2022 Tulir Asokan, Sumner Evans +// Copyright (C) 2024 Tulir Asokan, Sumner Evans // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -17,14 +17,12 @@ package database import ( - "database/sql" - "errors" + "context" _ "github.com/mattn/go-sqlite3" - "go.mau.fi/util/dbutil" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" + + "go.mau.fi/util/dbutil" ) type MediaBackfillRequestStatus int @@ -36,34 +34,46 @@ const ( ) type MediaBackfillRequestQuery struct { - db *Database - log log.Logger + *dbutil.QueryHelper[*MediaBackfillRequest] } -type MediaBackfillRequest struct { - db *Database - log log.Logger +const ( + getAllMediaBackfillRequestsForUserQuery = ` + SELECT user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error + FROM media_backfill_requests + WHERE user_mxid=$1 + AND status=0 + ` + deleteAllMediaBackfillRequestsForUserQuery = "DELETE FROM media_backfill_requests WHERE user_mxid=$1" + upsertMediaBackfillRequestQuery = ` + INSERT INTO media_backfill_requests (user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (user_mxid, portal_jid, portal_receiver, event_id) + DO UPDATE SET + media_key=excluded.media_key, + status=excluded.status, + error=excluded.error + ` +) - UserID id.UserID - PortalKey *PortalKey - EventID id.EventID - MediaKey []byte - Status MediaBackfillRequestStatus - Error string +func (mbrq *MediaBackfillRequestQuery) GetMediaBackfillRequestsForUser(ctx context.Context, userID id.UserID) ([]*MediaBackfillRequest, error) { + return mbrq.QueryMany(ctx, getAllMediaBackfillRequestsForUserQuery, userID) } -func (mbrq *MediaBackfillRequestQuery) newMediaBackfillRequest() *MediaBackfillRequest { +func (mbrq *MediaBackfillRequestQuery) DeleteAllMediaBackfillRequests(ctx context.Context, userID id.UserID) error { + return mbrq.Exec(ctx, deleteAllMediaBackfillRequestsForUserQuery, userID) +} + +func newMediaBackfillRequest(qh *dbutil.QueryHelper[*MediaBackfillRequest]) *MediaBackfillRequest { return &MediaBackfillRequest{ - db: mbrq.db, - log: mbrq.log, - PortalKey: &PortalKey{}, + qh: qh, } } -func (mbrq *MediaBackfillRequestQuery) NewMediaBackfillRequestWithValues(userID id.UserID, portalKey *PortalKey, eventID id.EventID, mediaKey []byte) *MediaBackfillRequest { +func (mbrq *MediaBackfillRequestQuery) NewMediaBackfillRequestWithValues(userID id.UserID, portalKey PortalKey, eventID id.EventID, mediaKey []byte) *MediaBackfillRequest { return &MediaBackfillRequest{ - db: mbrq.db, - log: mbrq.log, + qh: mbrq.QueryHelper, + UserID: userID, PortalKey: portalKey, EventID: eventID, @@ -72,62 +82,25 @@ func (mbrq *MediaBackfillRequestQuery) NewMediaBackfillRequestWithValues(userID } } -const ( - getMediaBackfillRequestsForUser = ` - SELECT user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error - FROM media_backfill_requests - WHERE user_mxid=$1 - AND status=0 - ` -) +type MediaBackfillRequest struct { + qh *dbutil.QueryHelper[*MediaBackfillRequest] -func (mbr *MediaBackfillRequest) Upsert() { - _, err := mbr.db.Exec(` - INSERT INTO media_backfill_requests (user_mxid, portal_jid, portal_receiver, event_id, media_key, status, error) - VALUES ($1, $2, $3, $4, $5, $6, $7) - ON CONFLICT (user_mxid, portal_jid, portal_receiver, event_id) - DO UPDATE SET - media_key=EXCLUDED.media_key, - status=EXCLUDED.status, - error=EXCLUDED.error`, - mbr.UserID, - mbr.PortalKey.JID.String(), - mbr.PortalKey.Receiver.String(), - mbr.EventID, - mbr.MediaKey, - mbr.Status, - mbr.Error) - if err != nil { - mbr.log.Warnfln("Failed to insert media backfill request %s/%s/%s: %v", mbr.UserID, mbr.PortalKey.String(), mbr.EventID, err) - } + UserID id.UserID + PortalKey PortalKey + EventID id.EventID + MediaKey []byte + Status MediaBackfillRequestStatus + Error string } -func (mbr *MediaBackfillRequest) Scan(row dbutil.Scannable) *MediaBackfillRequest { - err := row.Scan(&mbr.UserID, &mbr.PortalKey.JID, &mbr.PortalKey.Receiver, &mbr.EventID, &mbr.MediaKey, &mbr.Status, &mbr.Error) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - mbr.log.Errorln("Database scan failed:", err) - } - return nil - } - return mbr +func (mbr *MediaBackfillRequest) Scan(row dbutil.Scannable) (*MediaBackfillRequest, error) { + return dbutil.ValueOrErr(mbr, row.Scan(&mbr.UserID, &mbr.PortalKey.JID, &mbr.PortalKey.Receiver, &mbr.EventID, &mbr.MediaKey, &mbr.Status, &mbr.Error)) } -func (mbrq *MediaBackfillRequestQuery) GetMediaBackfillRequestsForUser(userID id.UserID) (requests []*MediaBackfillRequest) { - rows, err := mbrq.db.Query(getMediaBackfillRequestsForUser, userID) - defer rows.Close() - if err != nil || rows == nil { - return nil - } - for rows.Next() { - requests = append(requests, mbrq.newMediaBackfillRequest().Scan(rows)) - } - return +func (mbr *MediaBackfillRequest) sqlVariables() []any { + return []any{mbr.UserID, mbr.PortalKey.JID, mbr.PortalKey.Receiver, mbr.EventID, mbr.MediaKey, mbr.Status, mbr.Error} } -func (mbrq *MediaBackfillRequestQuery) DeleteAllMediaBackfillRequests(userID id.UserID) { - _, err := mbrq.db.Exec("DELETE FROM media_backfill_requests WHERE user_mxid=$1", userID) - if err != nil { - mbrq.log.Warnfln("Failed to delete media backfill requests for %s: %v", userID, err) - } +func (mbr *MediaBackfillRequest) Upsert(ctx context.Context) error { + return mbr.qh.Exec(ctx, upsertMediaBackfillRequestQuery, mbr.sqlVariables()...) } diff --git a/database/message.go b/database/message.go index dbd7e9c..63e9727 100644 --- a/database/message.go +++ b/database/message.go @@ -1,5 +1,5 @@ // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2021 Tulir Asokan +// Copyright (C) 2024 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -17,29 +17,22 @@ package database import ( - "database/sql" - "errors" + "context" "fmt" "strings" "time" "go.mau.fi/util/dbutil" "go.mau.fi/whatsmeow/types" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" ) type MessageQuery struct { - db *Database - log log.Logger + *dbutil.QueryHelper[*Message] } -func (mq *MessageQuery) New() *Message { - return &Message{ - db: mq.db, - log: mq.log, - } +func newMessage(qh *dbutil.QueryHelper[*Message]) *Message { + return &Message{qh: qh} } const ( @@ -67,60 +60,47 @@ const ( SELECT chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND timestamp>$3 AND timestamp<=$4 AND sent=true AND error='' ORDER BY timestamp ASC ` + insertMessageQuery = ` + INSERT INTO message + (chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + ` + markMessageSentQuery = "UPDATE message SET sent=true, timestamp=$1 WHERE chat_jid=$2 AND chat_receiver=$3 AND jid=$4" + updateMessageMXIDQuery = "UPDATE message SET mxid=$1, type=$2, error=$3 WHERE chat_jid=$4 AND chat_receiver=$5 AND jid=$6" + deleteMessageQuery = "DELETE FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3" ) -func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) { - rows, err := mq.db.Query(getAllMessagesQuery, chat.JID, chat.Receiver) - if err != nil || rows == nil { - return nil - } - for rows.Next() { - messages = append(messages, mq.New().Scan(rows)) - } - return +func (mq *MessageQuery) GetAll(ctx context.Context, chat PortalKey) ([]*Message, error) { + return mq.QueryMany(ctx, getAllMessagesQuery, chat.JID, chat.Receiver) } -func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.MessageID) *Message { - return mq.maybeScan(mq.db.QueryRow(getMessageByJIDQuery, chat.JID, chat.Receiver, jid)) +func (mq *MessageQuery) GetByJID(ctx context.Context, chat PortalKey, jid types.MessageID) (*Message, error) { + return mq.QueryOne(ctx, getMessageByJIDQuery, chat.JID, chat.Receiver, jid) } -func (mq *MessageQuery) GetByMXID(mxid id.EventID) *Message { - return mq.maybeScan(mq.db.QueryRow(getMessageByMXIDQuery, mxid)) +func (mq *MessageQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Message, error) { + return mq.QueryOne(ctx, getMessageByMXIDQuery, mxid) } -func (mq *MessageQuery) GetLastInChat(chat PortalKey) *Message { - return mq.GetLastInChatBefore(chat, time.Now().Add(60*time.Second)) +func (mq *MessageQuery) GetLastInChat(ctx context.Context, chat PortalKey) (*Message, error) { + return mq.GetLastInChatBefore(ctx, chat, time.Now().Add(60*time.Second)) } -func (mq *MessageQuery) GetLastInChatBefore(chat PortalKey, maxTimestamp time.Time) *Message { - msg := mq.maybeScan(mq.db.QueryRow(getLastMessageInChatQuery, chat.JID, chat.Receiver, maxTimestamp.Unix())) - if msg == nil || msg.Timestamp.IsZero() { +func (mq *MessageQuery) GetLastInChatBefore(ctx context.Context, chat PortalKey, maxTimestamp time.Time) (*Message, error) { + msg, err := mq.QueryOne(ctx, getLastMessageInChatQuery, chat.JID, chat.Receiver, maxTimestamp.Unix()) + if msg != nil && msg.Timestamp.IsZero() { // Old db, we don't know what the last message is. - return nil + msg = nil } - return msg + return msg, err } -func (mq *MessageQuery) GetFirstInChat(chat PortalKey) *Message { - return mq.maybeScan(mq.db.QueryRow(getFirstMessageInChatQuery, chat.JID, chat.Receiver)) +func (mq *MessageQuery) GetFirstInChat(ctx context.Context, chat PortalKey) (*Message, error) { + return mq.QueryOne(ctx, getFirstMessageInChatQuery, chat.JID, chat.Receiver) } -func (mq *MessageQuery) GetMessagesBetween(chat PortalKey, minTimestamp, maxTimestamp time.Time) (messages []*Message) { - rows, err := mq.db.Query(getMessagesBetweenQuery, chat.JID, chat.Receiver, minTimestamp.Unix(), maxTimestamp.Unix()) - if err != nil || rows == nil { - return nil - } - for rows.Next() { - messages = append(messages, mq.New().Scan(rows)) - } - return -} - -func (mq *MessageQuery) maybeScan(row *sql.Row) *Message { - if row == nil { - return nil - } - return mq.New().Scan(row) +func (mq *MessageQuery) GetMessagesBetween(ctx context.Context, chat PortalKey, minTimestamp, maxTimestamp time.Time) ([]*Message, error) { + return mq.QueryMany(ctx, getMessagesBetweenQuery, chat.JID, chat.Receiver, minTimestamp.Unix(), maxTimestamp.Unix()) } type MessageErrorType string @@ -144,8 +124,7 @@ const ( ) type Message struct { - db *Database - log log.Logger + qh *dbutil.QueryHelper[*Message] Chat PortalKey JID types.MessageID @@ -172,76 +151,49 @@ func (msg *Message) IsFakeJID() bool { const fakeGalleryMXIDFormat = "com.beeper.gallery::%d:%s" -func (msg *Message) Scan(row dbutil.Scannable) *Message { +func (msg *Message) Scan(row dbutil.Scannable) (*Message, error) { var ts int64 err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &msg.SenderMXID, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID) if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - msg.log.Errorln("Database scan failed:", err) - } - return nil + return nil, err } if strings.HasPrefix(msg.MXID.String(), "com.beeper.gallery::") { _, err = fmt.Sscanf(msg.MXID.String(), fakeGalleryMXIDFormat, &msg.GalleryPart, &msg.MXID) if err != nil { - msg.log.Errorln("Parsing gallery MXID failed:", err) + return nil, fmt.Errorf("failed to parse gallery MXID: %w", err) } } if ts != 0 { msg.Timestamp = time.Unix(ts, 0) } - return msg + return msg, nil } -func (msg *Message) Insert(txn dbutil.Execable) { - if txn == nil { - txn = msg.db - } - var sender interface{} = msg.Sender - // Slightly hacky hack to allow inserting empty senders (used for post-backfill dummy events) - if msg.Sender.IsEmpty() { - sender = "" - } +func (msg *Message) sqlVariables() []any { mxid := msg.MXID.String() if msg.GalleryPart != 0 { mxid = fmt.Sprintf(fakeGalleryMXIDFormat, msg.GalleryPart, mxid) } - _, err := txn.Exec(` - INSERT INTO message - (chat_jid, chat_receiver, jid, mxid, sender, sender_mxid, timestamp, sent, type, error, broadcast_list_jid) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) - `, msg.Chat.JID, msg.Chat.Receiver, msg.JID, mxid, sender, msg.SenderMXID, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID) - if err != nil { - msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err) - } + return []any{msg.Chat.JID, msg.Chat.Receiver, msg.JID, mxid, msg.Sender, msg.SenderMXID, msg.Timestamp.Unix(), msg.Sent, msg.Type, msg.Error, msg.BroadcastListJID} } -func (msg *Message) MarkSent(ts time.Time) { +func (msg *Message) Insert(ctx context.Context) error { + return msg.qh.Exec(ctx, insertMessageQuery, msg.sqlVariables()...) +} + +func (msg *Message) MarkSent(ctx context.Context, ts time.Time) error { msg.Sent = true msg.Timestamp = ts - _, err := msg.db.Exec("UPDATE message SET sent=true, timestamp=$1 WHERE chat_jid=$2 AND chat_receiver=$3 AND jid=$4", ts.Unix(), msg.Chat.JID, msg.Chat.Receiver, msg.JID) - if err != nil { - msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err) - } + return msg.qh.Exec(ctx, markMessageSentQuery, ts.Unix(), msg.Chat.JID, msg.Chat.Receiver, msg.JID) } -func (msg *Message) UpdateMXID(txn dbutil.Execable, mxid id.EventID, newType MessageType, newError MessageErrorType) { - if txn == nil { - txn = msg.db - } +func (msg *Message) UpdateMXID(ctx context.Context, mxid id.EventID, newType MessageType, newError MessageErrorType) error { msg.MXID = mxid msg.Type = newType msg.Error = newError - _, err := txn.Exec("UPDATE message SET mxid=$1, type=$2, error=$3 WHERE chat_jid=$4 AND chat_receiver=$5 AND jid=$6", - mxid, newType, newError, msg.Chat.JID, msg.Chat.Receiver, msg.JID) - if err != nil { - msg.log.Warnfln("Failed to update %s@%s: %v", msg.Chat, msg.JID, err) - } + return msg.qh.Exec(ctx, updateMessageMXIDQuery, mxid, newType, newError, msg.Chat.JID, msg.Chat.Receiver, msg.JID) } -func (msg *Message) Delete() { - _, err := msg.db.Exec("DELETE FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", msg.Chat.JID, msg.Chat.Receiver, msg.JID) - if err != nil { - msg.log.Warnfln("Failed to delete %s@%s: %v", msg.Chat, msg.JID, err) - } +func (msg *Message) Delete(ctx context.Context) error { + return msg.qh.Exec(ctx, deleteMessageQuery, msg.Chat.JID, msg.Chat.Receiver, msg.JID) } diff --git a/database/polloption.go b/database/polloption.go index 4af576f..03d8566 100644 --- a/database/polloption.go +++ b/database/polloption.go @@ -1,5 +1,5 @@ // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2022 Tulir Asokan +// Copyright (C) 2024 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -17,28 +17,56 @@ package database import ( + "context" "fmt" "strings" "github.com/lib/pq" + "go.mau.fi/util/dbutil" ) -func scanPollOptionMapping(rows dbutil.Rows) (id string, hashArr [32]byte, err error) { - var hash []byte - err = rows.Scan(&id, &hash) - if err != nil { - // return below - } else if len(hash) != 32 { - err = fmt.Errorf("unexpected hash length %d", len(hash)) - } else { - hashArr = *(*[32]byte)(hash) +const ( + bulkPutPollOptionsQuery = "INSERT INTO poll_option_id (msg_mxid, opt_id, opt_hash) VALUES ($1, $2, $3)" + bulkPutPollOptionsQueryTemplate = "($1, $%d, $%d)" + bulkPutPollOptionsQueryPlaceholder = "($1, $2, $3)" + getPollOptionIDsByHashesQuery = "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_hash = ANY($2)" + getPollOptionHashesByIDsQuery = "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_id = ANY($2)" + getPollOptionQuerySQLiteArrayTemplate = " IN (%s)" + getPollOptionQueryArrayPlaceholder = " = ANY($2)" +) + +func init() { + if strings.ReplaceAll(bulkPutPollOptionsQuery, bulkPutPollOptionsQueryPlaceholder, "meow") == bulkPutPollOptionsQuery { + panic("Bulk insert query placeholder not found") + } + if strings.ReplaceAll(getPollOptionIDsByHashesQuery, getPollOptionQueryArrayPlaceholder, "meow") == getPollOptionIDsByHashesQuery { + panic("Array select query placeholder not found") + } + if strings.ReplaceAll(getPollOptionHashesByIDsQuery, getPollOptionQueryArrayPlaceholder, "meow") == getPollOptionIDsByHashesQuery { + panic("Array select query placeholder not found") } - return } -func (msg *Message) PutPollOptions(opts map[[32]byte]string) { - query := "INSERT INTO poll_option_id (msg_mxid, opt_id, opt_hash) VALUES ($1, $2, $3)" +type pollOption struct { + id string + hash [32]byte +} + +func scanPollOption(rows dbutil.Scannable) (*pollOption, error) { + var hash []byte + var id string + err := rows.Scan(&id, &hash) + if err != nil { + return nil, err + } else if len(hash) != 32 { + return nil, fmt.Errorf("unexpected hash length %d", len(hash)) + } else { + return &pollOption{id: id, hash: [32]byte(hash)}, nil + } +} + +func (msg *Message) PutPollOptions(ctx context.Context, opts map[[32]byte]string) error { args := make([]any, len(opts)*2+1) placeholders := make([]string, len(opts)) args[0] = msg.MXID @@ -47,72 +75,47 @@ func (msg *Message) PutPollOptions(opts map[[32]byte]string) { args[i*2+1] = id hashCopy := hash args[i*2+2] = hashCopy[:] - placeholders[i] = fmt.Sprintf("($1, $%d, $%d)", i*2+2, i*2+3) + placeholders[i] = fmt.Sprintf(bulkPutPollOptionsQueryTemplate, i*2+2, i*2+3) i++ } - query = strings.ReplaceAll(query, "($1, $2, $3)", strings.Join(placeholders, ",")) - _, err := msg.db.Exec(query, args...) - if err != nil { - msg.log.Errorfln("Failed to save poll options for %s: %v", msg.MXID, err) - } + query := strings.ReplaceAll(bulkPutPollOptionsQuery, bulkPutPollOptionsQueryPlaceholder, strings.Join(placeholders, ",")) + return msg.qh.Exec(ctx, query, args...) } -func (msg *Message) GetPollOptionIDs(hashes [][]byte) map[[32]byte]string { - query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_hash = ANY($2)" +func getPollOptions[LookupKey any, Key comparable, Value any]( + ctx context.Context, + msg *Message, + query string, + things []LookupKey, + getKeyValue func(option *pollOption) (Key, Value), +) (map[Key]Value, error) { var args []any - if msg.db.Dialect == dbutil.Postgres { - args = []any{msg.MXID, pq.Array(hashes)} + if msg.qh.GetDB().Dialect == dbutil.Postgres { + args = []any{msg.MXID, pq.Array(things)} } else { - query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(hashes)), ","))) - args = make([]any, len(hashes)+1) + query = strings.ReplaceAll(query, getPollOptionQueryArrayPlaceholder, fmt.Sprintf(getPollOptionQuerySQLiteArrayTemplate, strings.TrimSuffix(strings.Repeat("?,", len(things)), ","))) + args = make([]any, len(things)+1) args[0] = msg.MXID - for i, hash := range hashes { - args[i+1] = hash + for i, thing := range things { + args[i+1] = thing } } - ids := make(map[[32]byte]string, len(hashes)) - rows, err := msg.db.Query(query, args...) - if err != nil { - msg.log.Errorfln("Failed to query poll option IDs for %s: %v", msg.MXID, err) - } else { - for rows.Next() { - id, hash, err := scanPollOptionMapping(rows) - if err != nil { - msg.log.Errorfln("Failed to scan poll option ID for %s: %v", msg.MXID, err) - break - } - ids[hash] = id - } - } - return ids + return dbutil.RowIterAsMap( + dbutil.ConvertRowFn[*pollOption](scanPollOption).NewRowIter(msg.qh.GetDB().Query(ctx, query, args...)), + getKeyValue, + ) } -func (msg *Message) GetPollOptionHashes(ids []string) map[string][32]byte { - query := "SELECT opt_id, opt_hash FROM poll_option_id WHERE msg_mxid=$1 AND opt_id = ANY($2)" - var args []any - if msg.db.Dialect == dbutil.Postgres { - args = []any{msg.MXID, pq.Array(ids)} - } else { - query = strings.ReplaceAll(query, " = ANY($2)", fmt.Sprintf(" IN (%s)", strings.TrimSuffix(strings.Repeat("?,", len(ids)), ","))) - args = make([]any, len(ids)+1) - args[0] = msg.MXID - for i, id := range ids { - args[i+1] = id - } - } - hashes := make(map[string][32]byte, len(ids)) - rows, err := msg.db.Query(query, args...) - if err != nil { - msg.log.Errorfln("Failed to query poll option hashes for %s: %v", msg.MXID, err) - } else { - for rows.Next() { - id, hash, err := scanPollOptionMapping(rows) - if err != nil { - msg.log.Errorfln("Failed to scan poll option hash for %s: %v", msg.MXID, err) - break - } - hashes[id] = hash - } - } - return hashes +func (msg *Message) GetPollOptionIDs(ctx context.Context, hashes [][]byte) (map[[32]byte]string, error) { + return getPollOptions( + ctx, msg, getPollOptionIDsByHashesQuery, hashes, + func(t *pollOption) ([32]byte, string) { return t.hash, t.id }, + ) +} + +func (msg *Message) GetPollOptionHashes(ctx context.Context, ids []string) (map[string][32]byte, error) { + return getPollOptions( + ctx, msg, getPollOptionHashesByIDsQuery, ids, + func(t *pollOption) (string, [32]byte) { return t.id, t.hash }, + ) } diff --git a/database/portal.go b/database/portal.go index 1b3eb00..72156db 100644 --- a/database/portal.go +++ b/database/portal.go @@ -1,5 +1,5 @@ // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2021 Tulir Asokan +// Copyright (C) 2024 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -17,15 +17,14 @@ package database import ( + "context" "database/sql" - "fmt" "time" - "go.mau.fi/util/dbutil" "go.mau.fi/whatsmeow/types" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" + + "go.mau.fi/util/dbutil" ) type PortalKey struct { @@ -53,90 +52,89 @@ func (key PortalKey) String() string { } type PortalQuery struct { - db *Database - log log.Logger + *dbutil.QueryHelper[*Portal] } -func (pq *PortalQuery) New() *Portal { +func newPortal(qh *dbutil.QueryHelper[*Portal]) *Portal { return &Portal{ - db: pq.db, - log: pq.log, + qh: qh, } } -const portalColumns = "jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, encrypted, last_sync, is_parent, parent_group, in_space, first_event_id, next_batch_id, relay_user_id, expiration_time" - -func (pq *PortalQuery) GetAll() []*Portal { - return pq.getAll(fmt.Sprintf("SELECT %s FROM portal", portalColumns)) -} - -func (pq *PortalQuery) GetByJID(key PortalKey) *Portal { - return pq.get(fmt.Sprintf("SELECT %s FROM portal WHERE jid=$1 AND receiver=$2", portalColumns), key.JID, key.Receiver) -} - -func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal { - return pq.get(fmt.Sprintf("SELECT %s FROM portal WHERE mxid=$1", portalColumns), mxid) -} - -func (pq *PortalQuery) GetAllByJID(jid types.JID) []*Portal { - return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE jid=$1", portalColumns), jid.ToNonAD()) -} - -func (pq *PortalQuery) GetAllByParentGroup(jid types.JID) []*Portal { - return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE parent_group=$1", portalColumns), jid) -} - -func (pq *PortalQuery) FindPrivateChats(receiver types.JID) []*Portal { - return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE receiver=$1 AND jid LIKE '%%@s.whatsapp.net'", portalColumns), receiver.ToNonAD()) -} - -func (pq *PortalQuery) FindPrivateChatsNotInSpace(receiver types.JID) (keys []PortalKey) { - receiver = receiver.ToNonAD() - rows, err := pq.db.Query(` +const ( + getAllPortalsQuery = ` + SELECT jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, + encrypted, last_sync, is_parent, parent_group, in_space, + first_event_id, next_batch_id, relay_user_id, expiration_time + FROM portal + ` + getPortalByJIDQuery = getAllPortalsQuery + " WHERE jid=$1 AND receiver=$2" + getPortalByMXIDQuery = getAllPortalsQuery + " WHERE mxid=$1" + getPrivateChatsWithQuery = getAllPortalsQuery + " WHERE jid=$1" + getPrivateChatsOfQuery = getAllPortalsQuery + " WHERE receiver=$1" + getAllPortalsByParentGroupQuery = getAllPortalsQuery + " WHERE parent_group=$1" + findPrivateChatPortalsNotInSpaceQuery = ` SELECT jid FROM portal LEFT JOIN user_portal ON portal.jid=user_portal.portal_jid AND portal.receiver=user_portal.portal_receiver WHERE mxid<>'' AND receiver=$1 AND (user_portal.in_space=false OR user_portal.in_space IS NULL) - `, receiver) - if err != nil { - pq.log.Errorfln("Failed to find private chats not in space for %s: %v", receiver, err) - return - } else if rows == nil { - return - } - for rows.Next() { - var key PortalKey + ` + + insertPortalQuery = ` + INSERT INTO portal ( + jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, + encrypted, last_sync, is_parent, parent_group, in_space, + first_event_id, next_batch_id, relay_user_id, expiration_time + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19) + ` + updatePortalQuery = ` + UPDATE portal + SET mxid=$3, name=$4, name_set=$5, topic=$6, topic_set=$7, avatar=$8, avatar_url=$9, avatar_set=$10, + encrypted=$11, last_sync=$12, is_parent=$13, parent_group=$14, in_space=$15, + first_event_id=$16, next_batch_id=$17, relay_user_id=$18, expiration_time=$19 + WHERE jid=$1 AND receiver=$2 + ` + clearPortalInSpaceQuery = "UPDATE portal SET in_space=false WHERE parent_group=$1" + deletePortalQuery = "DELETE FROM portal WHERE jid=$1 AND receiver=$2" +) + +func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) { + return pq.QueryMany(ctx, getAllPortalsQuery) +} + +func (pq *PortalQuery) GetByJID(ctx context.Context, key PortalKey) (*Portal, error) { + return pq.QueryOne(ctx, getPortalByJIDQuery, key.JID, key.Receiver) +} + +func (pq *PortalQuery) GetByMXID(ctx context.Context, mxid id.RoomID) (*Portal, error) { + return pq.QueryOne(ctx, getPortalByMXIDQuery, mxid) +} + +func (pq *PortalQuery) GetAllByJID(ctx context.Context, jid types.JID) ([]*Portal, error) { + return pq.QueryMany(ctx, getPrivateChatsWithQuery, jid.ToNonAD()) +} + +func (pq *PortalQuery) FindPrivateChats(ctx context.Context, receiver types.JID) ([]*Portal, error) { + return pq.QueryMany(ctx, getPrivateChatsOfQuery, receiver.ToNonAD()) +} + +func (pq *PortalQuery) GetAllByParentGroup(ctx context.Context, jid types.JID) ([]*Portal, error) { + return pq.QueryMany(ctx, getAllPortalsByParentGroupQuery, jid) +} + +func (pq *PortalQuery) FindPrivateChatsNotInSpace(ctx context.Context, receiver types.JID) (keys []PortalKey, err error) { + receiver = receiver.ToNonAD() + scanFn := func(rows dbutil.Scannable) (key PortalKey, err error) { key.Receiver = receiver err = rows.Scan(&key.JID) - if err == nil { - keys = append(keys, key) - } + return } - return -} - -func (pq *PortalQuery) getAll(query string, args ...interface{}) (portals []*Portal) { - rows, err := pq.db.Query(query, args...) - if err != nil || rows == nil { - return nil - } - defer rows.Close() - for rows.Next() { - portals = append(portals, pq.New().Scan(rows)) - } - return -} - -func (pq *PortalQuery) get(query string, args ...interface{}) *Portal { - row := pq.db.QueryRow(query, args...) - if row == nil { - return nil - } - return pq.New().Scan(row) + return dbutil.ConvertRowFn[PortalKey](scanFn). + NewRowIter(pq.GetDB().Query(ctx, findPrivateChatPortalsNotInSpaceQuery, receiver)). + AsList() } type Portal struct { - db *Database - log log.Logger + qh *dbutil.QueryHelper[*Portal] Key PortalKey MXID id.RoomID @@ -161,15 +159,17 @@ type Portal struct { ExpirationTime uint32 } -func (portal *Portal) Scan(row dbutil.Scannable) *Portal { +func (portal *Portal) Scan(row dbutil.Scannable) (*Portal, error) { var mxid, avatarURL, firstEventID, nextBatchID, relayUserID, parentGroupJID sql.NullString var lastSyncTs int64 - err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.NameSet, &portal.Topic, &portal.TopicSet, &portal.Avatar, &avatarURL, &portal.AvatarSet, &portal.Encrypted, &lastSyncTs, &portal.IsParent, &parentGroupJID, &portal.InSpace, &firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime) + err := row.Scan( + &portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.NameSet, + &portal.Topic, &portal.TopicSet, &portal.Avatar, &avatarURL, &portal.AvatarSet, &portal.Encrypted, + &lastSyncTs, &portal.IsParent, &parentGroupJID, &portal.InSpace, + &firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime, + ) if err != nil { - if err != sql.ErrNoRows { - portal.log.Errorln("Database scan failed:", err) - } - return nil + return nil, err } if lastSyncTs > 0 { portal.LastSync = time.Unix(lastSyncTs, 0) @@ -182,96 +182,36 @@ func (portal *Portal) Scan(row dbutil.Scannable) *Portal { portal.FirstEventID = id.EventID(firstEventID.String) portal.NextBatchID = id.BatchID(nextBatchID.String) portal.RelayUserID = id.UserID(relayUserID.String) - return portal + return portal, nil } -func (portal *Portal) mxidPtr() *id.RoomID { - if len(portal.MXID) > 0 { - return &portal.MXID +func (portal *Portal) sqlVariables() []any { + var lastSyncTS int64 + if !portal.LastSync.IsZero() { + lastSyncTS = portal.LastSync.Unix() } - return nil -} - -func (portal *Portal) relayUserPtr() *id.UserID { - if len(portal.RelayUserID) > 0 { - return &portal.RelayUserID - } - return nil -} - -func (portal *Portal) parentGroupPtr() *string { - if !portal.ParentGroup.IsEmpty() { - val := portal.ParentGroup.String() - return &val - } - return nil -} - -func (portal *Portal) lastSyncTs() int64 { - if portal.LastSync.IsZero() { - return 0 - } - return portal.LastSync.Unix() -} - -func (portal *Portal) Insert() { - _, err := portal.db.Exec(` - INSERT INTO portal (jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, - encrypted, last_sync, is_parent, parent_group, in_space, first_event_id, next_batch_id, - relay_user_id, expiration_time) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19) - `, - portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.NameSet, portal.Topic, portal.TopicSet, - portal.Avatar, portal.AvatarURL.String(), portal.AvatarSet, portal.Encrypted, portal.lastSyncTs(), - portal.IsParent, portal.parentGroupPtr(), portal.InSpace, portal.FirstEventID.String(), portal.NextBatchID.String(), - portal.relayUserPtr(), portal.ExpirationTime) - if err != nil { - portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err) + return []any{ + portal.Key.JID, portal.Key.Receiver, dbutil.StrPtr(portal.MXID), portal.Name, portal.NameSet, + portal.Topic, portal.TopicSet, portal.Avatar, portal.AvatarURL.String(), portal.AvatarSet, portal.Encrypted, + lastSyncTS, portal.IsParent, dbutil.StrPtr(portal.ParentGroup.String()), portal.InSpace, + portal.FirstEventID.String(), portal.NextBatchID.String(), dbutil.StrPtr(portal.RelayUserID), portal.ExpirationTime, } } -func (portal *Portal) Update(txn dbutil.Execable) { - if txn == nil { - txn = portal.db - } - _, err := txn.Exec(` - UPDATE portal - SET mxid=$1, name=$2, name_set=$3, topic=$4, topic_set=$5, avatar=$6, avatar_url=$7, avatar_set=$8, - encrypted=$9, last_sync=$10, is_parent=$11, parent_group=$12, in_space=$13, - first_event_id=$14, next_batch_id=$15, relay_user_id=$16, expiration_time=$17 - WHERE jid=$18 AND receiver=$19 - `, portal.mxidPtr(), portal.Name, portal.NameSet, portal.Topic, portal.TopicSet, portal.Avatar, portal.AvatarURL.String(), - portal.AvatarSet, portal.Encrypted, portal.lastSyncTs(), portal.IsParent, portal.parentGroupPtr(), portal.InSpace, - portal.FirstEventID.String(), portal.NextBatchID.String(), portal.relayUserPtr(), portal.ExpirationTime, - portal.Key.JID, portal.Key.Receiver) - if err != nil { - portal.log.Warnfln("Failed to update %s: %v", portal.Key, err) - } +func (portal *Portal) Insert(ctx context.Context) error { + return portal.qh.Exec(ctx, insertPortalQuery, portal.sqlVariables()...) } -func (portal *Portal) Delete() { - txn, err := portal.db.Begin() - if err != nil { - portal.log.Errorfln("Failed to begin transaction to delete portal %v: %v", portal.Key, err) - return - } - defer func() { +func (portal *Portal) Update(ctx context.Context) error { + return portal.qh.Exec(ctx, updatePortalQuery, portal.sqlVariables()...) +} + +func (portal *Portal) Delete(ctx context.Context) error { + return portal.qh.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error { + err := portal.qh.Exec(ctx, clearPortalInSpaceQuery, portal.Key.JID) if err != nil { - err = txn.Rollback() - if err != nil { - portal.log.Warnfln("Failed to rollback failed portal delete transaction: %v", err) - } - } else if err = txn.Commit(); err != nil { - portal.log.Warnfln("Failed to commit portal delete transaction: %v", err) + return err } - }() - _, err = txn.Exec("UPDATE portal SET in_space=false WHERE parent_group=$1", portal.Key.JID) - if err != nil { - portal.log.Warnfln("Failed to mark child groups of %v as not in space: %v", portal.Key.JID, err) - return - } - _, err = txn.Exec("DELETE FROM portal WHERE jid=$1 AND receiver=$2", portal.Key.JID, portal.Key.Receiver) - if err != nil { - portal.log.Warnfln("Failed to delete %v: %v", portal.Key, err) - } + return portal.qh.Exec(ctx, deletePortalQuery, portal.Key.JID, portal.Key.Receiver) + }) } diff --git a/database/puppet.go b/database/puppet.go index c490e7e..8b160b8 100644 --- a/database/puppet.go +++ b/database/puppet.go @@ -1,5 +1,5 @@ // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2021 Tulir Asokan +// Copyright (C) 2024 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -17,74 +17,70 @@ package database import ( + "context" "database/sql" "time" - "go.mau.fi/util/dbutil" + "github.com/rs/zerolog" "go.mau.fi/whatsmeow/types" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" + + "go.mau.fi/util/dbutil" ) type PuppetQuery struct { - db *Database - log log.Logger + *dbutil.QueryHelper[*Puppet] } -func (pq *PuppetQuery) New() *Puppet { +func newPuppet(qh *dbutil.QueryHelper[*Puppet]) *Puppet { return &Puppet{ - db: pq.db, - log: pq.log, + qh: qh, EnablePresence: true, EnableReceipts: true, } } -func (pq *PuppetQuery) GetAll() (puppets []*Puppet) { - rows, err := pq.db.Query("SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet") - if err != nil || rows == nil { - return nil - } - defer rows.Close() - for rows.Next() { - puppets = append(puppets, pq.New().Scan(rows)) - } - return +const ( + getAllPuppetsQuery = ` + SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, + last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts + FROM puppet + ` + getPuppetByJIDQuery = getAllPuppetsQuery + " WHERE username=$1" + getPuppetByCustomMXIDQuery = getAllPuppetsQuery + " WHERE custom_mxid=$1" + getAllPuppetsWithCustomMXIDQuery = getAllPuppetsQuery + " WHERE custom_mxid<>''" + insertPuppetQuery = ` + INSERT INTO puppet (username, avatar, avatar_url, avatar_set, displayname, name_quality, name_set, contact_info_set, + last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + ` + updatePuppetQuery = ` + UPDATE puppet + SET avatar=$2, avatar_url=$3, avatar_set=$4, displayname=$5, name_quality=$6, name_set=$7, contact_info_set=$8, + last_sync=$9, custom_mxid=$10, access_token=$11, next_batch=$12, enable_presence=$13, enable_receipts=$14 + WHERE username=$1 + ` +) + +func (pq *PuppetQuery) GetAll(ctx context.Context) ([]*Puppet, error) { + return pq.QueryMany(ctx, getAllPuppetsQuery) } -func (pq *PuppetQuery) Get(jid types.JID) *Puppet { - row := pq.db.QueryRow("SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE username=$1", jid.User) - if row == nil { - return nil - } - return pq.New().Scan(row) +func (pq *PuppetQuery) Get(ctx context.Context, jid types.JID) (*Puppet, error) { + return pq.QueryOne(ctx, getPuppetByJIDQuery, jid.User) } -func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet { - row := pq.db.QueryRow("SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE custom_mxid=$1", mxid) - if row == nil { - return nil - } - return pq.New().Scan(row) +func (pq *PuppetQuery) GetByCustomMXID(ctx context.Context, mxid id.UserID) (*Puppet, error) { + return pq.QueryOne(ctx, getPuppetByCustomMXIDQuery, mxid) } -func (pq *PuppetQuery) GetAllWithCustomMXID() (puppets []*Puppet) { - rows, err := pq.db.Query("SELECT username, avatar, avatar_url, displayname, name_quality, name_set, avatar_set, contact_info_set, last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts FROM puppet WHERE custom_mxid<>''") - if err != nil || rows == nil { - return nil - } - defer rows.Close() - for rows.Next() { - puppets = append(puppets, pq.New().Scan(rows)) - } - return +func (pq *PuppetQuery) GetAllWithCustomMXID(ctx context.Context) ([]*Puppet, error) { + return pq.QueryMany(ctx, getAllPuppetsWithCustomMXIDQuery) } type Puppet struct { - db *Database - log log.Logger + qh *dbutil.QueryHelper[*Puppet] JID types.JID Avatar string @@ -103,17 +99,14 @@ type Puppet struct { EnableReceipts bool } -func (puppet *Puppet) Scan(row dbutil.Scannable) *Puppet { +func (puppet *Puppet) Scan(row dbutil.Scannable) (*Puppet, error) { var displayname, avatar, avatarURL, customMXID, accessToken, nextBatch sql.NullString var quality, lastSync sql.NullInt64 var enablePresence, enableReceipts, nameSet, avatarSet, contactInfoSet sql.NullBool var username string err := row.Scan(&username, &avatar, &avatarURL, &displayname, &quality, &nameSet, &avatarSet, &contactInfoSet, &lastSync, &customMXID, &accessToken, &nextBatch, &enablePresence, &enableReceipts) if err != nil { - if err != sql.ErrNoRows { - puppet.log.Errorln("Database scan failed:", err) - } - return nil + return nil, err } puppet.JID = types.NewJID(username, types.DefaultUserServer) puppet.Displayname = displayname.String @@ -131,45 +124,30 @@ func (puppet *Puppet) Scan(row dbutil.Scannable) *Puppet { puppet.NextBatch = nextBatch.String puppet.EnablePresence = enablePresence.Bool puppet.EnableReceipts = enableReceipts.Bool - return puppet + return puppet, nil } -func (puppet *Puppet) Insert() { - if puppet.JID.Server != types.DefaultUserServer { - puppet.log.Warnfln("Not inserting %s: not a user", puppet.JID) - return - } - var lastSyncTs int64 +func (puppet *Puppet) sqlVariables() []any { + var lastSyncTS int64 if !puppet.LastSync.IsZero() { - lastSyncTs = puppet.LastSync.Unix() + lastSyncTS = puppet.LastSync.Unix() } - _, err := puppet.db.Exec(` - INSERT INTO puppet (username, avatar, avatar_url, avatar_set, displayname, name_quality, name_set, contact_info_set, - last_sync, custom_mxid, access_token, next_batch, enable_presence, enable_receipts) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) - `, puppet.JID.User, puppet.Avatar, puppet.AvatarURL.String(), puppet.AvatarSet, puppet.Displayname, - puppet.NameQuality, puppet.NameSet, puppet.ContactInfoSet, lastSyncTs, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch, + return []any{ + puppet.JID.User, puppet.Avatar, puppet.AvatarURL.String(), puppet.AvatarSet, puppet.Displayname, + puppet.NameQuality, puppet.NameSet, puppet.ContactInfoSet, lastSyncTS, + puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch, puppet.EnablePresence, puppet.EnableReceipts, - ) - if err != nil { - puppet.log.Warnfln("Failed to insert %s: %v", puppet.JID, err) } } -func (puppet *Puppet) Update() { - var lastSyncTs int64 - if !puppet.LastSync.IsZero() { - lastSyncTs = puppet.LastSync.Unix() - } - _, err := puppet.db.Exec(` - UPDATE puppet - SET displayname=$1, name_quality=$2, name_set=$3, avatar=$4, avatar_url=$5, avatar_set=$6, contact_info_set=$7, last_sync=$8, - custom_mxid=$9, access_token=$10, next_batch=$11, enable_presence=$12, enable_receipts=$13 - WHERE username=$14 - `, puppet.Displayname, puppet.NameQuality, puppet.NameSet, puppet.Avatar, puppet.AvatarURL.String(), puppet.AvatarSet, puppet.ContactInfoSet, - lastSyncTs, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch, puppet.EnablePresence, puppet.EnableReceipts, - puppet.JID.User) - if err != nil { - puppet.log.Warnfln("Failed to update %s: %v", puppet.JID, err) +func (puppet *Puppet) Insert(ctx context.Context) error { + if puppet.JID.Server != types.DefaultUserServer { + zerolog.Ctx(ctx).Warn().Stringer("jid", puppet.JID).Msg("Not inserting puppet: not a user") + return nil } + return puppet.qh.Exec(ctx, insertPuppetQuery, puppet.sqlVariables()...) +} + +func (puppet *Puppet) Update(ctx context.Context) error { + return puppet.qh.Exec(ctx, updatePuppetQuery, puppet.sqlVariables()...) } diff --git a/database/reaction.go b/database/reaction.go index 1769358..b74ead5 100644 --- a/database/reaction.go +++ b/database/reaction.go @@ -1,5 +1,5 @@ // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2022 Tulir Asokan +// Copyright (C) 2024 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -17,26 +17,20 @@ package database import ( - "database/sql" - "errors" + "context" + + "go.mau.fi/whatsmeow/types" + "maunium.net/go/mautrix/id" "go.mau.fi/util/dbutil" - "go.mau.fi/whatsmeow/types" - log "maunium.net/go/maulogger/v2" - - "maunium.net/go/mautrix/id" ) type ReactionQuery struct { - db *Database - log log.Logger + *dbutil.QueryHelper[*Reaction] } -func (rq *ReactionQuery) New() *Reaction { - return &Reaction{ - db: rq.db, - log: rq.log, - } +func newReaction(qh *dbutil.QueryHelper[*Reaction]) *Reaction { + return &Reaction{qh: qh} } const ( @@ -55,28 +49,20 @@ const ( DO UPDATE SET mxid=excluded.mxid, jid=excluded.jid ` deleteReactionQuery = ` - DELETE FROM reaction WHERE chat_jid=$1 AND chat_receiver=$2 AND target_jid=$3 AND sender=$4 AND mxid=$5 + DELETE FROM reaction WHERE chat_jid=$1 AND chat_receiver=$2 AND target_jid=$3 AND sender=$4 ` ) -func (rq *ReactionQuery) GetByTargetJID(chat PortalKey, jid types.MessageID, sender types.JID) *Reaction { - return rq.maybeScan(rq.db.QueryRow(getReactionByTargetJIDQuery, chat.JID, chat.Receiver, jid, sender.ToNonAD())) +func (rq *ReactionQuery) GetByTargetJID(ctx context.Context, chat PortalKey, jid types.MessageID, sender types.JID) (*Reaction, error) { + return rq.QueryOne(ctx, getReactionByTargetJIDQuery, chat.JID, chat.Receiver, jid, sender.ToNonAD()) } -func (rq *ReactionQuery) GetByMXID(mxid id.EventID) *Reaction { - return rq.maybeScan(rq.db.QueryRow(getReactionByMXIDQuery, mxid)) -} - -func (rq *ReactionQuery) maybeScan(row *sql.Row) *Reaction { - if row == nil { - return nil - } - return rq.New().Scan(row) +func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) { + return rq.QueryOne(ctx, getReactionByMXIDQuery, mxid) } type Reaction struct { - db *Database - log log.Logger + qh *dbutil.QueryHelper[*Reaction] Chat PortalKey TargetJID types.MessageID @@ -85,35 +71,19 @@ type Reaction struct { JID types.MessageID } -func (reaction *Reaction) Scan(row dbutil.Scannable) *Reaction { - err := row.Scan(&reaction.Chat.JID, &reaction.Chat.Receiver, &reaction.TargetJID, &reaction.Sender, &reaction.MXID, &reaction.JID) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - reaction.log.Errorln("Database scan failed:", err) - } - return nil - } - return reaction +func (reaction *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) { + return dbutil.ValueOrErr(reaction, row.Scan(&reaction.Chat.JID, &reaction.Chat.Receiver, &reaction.TargetJID, &reaction.Sender, &reaction.MXID, &reaction.JID)) } -func (reaction *Reaction) Upsert(txn dbutil.Execable) { +func (reaction *Reaction) sqlVariables() []any { reaction.Sender = reaction.Sender.ToNonAD() - if txn == nil { - txn = reaction.db - } - _, err := txn.Exec(upsertReactionQuery, reaction.Chat.JID, reaction.Chat.Receiver, reaction.TargetJID, reaction.Sender, reaction.MXID, reaction.JID) - if err != nil { - reaction.log.Warnfln("Failed to upsert reaction to %s@%s by %s: %v", reaction.Chat, reaction.TargetJID, reaction.Sender, err) - } + return []any{reaction.Chat.JID, reaction.Chat.Receiver, reaction.TargetJID, reaction.Sender, reaction.MXID, reaction.JID} } -func (reaction *Reaction) GetTarget() *Message { - return reaction.db.Message.GetByJID(reaction.Chat, reaction.TargetJID) +func (reaction *Reaction) Upsert(ctx context.Context) error { + return reaction.qh.Exec(ctx, upsertReactionQuery, reaction.sqlVariables()...) } -func (reaction *Reaction) Delete() { - _, err := reaction.db.Exec(deleteReactionQuery, reaction.Chat.JID, reaction.Chat.Receiver, reaction.TargetJID, reaction.Sender, reaction.MXID) - if err != nil { - reaction.log.Warnfln("Failed to delete reaction %s: %v", reaction.MXID, err) - } +func (reaction *Reaction) Delete(ctx context.Context) error { + return reaction.qh.Exec(ctx, deleteReactionQuery, reaction.Chat.JID, reaction.Chat.Receiver, reaction.TargetJID, reaction.Sender) } diff --git a/database/upgrades/upgrades.go b/database/upgrades/upgrades.go index 990f496..f8015a4 100644 --- a/database/upgrades/upgrades.go +++ b/database/upgrades/upgrades.go @@ -17,6 +17,7 @@ package upgrades import ( + "context" "embed" "errors" @@ -29,7 +30,7 @@ var Table dbutil.UpgradeTable var rawUpgrades embed.FS func init() { - Table.Register(-1, 35, 0, "Unsupported version", false, func(tx dbutil.Execable, database *dbutil.Database) error { + Table.Register(-1, 35, 0, "Unsupported version", false, func(ctx context.Context, database *dbutil.Database) error { return errors.New("please upgrade to mautrix-whatsapp v0.4.0 before upgrading to a newer version") }) Table.RegisterFS(rawUpgrades) diff --git a/database/user.go b/database/user.go index a4ebc35..1eb1cb7 100644 --- a/database/user.go +++ b/database/user.go @@ -1,5 +1,5 @@ // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2022 Tulir Asokan +// Copyright (C) 2024 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -17,63 +17,65 @@ package database import ( + "context" "database/sql" "sync" "time" - "go.mau.fi/util/dbutil" "go.mau.fi/whatsmeow/types" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" + + "go.mau.fi/util/dbutil" ) type UserQuery struct { - db *Database - log log.Logger + *dbutil.QueryHelper[*User] } -func (uq *UserQuery) New() *User { +func newUser(qh *dbutil.QueryHelper[*User]) *User { return &User{ - db: uq.db, - log: uq.log, + qh: qh, lastReadCache: make(map[PortalKey]time.Time), inSpaceCache: make(map[PortalKey]bool), } } -func (uq *UserQuery) GetAll() (users []*User) { - rows, err := uq.db.Query(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user"`) - if err != nil || rows == nil { - return nil - } - defer rows.Close() - for rows.Next() { - users = append(users, uq.New().Scan(rows)) - } - return +const ( + getAllUsersQuery = `SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user"` + getUserByMXIDQuery = getAllUsersQuery + ` WHERE mxid=$1` + getUserByUsernameQuery = getAllUsersQuery + ` WHERE username=$1` + insertUserQuery = ` + INSERT INTO "user" ( + mxid, username, agent, device, + management_room, space_room, + phone_last_seen, phone_last_pinged, timezone + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + ` + updateUserQuery = ` + UPDATE "user" + SET username=$1, agent=$2, device=$3, + management_room=$4, space_room=$5, + phone_last_seen=$6, phone_last_pinged=$7, timezone=$8 + WHERE mxid=$9 + ` + getUserLastAppStateKeyIDQuery = "SELECT key_id FROM whatsmeow_app_state_sync_keys WHERE jid=$1 ORDER BY timestamp DESC LIMIT 1" +) + +func (uq *UserQuery) GetAll(ctx context.Context) ([]*User, error) { + return uq.QueryMany(ctx, getAllUsersQuery) } -func (uq *UserQuery) GetByMXID(userID id.UserID) *User { - row := uq.db.QueryRow(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user" WHERE mxid=$1`, userID) - if row == nil { - return nil - } - return uq.New().Scan(row) +func (uq *UserQuery) GetByMXID(ctx context.Context, userID id.UserID) (*User, error) { + return uq.QueryOne(ctx, getUserByMXIDQuery, userID) } -func (uq *UserQuery) GetByUsername(username string) *User { - row := uq.db.QueryRow(`SELECT mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone FROM "user" WHERE username=$1`, username) - if row == nil { - return nil - } - return uq.New().Scan(row) +func (uq *UserQuery) GetByUsername(ctx context.Context, username string) (*User, error) { + return uq.QueryOne(ctx, getUserByUsernameQuery, username) } type User struct { - db *Database - log log.Logger + qh *dbutil.QueryHelper[*User] MXID id.UserID JID types.JID @@ -89,20 +91,21 @@ type User struct { inSpaceCacheLock sync.Mutex } -func (user *User) Scan(row dbutil.Scannable) *User { +func (user *User) Scan(row dbutil.Scannable) (*User, error) { var username, timezone sql.NullString - var device, agent sql.NullByte + var device, agent sql.NullInt16 var phoneLastSeen, phoneLastPinged sql.NullInt64 err := row.Scan(&user.MXID, &username, &agent, &device, &user.ManagementRoom, &user.SpaceRoom, &phoneLastSeen, &phoneLastPinged, &timezone) if err != nil { - if err != sql.ErrNoRows { - user.log.Errorln("Database scan failed:", err) - } - return nil + return nil, err } user.Timezone = timezone.String if len(username.String) > 0 { - user.JID = types.NewADJID(username.String, agent.Byte, device.Byte) + user.JID = types.JID{ + User: username.String, + Device: uint16(device.Int16), + Server: types.DefaultUserServer, + } } if phoneLastSeen.Valid { user.PhoneLastSeen = time.Unix(phoneLastSeen.Int64, 0) @@ -110,66 +113,34 @@ func (user *User) Scan(row dbutil.Scannable) *User { if phoneLastPinged.Valid { user.PhoneLastPinged = time.Unix(phoneLastPinged.Int64, 0) } - return user + return user, nil } -func (user *User) usernamePtr() *string { +func (user *User) sqlVariables() []any { + var username *string + var agent, device *uint16 if !user.JID.IsEmpty() { - return &user.JID.User + username = dbutil.StrPtr(user.JID.User) + var zero uint16 + agent = &zero + device = dbutil.NumPtr(user.JID.Device) } - return nil -} - -func (user *User) agentPtr() *uint8 { - if !user.JID.IsEmpty() { - zero := uint8(0) - return &zero - } - return nil -} - -func (user *User) devicePtr() *uint8 { - if !user.JID.IsEmpty() { - device := uint8(user.JID.Device) - return &device - } - return nil -} - -func (user *User) phoneLastSeenPtr() *int64 { - if user.PhoneLastSeen.IsZero() { - return nil - } - ts := user.PhoneLastSeen.Unix() - return &ts -} - -func (user *User) phoneLastPingedPtr() *int64 { - if user.PhoneLastPinged.IsZero() { - return nil - } - ts := user.PhoneLastPinged.Unix() - return &ts -} - -func (user *User) Insert() { - _, err := user.db.Exec(`INSERT INTO "user" (mxid, username, agent, device, management_room, space_room, phone_last_seen, phone_last_pinged, timezone) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`, - user.MXID, user.usernamePtr(), user.agentPtr(), user.devicePtr(), user.ManagementRoom, user.SpaceRoom, user.phoneLastSeenPtr(), user.phoneLastPingedPtr(), user.Timezone) - if err != nil { - user.log.Warnfln("Failed to insert %s: %v", user.MXID, err) + return []any{ + username, agent, device, user.ManagementRoom, user.SpaceRoom, + dbutil.UnixPtr(user.PhoneLastSeen), dbutil.UnixPtr(user.PhoneLastPinged), + user.Timezone, user.MXID, } } -func (user *User) Update() { - _, err := user.db.Exec(`UPDATE "user" SET username=$1, agent=$2, device=$3, management_room=$4, space_room=$5, phone_last_seen=$6, phone_last_pinged=$7, timezone=$8 WHERE mxid=$9`, - user.usernamePtr(), user.agentPtr(), user.devicePtr(), user.ManagementRoom, user.SpaceRoom, user.phoneLastSeenPtr(), user.phoneLastPingedPtr(), user.Timezone, user.MXID) - if err != nil { - user.log.Warnfln("Failed to update %s: %v", user.MXID, err) - } +func (user *User) Insert(ctx context.Context) error { + return user.qh.Exec(ctx, insertUserQuery, user.sqlVariables()...) } -func (user *User) GetLastAppStateKeyID() ([]byte, error) { - var keyID []byte - err := user.db.QueryRow("SELECT key_id FROM whatsmeow_app_state_sync_keys ORDER BY timestamp DESC LIMIT 1").Scan(&keyID) - return keyID, err +func (user *User) Update(ctx context.Context) error { + return user.qh.Exec(ctx, updateUserQuery, user.sqlVariables()...) +} + +func (user *User) GetLastAppStateKeyID(ctx context.Context) (keyID []byte, err error) { + err = user.qh.GetDB().QueryRow(ctx, getUserLastAppStateKeyIDQuery, user.JID).Scan(&keyID) + return } diff --git a/database/userportal.go b/database/userportal.go index be339df..ff1a96a 100644 --- a/database/userportal.go +++ b/database/userportal.go @@ -1,5 +1,5 @@ // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2021 Tulir Asokan +// Copyright (C) 2024 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -17,69 +17,97 @@ package database import ( + "context" "database/sql" "errors" "time" + + "github.com/rs/zerolog" ) -func (user *User) GetLastReadTS(portal PortalKey) time.Time { +const ( + getLastReadTSQuery = "SELECT last_read_ts FROM user_portal WHERE user_mxid=$1 AND portal_jid=$2 AND portal_receiver=$3" + setLastReadTSQuery = ` + INSERT INTO user_portal (user_mxid, portal_jid, portal_receiver, last_read_ts) VALUES ($1, $2, $3, $4) + ON CONFLICT (user_mxid, portal_jid, portal_receiver) DO UPDATE SET last_read_ts=excluded.last_read_ts WHERE user_portal.last_read_ts 0 { + member, err := formatter.bridge.StateStore.GetMember(ctx, roomID, user.MXID) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("room_id", roomID). + Stringer("user_id", user.MXID). + Msg("Failed to get member profile from state store") + } else if len(member.Displayname) > 0 { displayname = member.Displayname } } return } -func (formatter *Formatter) ParseWhatsApp(roomID id.RoomID, content *event.MessageEventContent, mentionedJIDs []string, allowInlineURL, forceHTML bool) { +func (formatter *Formatter) ParseWhatsApp(ctx context.Context, roomID id.RoomID, content *event.MessageEventContent, mentionedJIDs []string, allowInlineURL, forceHTML bool) { output := html.EscapeString(content.Body) for regex, replacement := range formatter.waReplString { output = regex.ReplaceAllString(output, replacement) @@ -145,7 +152,7 @@ func (formatter *Formatter) ParseWhatsApp(roomID id.RoomID, content *event.Messa // TODO lid support? continue } - mxid, displayname := formatter.getMatrixInfoByJID(roomID, jid) + mxid, displayname := formatter.getMatrixInfoByJID(ctx, roomID, jid) number := "@" + jid.User output = strings.ReplaceAll(output, number, fmt.Sprintf(`%s`, mxid, displayname)) content.Body = strings.ReplaceAll(content.Body, number, displayname) diff --git a/go.mod b/go.mod index 6135da5..fe05c3e 100644 --- a/go.mod +++ b/go.mod @@ -1,25 +1,25 @@ module maunium.net/go/mautrix-whatsapp -go 1.20 +go 1.21 require ( + github.com/beeper/libserv v0.0.0-20231231202820-c7303abfc32c github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 github.com/lib/pq v1.10.9 - github.com/mattn/go-sqlite3 v1.14.19 - github.com/prometheus/client_golang v1.17.0 - github.com/rs/zerolog v1.31.0 + github.com/mattn/go-sqlite3 v1.14.22 + github.com/prometheus/client_golang v1.19.0 + github.com/rs/zerolog v1.32.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e - github.com/tidwall/gjson v1.17.0 - go.mau.fi/util v0.2.1 + github.com/tidwall/gjson v1.17.1 + go.mau.fi/util v0.4.1-0.20240311141448-53cb04950f7e go.mau.fi/webp v0.1.0 - go.mau.fi/whatsmeow v0.0.0-20231216213200-9d803dd92735 - golang.org/x/exp v0.0.0-20231214170342-aacd6d4b4611 - golang.org/x/image v0.14.0 - golang.org/x/net v0.19.0 - google.golang.org/protobuf v1.31.0 - maunium.net/go/maulogger/v2 v2.4.1 - maunium.net/go/mautrix v0.16.2 + go.mau.fi/whatsmeow v0.0.0-20240311200223-e9bca1903462 + golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 + golang.org/x/image v0.15.0 + golang.org/x/net v0.22.0 + google.golang.org/protobuf v1.33.0 + maunium.net/go/mautrix v0.18.0-beta.1.0.20240311183606-94246ffc85aa ) require ( @@ -27,24 +27,27 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect - github.com/golang/protobuf v1.5.3 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect - github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect - github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect - github.com/prometheus/common v0.44.0 // indirect - github.com/prometheus/procfs v0.11.1 // indirect + github.com/prometheus/client_model v0.5.0 // indirect + github.com/prometheus/common v0.48.0 // indirect + github.com/prometheus/procfs v0.12.0 // indirect + github.com/rs/xid v1.5.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/sjson v1.2.5 // indirect - github.com/yuin/goldmark v1.6.0 // indirect + github.com/yuin/goldmark v1.7.0 // indirect go.mau.fi/libsignal v0.1.0 // indirect go.mau.fi/zeroconfig v0.1.2 // indirect - golang.org/x/crypto v0.16.0 // indirect - golang.org/x/sys v0.15.0 // indirect + golang.org/x/crypto v0.21.0 // indirect + golang.org/x/sys v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect maunium.net/go/mauflag v1.0.0 // indirect ) + +//replace go.mau.fi/util => ../../Go/go-util +//replace maunium.net/go/mautrix => ../mautrix-go diff --git a/go.sum b/go.sum index f568af8..7ec33bd 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,9 @@ filippo.io/edwards25519 v1.0.0 h1:0wAIcmJUqRdI8IJ/3eGi5/HwXZWPujYXXlkrQogz0Ek= filippo.io/edwards25519 v1.0.0/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns= -github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/beeper/libserv v0.0.0-20231231202820-c7303abfc32c h1:WqjRVgUO039eiISCjsZC4F9onOEV93DJAk6v33rsZzY= +github.com/beeper/libserv v0.0.0-20231231202820-c7303abfc32c/go.mod h1:b9FFm9y4mEm36G8ytVmS1vkNzJa0KepmcdVY+qf7qRU= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= @@ -9,18 +12,18 @@ github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8 github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= -github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= @@ -30,78 +33,75 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbWfGI= -github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= -github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= -github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/prometheus/client_golang v1.17.0 h1:rl2sfwZMtSthVU752MqfjQozy7blglC+1SOtjMAMh+Q= -github.com/prometheus/client_golang v1.17.0/go.mod h1:VeL+gMmOAxkS2IqfCq0ZmHSL+LjWfWDUmp1mBz9JgUY= -github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 h1:v7DLqVdK4VrYkVD5diGdl4sxJurKJEMnODWRJlxV9oM= -github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU= -github.com/prometheus/common v0.44.0 h1:+5BrQJwiBB9xsMygAB3TNvpQKOwlkc25LbISbrdOOfY= -github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO7x0VV9VvuY= -github.com/prometheus/procfs v0.11.1 h1:xRC8Iq1yyca5ypa9n1EZnWZkt7dwcoRPQwX/5gwaUuI= -github.com/prometheus/procfs v0.11.1/go.mod h1:eesXgaPo1q7lBpVMoMy0ZOFTth9hBn4W/y0/p/ScXhY= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU= +github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k= +github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= +github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= +github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE= +github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc= +github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= +github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A= -github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0= +github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM= -github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U= +github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68= -github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yuin/goldmark v1.7.0 h1:EfOIvIMZIzHdB/R/zVrikYLPPwJlfMcNczJFMs1m6sA= +github.com/yuin/goldmark v1.7.0/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= go.mau.fi/libsignal v0.1.0 h1:vAKI/nJ5tMhdzke4cTK1fb0idJzz1JuEIpmjprueC+c= go.mau.fi/libsignal v0.1.0/go.mod h1:R8ovrTezxtUNzCQE5PH30StOQWWeBskBsWE55vMfY9I= -go.mau.fi/util v0.2.1 h1:eazulhFE/UmjOFtPrGg6zkF5YfAyiDzQb8ihLMbsPWw= -go.mau.fi/util v0.2.1/go.mod h1:MjlzCQEMzJ+G8RsPawHzpLB8rwTo3aPIjG5FzBvQT/c= +go.mau.fi/util v0.4.1-0.20240311141448-53cb04950f7e h1:e1jDj/MjleSS5r9DMRbuCZYKy5Rr+sbsu8eWjtLqrGk= +go.mau.fi/util v0.4.1-0.20240311141448-53cb04950f7e/go.mod h1:jOAREC/go8T6rGic01cu6WRa90xi9U4z3QmDjRf8xpo= go.mau.fi/webp v0.1.0 h1:BHObH/DcFntT9KYun5pDr0Ot4eUZO8k2C7eP7vF4ueA= go.mau.fi/webp v0.1.0/go.mod h1:e42Z+VMFrUMS9cpEwGRIor+lQWO8oUAyPyMtcL+NMt8= -go.mau.fi/whatsmeow v0.0.0-20231216213200-9d803dd92735 h1:+teJYCOK6M4Kn2TYCj29levhHVwnJTmgCtEXLtgwQtM= -go.mau.fi/whatsmeow v0.0.0-20231216213200-9d803dd92735/go.mod h1:5xTtHNaZpGni6z6aE1iEopjW7wNgsKcolZxZrOujK9M= +go.mau.fi/whatsmeow v0.0.0-20240311200223-e9bca1903462 h1:QOGjCIh2WEfkgX/38KLjnNof79GWx0T+KLrhTHiws3s= +go.mau.fi/whatsmeow v0.0.0-20240311200223-e9bca1903462/go.mod h1:lQHbhaG/fI+6hfGqz5Vzn2OBJBEZ05H0kCP6iJXriN4= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= -golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/exp v0.0.0-20231214170342-aacd6d4b4611 h1:qCEDpW1G+vcj3Y7Fy52pEM1AWm3abj8WimGYejI3SC4= -golang.org/x/exp v0.0.0-20231214170342-aacd6d4b4611/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= -golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= -golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= -golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= -golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ= +golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= +golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= +golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= +golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= +golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= -google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= -maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8= -maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho= -maunium.net/go/mautrix v0.16.2 h1:a6GUJXNWsTEOO8VE4dROBfCIfPp50mqaqzv7KPzChvg= -maunium.net/go/mautrix v0.16.2/go.mod h1:YL4l4rZB46/vj/ifRMEjcibbvHjgxHftOF1SgmruLu4= +maunium.net/go/mautrix v0.18.0-beta.1.0.20240311183606-94246ffc85aa h1:TLSWIAWKIWxLghgzWfp7o92pVCcFR3yLsArc0s/tsMs= +maunium.net/go/mautrix v0.18.0-beta.1.0.20240311183606-94246ffc85aa/go.mod h1:0sfLB2ejW+lhgio4UlZMmn5i9SuZ8mxFkonFSamrfTE= diff --git a/historysync.go b/historysync.go index 4adb14e..5189bde 100644 --- a/historysync.go +++ b/historysync.go @@ -17,6 +17,7 @@ package main import ( + "context" "crypto/sha256" "encoding/base64" "fmt" @@ -24,11 +25,11 @@ import ( "time" "github.com/rs/zerolog" - "go.mau.fi/util/dbutil" - "go.mau.fi/util/variationselector" waProto "go.mau.fi/whatsmeow/binary/proto" "go.mau.fi/whatsmeow/types" + "go.mau.fi/util/variationselector" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/event" @@ -64,9 +65,8 @@ func (user *User) handleHistorySyncsLoop() { if batchSend { // Start the backfill queue. user.BackfillQueue = &BackfillQueue{ - BackfillQuery: user.bridge.DB.Backfill, + BackfillQuery: user.bridge.DB.BackfillQueue, reCheckChannels: []chan bool{}, - log: user.log.Sub("BackfillQueue"), } forwardAndImmediate := []database.BackfillType{database.BackfillImmediate, database.BackfillForward} @@ -109,80 +109,113 @@ func (user *User) handleHistorySyncsLoop() { const EnqueueBackfillsDelay = 30 * time.Second func (user *User) enqueueAllBackfills() { - nMostRecent := user.bridge.DB.HistorySync.GetRecentConversations(user.MXID, user.bridge.Config.Bridge.HistorySync.MaxInitialConversations) - if len(nMostRecent) > 0 { - user.log.Infofln("%v has passed since the last history sync blob, enqueueing backfills for %d chats", EnqueueBackfillsDelay, len(nMostRecent)) - // Find the portals for all the conversations. - portals := []*Portal{} - for _, conv := range nMostRecent { - jid, err := types.ParseJID(conv.ConversationID) - if err != nil { - user.log.Warnfln("Failed to parse chat JID '%s' in history sync: %v", conv.ConversationID, err) - continue - } - portals = append(portals, user.GetPortalByJID(jid)) - } - - user.EnqueueImmediateBackfills(portals) - user.EnqueueForwardBackfills(portals) - user.EnqueueDeferredBackfills(portals) - - // Tell the queue to check for new backfill requests. - user.BackfillQueue.ReCheck() + log := user.zlog.With(). + Str("method", "User.enqueueAllBackfills"). + Logger() + ctx := log.WithContext(context.TODO()) + nMostRecent, err := user.bridge.DB.HistorySync.GetRecentConversations(ctx, user.MXID, user.bridge.Config.Bridge.HistorySync.MaxInitialConversations) + if err != nil { + log.Err(err).Msg("Failed to get recent history sync conversations from database") + return + } else if len(nMostRecent) == 0 { + return } + log.Info(). + Int("chat_count", len(nMostRecent)). + Msg("Enqueueing backfills for recent chats in history sync") + // Find the portals for all the conversations. + portals := make([]*Portal, 0, len(nMostRecent)) + for _, conv := range nMostRecent { + jid, err := types.ParseJID(conv.ConversationID) + if err != nil { + log.Err(err).Str("conversation_id", conv.ConversationID).Msg("Failed to parse chat JID in history sync") + continue + } + portals = append(portals, user.GetPortalByJID(jid)) + } + + user.EnqueueImmediateBackfills(ctx, portals) + user.EnqueueForwardBackfills(ctx, portals) + user.EnqueueDeferredBackfills(ctx, portals) + + // Tell the queue to check for new backfill requests. + user.BackfillQueue.ReCheck() } func (user *User) backfillAll() { - conversations := user.bridge.DB.HistorySync.GetRecentConversations(user.MXID, -1) - if len(conversations) > 0 { - user.zlog.Info(). - Int("conversation_count", len(conversations)). - Msg("Probably received all history sync blobs, now backfilling conversations") - limit := user.bridge.Config.Bridge.HistorySync.MaxInitialConversations - bridgedCount := 0 - // Find the portals for all the conversations. - for _, conv := range conversations { - jid, err := types.ParseJID(conv.ConversationID) + log := user.zlog.With(). + Str("method", "User.backfillAll"). + Logger() + ctx := log.WithContext(context.TODO()) + conversations, err := user.bridge.DB.HistorySync.GetRecentConversations(ctx, user.MXID, -1) + if err != nil { + log.Err(err).Msg("Failed to get history sync conversations from database") + return + } else if len(conversations) == 0 { + return + } + log.Info(). + Int("conversation_count", len(conversations)). + Msg("Probably received all history sync blobs, now backfilling conversations") + limit := user.bridge.Config.Bridge.HistorySync.MaxInitialConversations + bridgedCount := 0 + // Find the portals for all the conversations. + for _, conv := range conversations { + jid, err := types.ParseJID(conv.ConversationID) + if err != nil { + log.Err(err). + Str("conversation_id", conv.ConversationID). + Msg("Failed to parse chat JID in history sync") + continue + } + portal := user.GetPortalByJID(jid) + if portal.MXID != "" { + log.Debug(). + Str("portal_jid", portal.Key.JID.String()). + Msg("Chat already has a room, deleting messages from database") + err = user.bridge.DB.HistorySync.DeleteConversation(ctx, user.MXID, portal.Key.JID.String()) if err != nil { - user.zlog.Warn().Err(err). - Str("conversation_id", conv.ConversationID). - Msg("Failed to parse chat JID in history sync") - continue + log.Err(err).Str("portal_jid", portal.Key.JID.String()). + Msg("Failed to delete history sync conversation with existing portal from database") } - portal := user.GetPortalByJID(jid) - if portal.MXID != "" { - user.zlog.Debug(). - Str("portal_jid", portal.Key.JID.String()). - Msg("Chat already has a room, deleting messages from database") - user.bridge.DB.HistorySync.DeleteConversation(user.MXID, portal.Key.JID.String()) - bridgedCount++ - } else if !user.bridge.DB.HistorySync.ConversationHasMessages(user.MXID, portal.Key) { - user.zlog.Debug().Str("portal_jid", portal.Key.JID.String()).Msg("Skipping chat with no messages in history sync") - user.bridge.DB.HistorySync.DeleteConversation(user.MXID, portal.Key.JID.String()) - } else if limit < 0 || bridgedCount < limit { - bridgedCount++ - err = portal.CreateMatrixRoom(user, nil, nil, true, true) - if err != nil { - user.zlog.Err(err).Msg("Failed to create Matrix room for backfill") - } + bridgedCount++ + } else if hasMessages, err := user.bridge.DB.HistorySync.ConversationHasMessages(ctx, user.MXID, portal.Key); err != nil { + log.Err(err).Str("portal_jid", portal.Key.JID.String()).Msg("Failed to check if chat has messages in history sync") + } else if !hasMessages { + log.Debug().Str("portal_jid", portal.Key.JID.String()).Msg("Skipping chat with no messages in history sync") + err = user.bridge.DB.HistorySync.DeleteConversation(ctx, user.MXID, portal.Key.JID.String()) + if err != nil { + log.Err(err).Str("portal_jid", portal.Key.JID.String()). + Msg("Failed to delete history sync conversation with no messages from database") + } + } else if limit < 0 || bridgedCount < limit { + bridgedCount++ + err = portal.CreateMatrixRoom(ctx, user, nil, nil, true, true) + if err != nil { + log.Err(err).Msg("Failed to create Matrix room for backfill") } } } } -func (portal *Portal) legacyBackfill(user *User) { +func (portal *Portal) legacyBackfill(ctx context.Context, user *User) { defer portal.latestEventBackfillLock.Unlock() // This should only be called from CreateMatrixRoom which locks latestEventBackfillLock before creating the room. if portal.latestEventBackfillLock.TryLock() { panic("legacyBackfill() called without locking latestEventBackfillLock") } - // TODO use portal.zlog instead of user.zlog - log := user.zlog.With(). - Str("portal_jid", portal.Key.JID.String()). - Str("action", "legacy backfill"). - Logger() - conv := user.bridge.DB.HistorySync.GetConversation(user.MXID, portal.Key) - messages := user.bridge.DB.HistorySync.GetMessagesBetween(user.MXID, portal.Key.JID.String(), nil, nil, portal.bridge.Config.Bridge.HistorySync.MessageCount) + log := zerolog.Ctx(ctx).With().Str("action", "legacy backfill").Logger() + ctx = log.WithContext(ctx) + conv, err := user.bridge.DB.HistorySync.GetConversation(ctx, user.MXID, portal.Key) + if err != nil { + log.Err(err).Msg("Failed to get history sync conversation data for backfill") + return + } + messages, err := user.bridge.DB.HistorySync.GetMessagesBetween(ctx, user.MXID, portal.Key.JID.String(), nil, nil, portal.bridge.Config.Bridge.HistorySync.MessageCount) + if err != nil { + log.Err(err).Msg("Failed to get history sync messages for backfill") + return + } log.Debug().Int("message_count", len(messages)).Msg("Got messages to backfill from database") for i := len(messages) - 1; i >= 0; i-- { msgEvt, err := user.Client.ParseWebMessage(portal.Key.JID, messages[i]) @@ -194,18 +227,26 @@ func (portal *Portal) legacyBackfill(user *User) { Msg("Dropping historical message due to parse error") continue } - portal.handleMessage(user, msgEvt, true) + ctx := log.With(). + Str("message_id", msgEvt.Info.ID). + Stringer("message_sender", msgEvt.Info.Sender). + Logger(). + WithContext(ctx) + portal.handleMessage(ctx, user, msgEvt, true) } if conv != nil { isUnread := conv.MarkedAsUnread || conv.UnreadCount > 0 isTooOld := user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold > 0 && conv.LastMessageTimestamp.Before(time.Now().Add(time.Duration(-user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold)*time.Hour)) shouldMarkAsRead := !isUnread || isTooOld if shouldMarkAsRead { - user.markSelfReadFull(portal) + user.markSelfReadFull(ctx, portal) } } - log.Debug().Msg("Backfill complete, deleting leftover messages from database") - user.bridge.DB.HistorySync.DeleteConversation(user.MXID, portal.Key.JID.String()) + log.Info().Msg("Backfill complete, deleting leftover messages from database") + err = user.bridge.DB.HistorySync.DeleteConversation(ctx, user.MXID, portal.Key.JID.String()) + if err != nil { + log.Err(err).Msg("Failed to delete history sync conversation from database after backfill") + } } func (user *User) dailyMediaRequestLoop() { @@ -224,29 +265,49 @@ func (user *User) dailyMediaRequestLoop() { if requestStartTime.Before(now) { requestStartTime = requestStartTime.AddDate(0, 0, 1) } + log := user.zlog.With(). + Str("action", "daily media request loop"). + Logger() + ctx := log.WithContext(context.Background()) // Wait to start the loop - user.log.Infof("Waiting until %s to do media retry requests", requestStartTime) + log.Info().Time("start_loop_at", requestStartTime).Msg("Waiting until start time to do media retry requests") time.Sleep(time.Until(requestStartTime)) for { - mediaBackfillRequests := user.bridge.DB.MediaBackfillRequest.GetMediaBackfillRequestsForUser(user.MXID) - user.log.Infof("Sending %d media retry requests", len(mediaBackfillRequests)) + mediaBackfillRequests, err := user.bridge.DB.MediaBackfillRequest.GetMediaBackfillRequestsForUser(ctx, user.MXID) + if err != nil { + log.Err(err).Msg("Failed to get media retry requests") + } else if len(mediaBackfillRequests) > 0 { + log.Info().Int("media_request_count", len(mediaBackfillRequests)).Msg("Sending media retry requests") - // Send all of the media backfill requests for the user at once - for _, req := range mediaBackfillRequests { - portal := user.GetPortalByJID(req.PortalKey.JID) - _, err := portal.requestMediaRetry(user, req.EventID, req.MediaKey) - if err != nil { - user.log.Warnf("Failed to send media retry request for %s / %s", req.PortalKey.String(), req.EventID) - req.Status = database.MediaBackfillRequestStatusRequestFailed - req.Error = err.Error() - } else { - user.log.Debugfln("Sent media retry request for %s / %s", req.PortalKey.String(), req.EventID) - req.Status = database.MediaBackfillRequestStatusRequested + // Send all the media backfill requests for the user at once + for _, req := range mediaBackfillRequests { + portal := user.GetPortalByJID(req.PortalKey.JID) + _, err = portal.requestMediaRetry(ctx, user, req.EventID, req.MediaKey) + if err != nil { + log.Err(err). + Stringer("portal_key", req.PortalKey). + Stringer("event_id", req.EventID). + Msg("Failed to send media retry request") + req.Status = database.MediaBackfillRequestStatusRequestFailed + req.Error = err.Error() + } else { + log.Debug(). + Stringer("portal_key", req.PortalKey). + Stringer("event_id", req.EventID). + Msg("Sent media retry request") + req.Status = database.MediaBackfillRequestStatusRequested + } + req.MediaKey = nil + err = req.Upsert(ctx) + if err != nil { + log.Err(err). + Stringer("portal_key", req.PortalKey). + Stringer("event_id", req.EventID). + Msg("Failed to save status of media retry request") + } } - req.MediaKey = nil - req.Upsert() } // Wait for 24 hours before making requests again @@ -254,20 +315,29 @@ func (user *User) dailyMediaRequestLoop() { } } -func (user *User) backfillInChunks(req *database.Backfill, conv *database.HistorySyncConversation, portal *Portal) { +func (user *User) backfillInChunks(ctx context.Context, req *database.BackfillTask, conv *database.HistorySyncConversation, portal *Portal) { portal.backfillLock.Lock() defer portal.backfillLock.Unlock() + log := zerolog.Ctx(ctx) - if len(portal.MXID) > 0 && !user.bridge.AS.StateStore.IsInRoom(portal.MXID, user.MXID) { - portal.ensureUserInvited(user) + if len(portal.MXID) > 0 && !user.bridge.AS.StateStore.IsInRoom(ctx, portal.MXID, user.MXID) { + portal.ensureUserInvited(ctx, user) } - backfillState := user.bridge.DB.Backfill.GetBackfillState(user.MXID, &portal.Key) + backfillState, err := user.bridge.DB.BackfillState.GetBackfillState(ctx, user.MXID, portal.Key) if backfillState == nil { - backfillState = user.bridge.DB.Backfill.NewBackfillState(user.MXID, &portal.Key) + backfillState = user.bridge.DB.BackfillState.NewBackfillState(user.MXID, portal.Key) } - backfillState.SetProcessingBatch(true) - defer backfillState.SetProcessingBatch(false) + err = backfillState.SetProcessingBatch(ctx, true) + if err != nil { + log.Err(err).Msg("Failed to mark batch as being processed") + } + defer func() { + err = backfillState.SetProcessingBatch(ctx, false) + if err != nil { + log.Err(err).Msg("Failed to mark batch as no longer being processed") + } + }() var timeEnd *time.Time var forward, shouldMarkAsRead bool @@ -275,17 +345,27 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor if req.BackfillType == database.BackfillForward { // TODO this overrides the TimeStart set when enqueuing the backfill // maybe the enqueue should instead include the prev event ID - lastMessage := portal.bridge.DB.Message.GetLastInChat(portal.Key) + lastMessage, err := portal.bridge.DB.Message.GetLastInChat(ctx, portal.Key) + if err != nil { + log.Err(err).Msg("Failed to get newest message in chat") + return + } start := lastMessage.Timestamp.Add(1 * time.Second) req.TimeStart = &start // Sending events at the end of the room (= latest events) forward = true } else { - firstMessage := portal.bridge.DB.Message.GetFirstInChat(portal.Key) + firstMessage, err := portal.bridge.DB.Message.GetFirstInChat(ctx, portal.Key) + if err != nil { + log.Err(err).Msg("Failed to get oldest message in chat") + return + } if firstMessage != nil { end := firstMessage.Timestamp.Add(-1 * time.Second) timeEnd = &end - user.log.Debugfln("Limiting backfill to end at %v", end) + log.Debug(). + Time("oldest_message_ts", firstMessage.Timestamp). + Msg("Limiting backfill to messages older than oldest message") } else { // Portal is empty -> events are latest forward = true @@ -303,45 +383,48 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor isTooOld := user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold > 0 && conv.LastMessageTimestamp.Before(time.Now().Add(time.Duration(-user.bridge.Config.Bridge.HistorySync.UnreadHoursThreshold)*time.Hour)) shouldMarkAsRead = !isUnread || isTooOld } - allMsgs := user.bridge.DB.HistorySync.GetMessagesBetween(user.MXID, conv.ConversationID, req.TimeStart, timeEnd, req.MaxTotalEvents) + allMsgs, err := user.bridge.DB.HistorySync.GetMessagesBetween(ctx, user.MXID, conv.ConversationID, req.TimeStart, timeEnd, req.MaxTotalEvents) sendDisappearedNotice := false // If expired messages are on, and a notice has not been sent to this chat // about it having disappeared messages at the conversation timestamp, send // a notice indicating so. if len(allMsgs) == 0 && conv.EphemeralExpiration != nil && *conv.EphemeralExpiration > 0 { - lastMessage := portal.bridge.DB.Message.GetLastInChat(portal.Key) + lastMessage, err := portal.bridge.DB.Message.GetLastInChat(ctx, portal.Key) + if err != nil { + log.Err(err).Msg("Failed to get last message in chat to check if disappeared notice should be sent") + } if lastMessage == nil || conv.LastMessageTimestamp.After(lastMessage.Timestamp) { sendDisappearedNotice = true } } if !sendDisappearedNotice && len(allMsgs) == 0 { - user.log.Debugfln("Not backfilling %s: no bridgeable messages found", portal.Key.JID) + log.Debug().Msg("Not backfilling chat: no bridgeable messages found") return } if len(portal.MXID) == 0 { - user.log.Debugln("Creating portal for", portal.Key.JID, "as part of history sync handling") - err := portal.CreateMatrixRoom(user, nil, nil, true, false) + log.Debug().Msg("Creating portal for chat as part of history sync handling") + err = portal.CreateMatrixRoom(ctx, user, nil, nil, true, false) if err != nil { - user.log.Errorfln("Failed to create room for %s during backfill: %v", portal.Key.JID, err) + log.Err(err).Msg("Failed to create room for chat during backfill") return } } // Update the backfill status here after the room has been created. - portal.updateBackfillStatus(backfillState) + portal.updateBackfillStatus(ctx, backfillState) if sendDisappearedNotice { - user.log.Debugfln("Sending notice to %s that there are disappeared messages ending at %v", portal.Key.JID, conv.LastMessageTimestamp) - resp, err := portal.sendMessage(portal.MainIntent(), event.EventMessage, &event.MessageEventContent{ + log.Debug().Time("last_message_time", conv.LastMessageTimestamp). + Msg("Sending notice that there are disappeared messages in the chat") + resp, err := portal.sendMessage(ctx, portal.MainIntent(), event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgNotice, Body: portal.formatDisappearingMessageNotice(), }, nil, conv.LastMessageTimestamp.UnixMilli()) - if err != nil { - portal.log.Errorln("Error sending disappearing messages notice event") + log.Err(err).Msg("Failed to send disappeared messages notice event") return } @@ -353,12 +436,18 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor msg.SenderMXID = portal.MainIntent().UserID msg.Sent = true msg.Type = database.MsgFake - msg.Insert(nil) - user.markSelfReadFull(portal) + err = msg.Insert(ctx) + if err != nil { + log.Err(err).Msg("Failed to save fake message entry for disappearing message timer in backfill") + } + user.markSelfReadFull(ctx, portal) return } - user.log.Infofln("Backfilling %d messages in %s, %d messages at a time (queue ID: %d)", len(allMsgs), portal.Key.JID, req.MaxBatchEvents, req.QueueID) + log.Info(). + Int("message_count", len(allMsgs)). + Int("max_batch_events", req.MaxBatchEvents). + Msg("Backfilling messages") toBackfill := allMsgs[0:] for len(toBackfill) > 0 { var msgs []*waProto.WebMessageInfo @@ -372,14 +461,14 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor if len(msgs) > 0 { time.Sleep(time.Duration(req.BatchDelay) * time.Second) - user.log.Debugfln("Backfilling %d messages in %s (queue ID: %d)", len(msgs), portal.Key.JID, req.QueueID) - portal.backfill(user, msgs, forward, shouldMarkAsRead) + log.Debug().Int("batch_message_count", len(msgs)).Msg("Backfilling message batch") + portal.backfill(ctx, user, msgs, forward, shouldMarkAsRead) } } - user.log.Debugfln("Finished backfilling %d messages in %s (queue ID: %d)", len(allMsgs), portal.Key.JID, req.QueueID) - err := user.bridge.DB.HistorySync.DeleteMessages(user.MXID, conv.ConversationID, allMsgs) + log.Debug().Int("message_count", len(allMsgs)).Msg("Finished backfilling messages in queue entry") + err = user.bridge.DB.HistorySync.DeleteMessages(ctx, user.MXID, conv.ConversationID, allMsgs) if err != nil { - user.log.Warnfln("Failed to delete %d history sync messages after backfilling (queue ID: %d): %v", len(allMsgs), req.QueueID, err) + log.Err(err).Msg("Failed to delete history sync messages after backfilling") } if req.TimeStart == nil { @@ -399,8 +488,11 @@ func (user *User) backfillInChunks(req *database.Backfill, conv *database.Histor // beginning of time. backfillState.FirstExpectedTimestamp = 0 } - backfillState.Upsert() - portal.updateBackfillStatus(backfillState) + err = backfillState.Upsert(ctx) + if err != nil { + log.Err(err).Msg("Failed to mark backfill state as completed in database") + } + portal.updateBackfillStatus(ctx, backfillState) } } @@ -408,13 +500,13 @@ func (user *User) storeHistorySync(evt *waProto.HistorySync) { if evt == nil || evt.SyncType == nil { return } - log := user.bridge.ZLog.With(). + log := user.zlog.With(). Str("method", "User.storeHistorySync"). - Str("user_id", user.MXID.String()). Str("sync_type", evt.GetSyncType().String()). Uint32("chunk_order", evt.GetChunkOrder()). Uint32("progress", evt.GetProgress()). Logger() + ctx := log.WithContext(context.TODO()) if evt.GetGlobalSettings() != nil { log.Debug().Interface("global_settings", evt.GetGlobalSettings()).Msg("Got global settings in history sync") } @@ -466,7 +558,7 @@ func (user *User) storeHistorySync(evt *waProto.HistorySync) { historySyncConversation := user.bridge.DB.HistorySync.NewConversationWithValues( user.MXID, conv.GetId(), - &portal.Key, + portal.Key, getConversationTimestamp(conv), conv.GetMuteEndTime(), conv.GetArchived(), @@ -476,7 +568,10 @@ func (user *User) storeHistorySync(evt *waProto.HistorySync) { conv.EphemeralExpiration, conv.GetMarkedAsUnread(), conv.GetUnreadCount()) - historySyncConversation.Upsert() + err := historySyncConversation.Upsert(ctx) + if err != nil { + log.Err(err).Msg("Failed to insert history sync conversation into database") + } } var minTime, maxTime time.Time @@ -521,7 +616,7 @@ func (user *User) storeHistorySync(evt *waProto.HistorySync) { Msg("Failed to save historical message") continue } - err = message.Insert() + err = message.Insert(ctx) if err != nil { log.Error().Err(err). Int("msg_index", i). @@ -570,15 +665,20 @@ func getConversationTimestamp(conv *waProto.Conversation) uint64 { return convTs } -func (user *User) EnqueueImmediateBackfills(portals []*Portal) { +func (user *User) EnqueueImmediateBackfills(ctx context.Context, portals []*Portal) { for priority, portal := range portals { maxMessages := user.bridge.Config.Bridge.HistorySync.Immediate.MaxEvents - initialBackfill := user.bridge.DB.Backfill.NewWithValues(user.MXID, database.BackfillImmediate, priority, &portal.Key, nil, maxMessages, maxMessages, 0) - initialBackfill.Insert() + initialBackfill := user.bridge.DB.BackfillQueue.NewWithValues(user.MXID, database.BackfillImmediate, priority, portal.Key, nil, maxMessages, maxMessages, 0) + err := initialBackfill.Insert(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("portal_key", portal.Key). + Msg("Failed to insert immediate backfill into database") + } } } -func (user *User) EnqueueDeferredBackfills(portals []*Portal) { +func (user *User) EnqueueDeferredBackfills(ctx context.Context, portals []*Portal) { numPortals := len(portals) for stageIdx, backfillStage := range user.bridge.Config.Bridge.HistorySync.Deferred { for portalIdx, portal := range portals { @@ -587,22 +687,36 @@ func (user *User) EnqueueDeferredBackfills(portals []*Portal) { startDaysAgo := time.Now().AddDate(0, 0, -backfillStage.StartDaysAgo) startDate = &startDaysAgo } - backfillMessages := user.bridge.DB.Backfill.NewWithValues( - user.MXID, database.BackfillDeferred, stageIdx*numPortals+portalIdx, &portal.Key, startDate, backfillStage.MaxBatchEvents, -1, backfillStage.BatchDelay) - backfillMessages.Insert() + backfillMessages := user.bridge.DB.BackfillQueue.NewWithValues( + user.MXID, database.BackfillDeferred, stageIdx*numPortals+portalIdx, portal.Key, startDate, backfillStage.MaxBatchEvents, -1, backfillStage.BatchDelay) + err := backfillMessages.Insert(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("portal_key", portal.Key). + Msg("Failed to insert deferred backfill into database") + } } } } -func (user *User) EnqueueForwardBackfills(portals []*Portal) { +func (user *User) EnqueueForwardBackfills(ctx context.Context, portals []*Portal) { for priority, portal := range portals { - lastMsg := user.bridge.DB.Message.GetLastInChat(portal.Key) - if lastMsg == nil { + lastMsg, err := user.bridge.DB.Message.GetLastInChat(ctx, portal.Key) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("portal_key", portal.Key). + Msg("Failed to get last message in chat to enqueue forward backfill") + } else if lastMsg == nil { continue } - backfill := user.bridge.DB.Backfill.NewWithValues( - user.MXID, database.BackfillForward, priority, &portal.Key, &lastMsg.Timestamp, -1, -1, 0) - backfill.Insert() + backfill := user.bridge.DB.BackfillQueue.NewWithValues( + user.MXID, database.BackfillForward, priority, portal.Key, &lastMsg.Timestamp, -1, -1, 0) + err = backfill.Insert(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("portal_key", portal.Key). + Msg("Failed to insert forward backfill into database") + } } } @@ -619,12 +733,11 @@ func (portal *Portal) deterministicEventID(sender types.JID, messageID types.Mes } var ( - PortalCreationDummyEvent = event.Type{Type: "fi.mau.dummy.portal_created", Class: event.MessageEventType} - BackfillStatusEvent = event.Type{Type: "com.beeper.backfill_status", Class: event.StateEventType} ) -func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo, isForward, atomicMarkAsRead bool) *mautrix.RespBeeperBatchSend { +func (portal *Portal) backfill(ctx context.Context, source *User, messages []*waProto.WebMessageInfo, isForward, atomicMarkAsRead bool) { + log := zerolog.Ctx(ctx) var req mautrix.ReqBeeperBatchSend var infos []*wrappedInfo @@ -633,7 +746,10 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo, req.MarkReadBy = source.MXID } - portal.log.Infofln("Processing history sync with %d messages (forward: %t)", len(messages), isForward) + log.Info(). + Bool("forward", isForward). + Int("message_count", len(messages)). + Msg("Processing history sync message batch") // The messages are ordered newest to oldest, so iterate them in reverse order. for i := len(messages) - 1; i >= 0; i-- { webMsg := messages[i] @@ -641,11 +757,16 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo, if err != nil { continue } + log := log.With(). + Str("message_id", msgEvt.Info.ID). + Stringer("message_sender", msgEvt.Info.Sender). + Logger() + ctx := log.WithContext(ctx) msgType := getMessageType(msgEvt.Message) if msgType == "unknown" || msgType == "ignore" || msgType == "unknown_protocol" { if msgType != "ignore" { - portal.log.Debugfln("Skipping message %s with unknown type in backfill", msgEvt.Info.ID) + log.Debug().Msg("Skipping message with unknown type in backfill") } continue } @@ -654,85 +775,83 @@ func (portal *Portal) backfill(source *User, messages []*waProto.WebMessageInfo, if !existingContact.Found || existingContact.PushName == "" { changed, _, err := source.Client.Store.Contacts.PutPushName(msgEvt.Info.Sender, webMsg.GetPushName()) if err != nil { - source.log.Errorfln("Failed to save push name of %s from historical message in device store: %v", msgEvt.Info.Sender, err) + log.Err(err).Msg("Failed to save push name from historical message to device store") } else if changed { - source.log.Debugfln("Got push name %s for %s from historical message", webMsg.GetPushName(), msgEvt.Info.Sender) + log.Debug().Str("push_name", webMsg.GetPushName()).Msg("Got push name from historical message") } } } - puppet := portal.getMessagePuppet(source, &msgEvt.Info) + puppet := portal.getMessagePuppet(ctx, source, &msgEvt.Info) if puppet == nil { continue } - converted := portal.convertMessage(puppet.IntentFor(portal), source, &msgEvt.Info, msgEvt.Message, true) + converted := portal.convertMessage(ctx, puppet.IntentFor(portal), source, &msgEvt.Info, msgEvt.Message, true) if converted == nil { - portal.log.Debugfln("Skipping unsupported message %s in backfill", msgEvt.Info.ID) + log.Debug().Msg("Skipping unsupported message in backfill") continue } if converted.ReplyTo != nil { - portal.SetReply(msgEvt.Info.ID, converted.Content, converted.ReplyTo, true) + portal.SetReply(ctx, converted.Content, converted.ReplyTo, true) } - err = portal.appendBatchEvents(source, converted, &msgEvt.Info, webMsg, &req.Events, &infos) + err = portal.appendBatchEvents(ctx, source, converted, &msgEvt.Info, webMsg, &req.Events, &infos) if err != nil { - portal.log.Errorfln("Error handling message %s during backfill: %v", msgEvt.Info.ID, err) + log.Err(err).Msg("Failed to handle message in backfill") } } - portal.log.Infofln("Made %d Matrix events from messages in batch", len(req.Events)) + log.Info().Int("event_count", len(req.Events)).Msg("Made Matrix events from messages in batch") if len(req.Events) == 0 { - return nil + return } - resp, err := portal.MainIntent().BeeperBatchSend(portal.MXID, &req) + resp, err := portal.MainIntent().BeeperBatchSend(ctx, portal.MXID, &req) if err != nil { - portal.log.Errorln("Error batch sending messages:", err) - return nil - } else { - txn, err := portal.bridge.DB.Begin() - if err != nil { - portal.log.Errorln("Failed to start transaction to save batch messages:", err) - return nil - } - - portal.finishBatch(txn, resp.EventIDs, infos) - - err = txn.Commit() - if err != nil { - portal.log.Errorln("Failed to commit transaction to save batch messages:", err) - return nil - } - if portal.bridge.Config.Bridge.HistorySync.MediaRequests.AutoRequestMedia { - go portal.requestMediaRetries(source, resp.EventIDs, infos) - } - return resp + log.Err(err).Msg("Failed to send batch of messages") + return + } + err = portal.bridge.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + return portal.finishBatch(ctx, resp.EventIDs, infos) + }) + if err != nil { + log.Err(err).Msg("Failed to save message batch to database") + return + } + log.Info().Msg("Successfully sent backfill batch") + if portal.bridge.Config.Bridge.HistorySync.MediaRequests.AutoRequestMedia { + go portal.requestMediaRetries(context.TODO(), source, resp.EventIDs, infos) } } -func (portal *Portal) requestMediaRetries(source *User, eventIDs []id.EventID, infos []*wrappedInfo) { +func (portal *Portal) requestMediaRetries(ctx context.Context, source *User, eventIDs []id.EventID, infos []*wrappedInfo) { for i, info := range infos { if info != nil && info.Error == database.MsgErrMediaNotFound && info.MediaKey != nil { switch portal.bridge.Config.Bridge.HistorySync.MediaRequests.RequestMethod { case config.MediaRequestMethodImmediate: err := source.Client.SendMediaRetryReceipt(info.MessageInfo, info.MediaKey) if err != nil { - portal.log.Warnfln("Failed to send post-backfill media retry request for %s: %v", info.ID, err) + portal.zlog.Err(err).Str("message_id", info.ID).Msg("Failed to send post-backfill media retry request") } else { - portal.log.Debugfln("Sent post-backfill media retry request for %s", info.ID) + portal.zlog.Debug().Str("message_id", info.ID).Msg("Sent post-backfill media retry request") } case config.MediaRequestMethodLocalTime: - req := portal.bridge.DB.MediaBackfillRequest.NewMediaBackfillRequestWithValues(source.MXID, &portal.Key, eventIDs[i], info.MediaKey) - req.Upsert() + req := portal.bridge.DB.MediaBackfillRequest.NewMediaBackfillRequestWithValues(source.MXID, portal.Key, eventIDs[i], info.MediaKey) + err := req.Upsert(ctx) + if err != nil { + portal.zlog.Err(err). + Stringer("event_id", eventIDs[i]). + Msg("Failed to upsert media backfill request") + } } } } } -func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessage, info *types.MessageInfo, raw *waProto.WebMessageInfo, eventsArray *[]*event.Event, infoArray *[]*wrappedInfo) error { +func (portal *Portal) appendBatchEvents(ctx context.Context, source *User, converted *ConvertedMessage, info *types.MessageInfo, raw *waProto.WebMessageInfo, eventsArray *[]*event.Event, infoArray *[]*wrappedInfo) error { if portal.bridge.Config.Bridge.CaptionInMessage { converted.MergeCaption() } - mainEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, converted.Content, converted.Extra, "") + mainEvt, err := portal.wrapBatchEvent(ctx, info, converted.Intent, converted.Type, converted.Content, converted.Extra, "") if err != nil { return err } @@ -750,7 +869,7 @@ func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessag ExpiresIn: converted.ExpiresIn, } if converted.Caption != nil { - captionEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, converted.Caption, nil, "caption") + captionEvt, err := portal.wrapBatchEvent(ctx, info, converted.Intent, converted.Type, converted.Caption, nil, "caption") if err != nil { return err } @@ -762,7 +881,7 @@ func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessag } if converted.MultiEvent != nil { for i, subEvtContent := range converted.MultiEvent { - subEvt, err := portal.wrapBatchEvent(info, converted.Intent, converted.Type, subEvtContent, nil, fmt.Sprintf("multi-%d", i)) + subEvt, err := portal.wrapBatchEvent(ctx, info, converted.Intent, converted.Type, subEvtContent, nil, fmt.Sprintf("multi-%d", i)) if err != nil { return err } @@ -771,7 +890,7 @@ func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessag } } for _, reaction := range raw.GetReactions() { - reactionEvent, reactionInfo := portal.wrapBatchReaction(source, reaction, mainEvt.ID, info.Timestamp) + reactionEvent, reactionInfo := portal.wrapBatchReaction(ctx, source, reaction, mainEvt.ID, info.Timestamp) if reactionEvent != nil { *eventsArray = append(*eventsArray, reactionEvent) *infoArray = append(*infoArray, &wrappedInfo{ @@ -785,7 +904,7 @@ func (portal *Portal) appendBatchEvents(source *User, converted *ConvertedMessag return nil } -func (portal *Portal) wrapBatchReaction(source *User, reaction *waProto.Reaction, mainEventID id.EventID, mainEventTS time.Time) (reactionEvent *event.Event, reactionInfo *types.MessageInfo) { +func (portal *Portal) wrapBatchReaction(ctx context.Context, source *User, reaction *waProto.Reaction, mainEventID id.EventID, mainEventTS time.Time) (reactionEvent *event.Event, reactionInfo *types.MessageInfo) { var senderJID types.JID if reaction.GetKey().GetFromMe() { senderJID = source.JID.ToNonAD() @@ -807,7 +926,7 @@ func (portal *Portal) wrapBatchReaction(source *User, reaction *waProto.Reaction ID: reaction.GetKey().GetId(), Timestamp: mainEventTS, } - puppet := portal.getMessagePuppet(source, reactionInfo) + puppet := portal.getMessagePuppet(ctx, source, reactionInfo) if puppet == nil { return } @@ -834,12 +953,12 @@ func (portal *Portal) wrapBatchReaction(source *User, reaction *waProto.Reaction return } -func (portal *Portal) wrapBatchEvent(info *types.MessageInfo, intent *appservice.IntentAPI, eventType event.Type, content *event.MessageEventContent, extraContent map[string]interface{}, partName string) (*event.Event, error) { +func (portal *Portal) wrapBatchEvent(ctx context.Context, info *types.MessageInfo, intent *appservice.IntentAPI, eventType event.Type, content *event.MessageEventContent, extraContent map[string]interface{}, partName string) (*event.Event, error) { wrappedContent := event.Content{ Parsed: content, Raw: extraContent, } - newEventType, err := portal.encrypt(intent, &wrappedContent, eventType) + newEventType, err := portal.encrypt(ctx, intent, &wrappedContent, eventType) if err != nil { return nil, err } @@ -853,37 +972,37 @@ func (portal *Portal) wrapBatchEvent(info *types.MessageInfo, intent *appservice }, nil } -func (portal *Portal) finishBatch(txn dbutil.Transaction, eventIDs []id.EventID, infos []*wrappedInfo) { +func (portal *Portal) finishBatch(ctx context.Context, eventIDs []id.EventID, infos []*wrappedInfo) error { for i, info := range infos { if info == nil { continue } eventID := eventIDs[i] - portal.markHandled(txn, nil, info.MessageInfo, eventID, info.SenderMXID, true, false, info.Type, 0, info.Error) + portal.markHandled(ctx, nil, info.MessageInfo, eventID, info.SenderMXID, true, false, info.Type, 0, info.Error) if info.Type == database.MsgReaction { - portal.upsertReaction(txn, nil, info.ReactionTarget, info.Sender, eventID, info.ID) + portal.upsertReaction(ctx, nil, info.ReactionTarget, info.Sender, eventID, info.ID) } if info.ExpiresIn > 0 { - portal.MarkDisappearing(txn, eventID, info.ExpiresIn, info.ExpirationStart) + portal.MarkDisappearing(ctx, eventID, info.ExpiresIn, info.ExpirationStart) } } - portal.log.Infofln("Successfully sent %d events", len(eventIDs)) + return nil } -func (portal *Portal) updateBackfillStatus(backfillState *database.BackfillState) { +func (portal *Portal) updateBackfillStatus(ctx context.Context, backfillState *database.BackfillState) { backfillStatus := "backfilling" if backfillState.BackfillComplete { backfillStatus = "complete" } - _, err := portal.bridge.Bot.SendStateEvent(portal.MXID, BackfillStatusEvent, "", map[string]interface{}{ + _, err := portal.bridge.Bot.SendStateEvent(ctx, portal.MXID, BackfillStatusEvent, "", map[string]interface{}{ "status": backfillStatus, "first_timestamp": backfillState.FirstExpectedTimestamp * 1000, }) if err != nil { - portal.log.Errorln("Error sending backfill status event:", err) + zerolog.Ctx(ctx).Err(err).Msg("Failed to send backfill status event to room") } } diff --git a/main.go b/main.go index c8958af..6f75fcf 100644 --- a/main.go +++ b/main.go @@ -17,6 +17,7 @@ package main import ( + "context" _ "embed" "net/http" "net/url" @@ -26,15 +27,18 @@ import ( "sync" "time" + "github.com/rs/zerolog" + waLog "go.mau.fi/whatsmeow/util/log" "google.golang.org/protobuf/proto" - "go.mau.fi/util/configupgrade" "go.mau.fi/whatsmeow" waProto "go.mau.fi/whatsmeow/binary/proto" "go.mau.fi/whatsmeow/store" "go.mau.fi/whatsmeow/store/sqlstore" "go.mau.fi/whatsmeow/types" + "go.mau.fi/util/configupgrade" + "maunium.net/go/mautrix/bridge" "maunium.net/go/mautrix/bridge/commands" "maunium.net/go/mautrix/bridge/status" @@ -91,7 +95,7 @@ func (br *WABridge) Init() { br.EventProcessor.On(TypeMSC3381PollResponse, br.MatrixHandler.HandleMessage) br.EventProcessor.On(TypeMSC3381V2PollResponse, br.MatrixHandler.HandleMessage) - Analytics.log = br.Log.Sub("Analytics") + Analytics.log = br.ZLog.With().Str("component", "analytics").Logger() Analytics.url = (&url.URL{ Scheme: "https", Host: br.Config.Analytics.Host, @@ -100,23 +104,20 @@ func (br *WABridge) Init() { Analytics.key = br.Config.Analytics.Token Analytics.userID = br.Config.Analytics.UserID if Analytics.IsEnabled() { - Analytics.log.Infoln("Analytics metrics are enabled") - if Analytics.userID != "" { - Analytics.log.Infoln("Overriding analytics user_id with %v", Analytics.userID) - } + Analytics.log.Info().Str("override_user_id", Analytics.userID).Msg("Analytics metrics are enabled") } - br.DB = database.New(br.Bridge.DB, br.Log.Sub("Database")) - br.WAContainer = sqlstore.NewWithDB(br.DB.RawDB, br.DB.Dialect.String(), &waLogger{br.Log.Sub("Database").Sub("WhatsApp")}) + br.DB = database.New(br.Bridge.DB) + br.WAContainer = sqlstore.NewWithDB(br.DB.RawDB, br.DB.Dialect.String(), waLog.Zerolog(br.ZLog.With().Str("db_section", "whatsmeow").Logger())) br.WAContainer.DatabaseErrorHandler = br.DB.HandleSignalStoreError ss := br.Config.Bridge.Provisioning.SharedSecret if len(ss) > 0 && ss != "disable" { - br.Provisioning = &ProvisioningAPI{bridge: br} + br.Provisioning = &ProvisioningAPI{bridge: br, log: br.ZLog.With().Str("component", "provisioning").Logger()} } br.Formatter = NewFormatter(br) - br.Metrics = NewMetricsHandler(br.Config.Metrics.Listen, br.Log.Sub("Metrics"), br.DB) + br.Metrics = NewMetricsHandler(br.Config.Metrics.Listen, br.ZLog.With().Str("component", "metrics").Logger(), br.DB) br.MatrixHandler.TrackEventDuration = br.Metrics.TrackMatrixEvent store.BaseClientPayload.UserAgent.OsVersion = proto.String(br.WAVersion) @@ -148,11 +149,10 @@ func (br *WABridge) Init() { func (br *WABridge) Start() { err := br.WAContainer.Upgrade() if err != nil { - br.Log.Fatalln("Failed to upgrade whatsmeow database: %v", err) + br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to upgrade whatsmeow database") os.Exit(15) } if br.Provisioning != nil { - br.Log.Debugln("Initializing provisioning API") br.Provisioning.Init() } go br.CheckWhatsAppUpdate() @@ -166,30 +166,40 @@ func (br *WABridge) Start() { } func (br *WABridge) CheckWhatsAppUpdate() { - br.Log.Debugfln("Checking for WhatsApp web update") + br.ZLog.Debug().Msg("Checking for WhatsApp web update") resp, err := whatsmeow.CheckUpdate(http.DefaultClient) if err != nil { - br.Log.Warnfln("Failed to check for WhatsApp web update: %v", err) + br.ZLog.Warn().Err(err).Msg("Failed to check for WhatsApp web update") return } if store.GetWAVersion() == resp.ParsedVersion { - br.Log.Debugfln("Bridge is using latest WhatsApp web protocol") + br.ZLog.Debug().Msg("Bridge is using latest WhatsApp web protocol") } else if store.GetWAVersion().LessThan(resp.ParsedVersion) { if resp.IsBelowHard || resp.IsBroken { - br.Log.Warnfln("Bridge is using outdated WhatsApp web protocol and probably doesn't work anymore (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion) + br.ZLog.Warn(). + Stringer("latest_version", resp.ParsedVersion). + Stringer("current_version", store.GetWAVersion()). + Msg("Bridge is using outdated WhatsApp web protocol and probably doesn't work anymore") } else if resp.IsBelowSoft { - br.Log.Infofln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion) + br.ZLog.Info(). + Stringer("latest_version", resp.ParsedVersion). + Stringer("current_version", store.GetWAVersion()). + Msg("Bridge is using outdated WhatsApp web protocol") } else { - br.Log.Debugfln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion) + br.ZLog.Debug(). + Stringer("latest_version", resp.ParsedVersion). + Stringer("current_version", store.GetWAVersion()). + Msg("Bridge is using outdated WhatsApp web protocol") } } else { - br.Log.Debugfln("Bridge is using newer than latest WhatsApp web protocol") + br.ZLog.Debug().Msg("Bridge is using newer than latest WhatsApp web protocol") } } func (br *WABridge) Loop() { + ctx := br.ZLog.With().Str("action", "background loop").Logger().WithContext(context.TODO()) for { - br.SleepAndDeleteUpcoming() + br.SleepAndDeleteUpcoming(ctx) time.Sleep(1 * time.Hour) br.WarnUsersAboutDisconnection() } @@ -199,14 +209,14 @@ func (br *WABridge) WarnUsersAboutDisconnection() { br.usersLock.Lock() for _, user := range br.usersByUsername { if user.IsConnected() && !user.PhoneRecentlySeen(true) { - go user.sendPhoneOfflineWarning() + go user.sendPhoneOfflineWarning(context.TODO()) } } br.usersLock.Unlock() } func (br *WABridge) StartUsers() { - br.Log.Debugln("Starting users") + br.ZLog.Debug().Msg("Starting users") foundAnySessions := false for _, user := range br.GetAllUsers() { if !user.JID.IsEmpty() { @@ -217,13 +227,13 @@ func (br *WABridge) StartUsers() { if !foundAnySessions { br.SendGlobalBridgeState(status.BridgeState{StateEvent: status.StateUnconfigured}.Fill(nil)) } - br.Log.Debugln("Starting custom puppets") + br.ZLog.Debug().Msg("Starting custom puppets") for _, loopuppet := range br.GetAllPuppetsWithCustomMXID() { go func(puppet *Puppet) { - puppet.log.Debugln("Starting custom puppet", puppet.CustomMXID) + puppet.zlog.Debug().Stringer("custom_mxid", puppet.CustomMXID).Msg("Starting double puppet") err := puppet.StartCustomMXID(true) if err != nil { - puppet.log.Errorln("Failed to start custom puppet:", err) + puppet.zlog.Err(err).Stringer("custom_mxid", puppet.CustomMXID).Msg("Failed to start double puppet") } }(loopuppet) } @@ -235,7 +245,7 @@ func (br *WABridge) Stop() { if user.Client == nil { continue } - br.Log.Debugln("Disconnecting", user.MXID) + user.zlog.Debug().Msg("Disconnecting user") user.Client.Disconnect() close(user.historySyncs) } diff --git a/matrix.go b/matrix.go index 57552ad..38fafb8 100644 --- a/matrix.go +++ b/matrix.go @@ -1,5 +1,5 @@ // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2022 Tulir Asokan +// Copyright (C) 2024 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -17,8 +17,10 @@ package main import ( + "context" "fmt" + "github.com/rs/zerolog" "go.mau.fi/whatsmeow/types" "maunium.net/go/mautrix" @@ -35,77 +37,89 @@ func (br *WABridge) CreatePrivatePortal(roomID id.RoomID, brInviter bridge.User, puppet := brGhost.(*Puppet) key := database.NewPortalKey(puppet.JID, inviter.JID) portal := br.GetPortalByJID(key) + log := br.ZLog.With(). + Str("action", "create private portal"). + Stringer("target_room_id", roomID). + Stringer("inviter_mxid", inviter.MXID). + Stringer("invitee_jid", puppet.JID). + Logger() + ctx := log.WithContext(context.TODO()) if len(portal.MXID) == 0 { - br.createPrivatePortalFromInvite(roomID, inviter, puppet, portal) + br.createPrivatePortalFromInvite(ctx, roomID, inviter, puppet, portal) return } - ok := portal.ensureUserInvited(inviter) + ok := portal.ensureUserInvited(ctx, inviter) if !ok { - br.Log.Warnfln("Failed to invite %s to existing private chat portal %s with %s. Redirecting portal to new room...", inviter.MXID, portal.MXID, puppet.JID) - br.createPrivatePortalFromInvite(roomID, inviter, puppet, portal) + log.Warn().Msg("Failed to invite user to existing private chat portal. Redirecting portal to new room...") + br.createPrivatePortalFromInvite(ctx, roomID, inviter, puppet, portal) return } intent := puppet.DefaultIntent() - errorMessage := fmt.Sprintf("You already have a private chat portal with me at [%[1]s](https://matrix.to/#/%[1]s)", portal.MXID) + errorMessage := fmt.Sprintf("You already have a private chat portal with me at [%s](%s)", portal.MXID, portal.MXID.URI(br.Config.Homeserver.Domain).MatrixToURL()) errorContent := format.RenderMarkdown(errorMessage, true, false) - _, _ = intent.SendMessageEvent(roomID, event.EventMessage, errorContent) - br.Log.Debugfln("Leaving private chat room %s as %s after accepting invite from %s as we already have chat with the user", roomID, puppet.MXID, inviter.MXID) - _, _ = intent.LeaveRoom(roomID) + _, _ = intent.SendMessageEvent(ctx, roomID, event.EventMessage, errorContent) + log.Debug().Msg("Leaving private chat room from invite as we already have chat with the user") + _, _ = intent.LeaveRoom(ctx, roomID) } -func (br *WABridge) createPrivatePortalFromInvite(roomID id.RoomID, inviter *User, puppet *Puppet, portal *Portal) { +func (br *WABridge) createPrivatePortalFromInvite(ctx context.Context, roomID id.RoomID, inviter *User, puppet *Puppet, portal *Portal) { + log := zerolog.Ctx(ctx) // TODO check if room is already encrypted var existingEncryption event.EncryptionEventContent var encryptionEnabled bool - err := portal.MainIntent().StateEvent(roomID, event.StateEncryption, "", &existingEncryption) + err := portal.MainIntent().StateEvent(ctx, roomID, event.StateEncryption, "", &existingEncryption) if err != nil { - portal.log.Warnfln("Failed to check if encryption is enabled in private chat room %s", roomID) + log.Err(err).Msg("Failed to check if encryption is enabled") } else { encryptionEnabled = existingEncryption.Algorithm == id.AlgorithmMegolmV1 } portal.MXID = roomID + portal.updateLogger() portal.Topic = PrivateChatTopic portal.Name = puppet.Displayname portal.AvatarURL = puppet.AvatarURL portal.Avatar = puppet.Avatar - portal.log.Infofln("Created private chat portal in %s after invite from %s", roomID, inviter.MXID) + log.Info().Msg("Created private chat portal from invite") intent := puppet.DefaultIntent() if br.Config.Bridge.Encryption.Default || encryptionEnabled { - _, err := intent.InviteUser(roomID, &mautrix.ReqInviteUser{UserID: br.Bot.UserID}) + _, err = intent.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{UserID: br.Bot.UserID}) if err != nil { - portal.log.Warnln("Failed to invite bridge bot to enable e2be:", err) + log.Err(err).Msg("Failed to invite bridge bot to enable e2be") } - err = br.Bot.EnsureJoined(roomID) + err = br.Bot.EnsureJoined(ctx, roomID) if err != nil { - portal.log.Warnln("Failed to join as bridge bot to enable e2be:", err) + log.Err(err).Msg("Failed to join as bridge bot to enable e2be") } if !encryptionEnabled { - _, err = intent.SendStateEvent(roomID, event.StateEncryption, "", portal.GetEncryptionEventContent()) + _, err = intent.SendStateEvent(ctx, roomID, event.StateEncryption, "", portal.GetEncryptionEventContent()) if err != nil { - portal.log.Warnln("Failed to enable e2be:", err) + log.Err(err).Msg("Failed to enable e2be") } } - br.AS.StateStore.SetMembership(roomID, inviter.MXID, event.MembershipJoin) - br.AS.StateStore.SetMembership(roomID, puppet.MXID, event.MembershipJoin) - br.AS.StateStore.SetMembership(roomID, br.Bot.UserID, event.MembershipJoin) + br.AS.StateStore.SetMembership(ctx, roomID, inviter.MXID, event.MembershipJoin) + br.AS.StateStore.SetMembership(ctx, roomID, puppet.MXID, event.MembershipJoin) + br.AS.StateStore.SetMembership(ctx, roomID, br.Bot.UserID, event.MembershipJoin) portal.Encrypted = true } - _, _ = portal.MainIntent().SetRoomTopic(portal.MXID, portal.Topic) + _, _ = portal.MainIntent().SetRoomTopic(ctx, portal.MXID, portal.Topic) if portal.shouldSetDMRoomMetadata() { - _, err = portal.MainIntent().SetRoomName(portal.MXID, portal.Name) + _, err = portal.MainIntent().SetRoomName(ctx, portal.MXID, portal.Name) portal.NameSet = err == nil - _, err = portal.MainIntent().SetRoomAvatar(portal.MXID, portal.AvatarURL) + _, err = portal.MainIntent().SetRoomAvatar(ctx, portal.MXID, portal.AvatarURL) portal.AvatarSet = err == nil } - portal.Update(nil) - portal.UpdateBridgeInfo() - _, _ = intent.SendNotice(roomID, "Private chat portal created") + err = portal.Update(ctx) + if err != nil { + log.Err(err).Msg("Failed to save portal to database after creating from invite") + } + portal.UpdateBridgeInfo(ctx) + _, _ = intent.SendNotice(ctx, roomID, "Private chat portal created") } -func (br *WABridge) HandlePresence(evt *event.Event) { +func (br *WABridge) HandlePresence(ctx context.Context, evt *event.Event) { user := br.GetUserByMXIDIfExists(evt.Sender) if user == nil || !user.IsLoggedIn() { return @@ -119,15 +133,15 @@ func (br *WABridge) HandlePresence(evt *event.Event) { presence := types.PresenceAvailable if evt.Content.AsPresence().Presence != event.PresenceOnline { presence = types.PresenceUnavailable - user.log.Debugln("Marking offline") + user.zlog.Debug().Msg("Marking offline") } else { - user.log.Debugln("Marking online") + user.zlog.Debug().Msg("Marking online") } user.lastPresence = presence if user.Client.Store.PushName != "" { err := user.Client.SendPresence(presence) if err != nil { - user.log.Warnln("Failed to set presence:", err) + user.zlog.Err(err).Msg("Failed to set presence") } } } diff --git a/messagetracking.go b/messagetracking.go index 9e79f66..9bac41e 100644 --- a/messagetracking.go +++ b/messagetracking.go @@ -1,5 +1,5 @@ // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2022 Tulir Asokan +// Copyright (C) 2024 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -23,7 +23,7 @@ import ( "sync" "time" - log "maunium.net/go/maulogger/v2" + "github.com/rs/zerolog" "go.mau.fi/whatsmeow" @@ -123,7 +123,7 @@ func errorToStatusReason(err error) (reason event.MessageStatusReason, status ev } } -func (portal *Portal) sendErrorMessage(evt *event.Event, err error, msgType string, confirmed bool, editID id.EventID) id.EventID { +func (portal *Portal) sendErrorMessage(ctx context.Context, evt *event.Event, err error, confirmed bool, editID id.EventID) id.EventID { if !portal.bridge.Config.Bridge.MessageErrorNotices { return "" } @@ -131,6 +131,21 @@ func (portal *Portal) sendErrorMessage(evt *event.Event, err error, msgType stri if confirmed { certainty = "was not" } + var msgType string + switch evt.Type { + case event.EventMessage: + msgType = "message" + case event.EventReaction: + msgType = "reaction" + case event.EventRedaction: + msgType = "redaction" + case TypeMSC3381PollResponse, TypeMSC3381V2PollResponse: + msgType = "poll response" + case TypeMSC3381PollStart: + msgType = "poll start" + default: + msgType = "unknown event" + } msg := fmt.Sprintf("\u26a0 Your %s %s bridged: %v", msgType, certainty, err) if errors.Is(err, errMessageTakingLong) { msg = fmt.Sprintf("\u26a0 Bridging your %s is taking longer than usual", msgType) @@ -144,15 +159,15 @@ func (portal *Portal) sendErrorMessage(evt *event.Event, err error, msgType stri } else { content.SetReply(evt) } - resp, err := portal.sendMainIntentMessage(content) + resp, err := portal.sendMainIntentMessage(ctx, content) if err != nil { - portal.log.Warnfln("Failed to send bridging error message:", err) + zerolog.Ctx(ctx).Err(err).Msg("Failed to send bridging error message") return "" } return resp.EventID } -func (portal *Portal) sendStatusEvent(evtID, lastRetry id.EventID, err error, deliveredTo *[]id.UserID) { +func (portal *Portal) sendStatusEvent(ctx context.Context, evtID, lastRetry id.EventID, err error, deliveredTo *[]id.UserID) { if !portal.bridge.Config.Bridge.MessageStatusEvents { return } @@ -179,75 +194,56 @@ func (portal *Portal) sendStatusEvent(evtID, lastRetry id.EventID, err error, de content.Reason, content.Status, _, _, content.Message = errorToStatusReason(err) content.Error = err.Error() } - _, err = intent.SendMessageEvent(portal.MXID, event.BeeperMessageStatus, &content) + _, err = intent.SendMessageEvent(ctx, portal.MXID, event.BeeperMessageStatus, &content) if err != nil { - portal.log.Warnln("Failed to send message status event:", err) + zerolog.Ctx(ctx).Err(err).Msg("Failed to send message status event") } } -func (portal *Portal) sendDeliveryReceipt(eventID id.EventID) { +func (portal *Portal) sendDeliveryReceipt(ctx context.Context, eventID id.EventID) { if portal.bridge.Config.Bridge.DeliveryReceipts { - err := portal.bridge.Bot.SendReceipt(portal.MXID, eventID, event.ReceiptTypeRead, nil) + err := portal.bridge.Bot.SendReceipt(ctx, portal.MXID, eventID, event.ReceiptTypeRead, nil) if err != nil { - portal.log.Debugfln("Failed to send delivery receipt for %s: %v", eventID, err) + zerolog.Ctx(ctx).Err(err).Msg("Failed to mark message as read by bot (Matrix-side delivery receipt)") } } } -func (portal *Portal) sendMessageMetrics(evt *event.Event, err error, part string, ms *metricSender) { - var msgType string - switch evt.Type { - case event.EventMessage: - msgType = "message" - case event.EventReaction: - msgType = "reaction" - case event.EventRedaction: - msgType = "redaction" - case TypeMSC3381PollResponse, TypeMSC3381V2PollResponse: - msgType = "poll response" - case TypeMSC3381PollStart: - msgType = "poll start" - default: - msgType = "unknown event" - } - evtDescription := evt.ID.String() - if evt.Type == event.EventRedaction { - evtDescription += fmt.Sprintf(" of %s", evt.Redacts) - } +func (portal *Portal) sendMessageMetrics(ctx context.Context, evt *event.Event, err error, part string, ms *metricSender) { origEvtID := evt.ID if retryMeta := evt.Content.AsMessage().MessageSendRetry; retryMeta != nil { origEvtID = retryMeta.OriginalEventID } if err != nil { - level := log.LevelError + level := zerolog.ErrorLevel if part == "Ignoring" { - level = log.LevelDebug + level = zerolog.DebugLevel } - portal.log.Logfln(level, "%s %s %s from %s: %v", part, msgType, evtDescription, evt.Sender, err) + zerolog.Ctx(ctx).WithLevel(level).Err(err).Msg(part + " Matrix event") reason, statusCode, isCertain, sendNotice, _ := errorToStatusReason(err) checkpointStatus := status.ReasonToCheckpointStatus(reason, statusCode) portal.bridge.SendMessageCheckpoint(evt, status.MsgStepRemote, err, checkpointStatus, ms.getRetryNum()) if sendNotice { - ms.setNoticeID(portal.sendErrorMessage(evt, err, msgType, isCertain, ms.getNoticeID())) + ms.setNoticeID(portal.sendErrorMessage(ctx, evt, err, isCertain, ms.getNoticeID())) } - portal.sendStatusEvent(origEvtID, evt.ID, err, nil) + portal.sendStatusEvent(ctx, origEvtID, evt.ID, err, nil) } else { - portal.log.Debugfln("Handled Matrix %s %s", msgType, evtDescription) - portal.sendDeliveryReceipt(evt.ID) + zerolog.Ctx(ctx).Debug().Msg("Successfully handled Matrix event") + portal.sendDeliveryReceipt(ctx, evt.ID) portal.bridge.SendMessageSuccessCheckpoint(evt, status.MsgStepRemote, ms.getRetryNum()) var deliveredTo *[]id.UserID if portal.IsPrivateChat() { deliveredTo = &[]id.UserID{} } - portal.sendStatusEvent(origEvtID, evt.ID, nil, deliveredTo) + portal.sendStatusEvent(ctx, origEvtID, evt.ID, nil, deliveredTo) if prevNotice := ms.popNoticeID(); prevNotice != "" { - _, _ = portal.MainIntent().RedactEvent(portal.MXID, prevNotice, mautrix.ReqRedact{ + _, _ = portal.MainIntent().RedactEvent(ctx, portal.MXID, prevNotice, mautrix.ReqRedact{ Reason: "error resolved", }) } } if ms != nil { - portal.log.Debugfln("Timings for %s: %s", evt.ID, ms.timings.String()) + zerolog.Ctx(ctx).Debug().Object("timings", ms.timings).Msg("Matrix event timings") } } @@ -264,47 +260,16 @@ type messageTimings struct { totalSend time.Duration } -func niceRound(dur time.Duration) time.Duration { - switch { - case dur < time.Millisecond: - return dur - case dur < time.Second: - return dur.Round(100 * time.Microsecond) - default: - return dur.Round(time.Millisecond) - } -} - -func (mt *messageTimings) String() string { - mt.initReceive = niceRound(mt.initReceive) - mt.decrypt = niceRound(mt.decrypt) - mt.portalQueue = niceRound(mt.portalQueue) - mt.totalReceive = niceRound(mt.totalReceive) - mt.implicitRR = niceRound(mt.implicitRR) - mt.preproc = niceRound(mt.preproc) - mt.convert = niceRound(mt.convert) - mt.whatsmeow.Queue = niceRound(mt.whatsmeow.Queue) - mt.whatsmeow.Marshal = niceRound(mt.whatsmeow.Marshal) - mt.whatsmeow.GetParticipants = niceRound(mt.whatsmeow.GetParticipants) - mt.whatsmeow.GetDevices = niceRound(mt.whatsmeow.GetDevices) - mt.whatsmeow.GroupEncrypt = niceRound(mt.whatsmeow.GroupEncrypt) - mt.whatsmeow.PeerEncrypt = niceRound(mt.whatsmeow.PeerEncrypt) - mt.whatsmeow.Send = niceRound(mt.whatsmeow.Send) - mt.whatsmeow.Resp = niceRound(mt.whatsmeow.Resp) - mt.whatsmeow.Retry = niceRound(mt.whatsmeow.Retry) - mt.totalSend = niceRound(mt.totalSend) - whatsmeowTimings := "N/A" - if mt.totalSend > 0 { - format := "queue: %[1]s, marshal: %[2]s, ske: %[3]s, pcp: %[4]s, dev: %[5]s, encrypt: %[6]s, send: %[7]s, resp: %[8]s" - if mt.whatsmeow.GetParticipants == 0 && mt.whatsmeow.GroupEncrypt == 0 { - format = "queue: %[1]s, marshal: %[2]s, dev: %[5]s, encrypt: %[6]s, send: %[7]s, resp: %[8]s" - } - if mt.whatsmeow.Retry > 0 { - format += ", retry: %[9]s" - } - whatsmeowTimings = fmt.Sprintf(format, mt.whatsmeow.Queue, mt.whatsmeow.Marshal, mt.whatsmeow.GroupEncrypt, mt.whatsmeow.GetParticipants, mt.whatsmeow.GetDevices, mt.whatsmeow.PeerEncrypt, mt.whatsmeow.Send, mt.whatsmeow.Resp, mt.whatsmeow.Retry) - } - return fmt.Sprintf("BRIDGE: receive: %s, decrypt: %s, queue: %s, total hs->portal: %s, implicit rr: %s -- PORTAL: preprocess: %s, convert: %s, total send: %s -- WHATSMEOW: %s", mt.initReceive, mt.decrypt, mt.implicitRR, mt.portalQueue, mt.totalReceive, mt.preproc, mt.convert, mt.totalSend, whatsmeowTimings) +func (mt *messageTimings) MarshalZerologObject(e *zerolog.Event) { + e.Dur("init_receive", mt.initReceive). + Dur("decrypt", mt.decrypt). + Dur("implicit_rr", mt.implicitRR). + Dur("portal_queue", mt.portalQueue). + Dur("total_receive", mt.totalReceive). + Dur("preproc", mt.preproc). + Dur("convert", mt.convert). + Object("whatsmeow", mt.whatsmeow). + Dur("total_send", mt.totalSend) } type metricSender struct { @@ -345,13 +310,13 @@ func (ms *metricSender) setNoticeID(evtID id.EventID) { } } -func (ms *metricSender) sendMessageMetrics(evt *event.Event, err error, part string, completed bool) { +func (ms *metricSender) sendMessageMetrics(ctx context.Context, evt *event.Event, err error, part string, completed bool) { ms.lock.Lock() defer ms.lock.Unlock() if !completed && ms.completed { return } - ms.portal.sendMessageMetrics(evt, err, part, ms) + ms.portal.sendMessageMetrics(ctx, evt, err, part, ms) ms.retryNum++ ms.completed = completed } diff --git a/metrics.go b/metrics.go index 6a93b04..a74e9a7 100644 --- a/metrics.go +++ b/metrics.go @@ -18,6 +18,7 @@ package main import ( "context" + "errors" "net/http" "runtime/debug" "strconv" @@ -27,7 +28,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promhttp" - log "maunium.net/go/maulogger/v2" + "github.com/rs/zerolog" "go.mau.fi/whatsmeow/types" @@ -40,7 +41,7 @@ import ( type MetricsHandler struct { db *database.Database server *http.Server - log log.Logger + log zerolog.Logger running bool ctx context.Context @@ -70,7 +71,7 @@ type MetricsHandler struct { loggedInStateLock sync.Mutex } -func NewMetricsHandler(address string, log log.Logger, db *database.Database) *MetricsHandler { +func NewMetricsHandler(address string, log zerolog.Logger, db *database.Database) *MetricsHandler { portalCount := promauto.NewGaugeVec(prometheus.GaugeOpts{ Name: "whatsapp_portals_total", Help: "Number of portal rooms on Matrix", @@ -232,31 +233,31 @@ func (mh *MetricsHandler) TrackConnectionState(jid types.JID, connected bool) { func (mh *MetricsHandler) updateStats() { start := time.Now() var puppetCount int - err := mh.db.QueryRowContext(mh.ctx, "SELECT COUNT(*) FROM puppet").Scan(&puppetCount) + err := mh.db.QueryRow(mh.ctx, "SELECT COUNT(*) FROM puppet").Scan(&puppetCount) if err != nil { - mh.log.Warnln("Failed to scan number of puppets:", err) + mh.log.Err(err).Msg("Failed to scan number of puppets") } else { mh.puppetCount.Set(float64(puppetCount)) } var userCount int - err = mh.db.QueryRowContext(mh.ctx, `SELECT COUNT(*) FROM "user"`).Scan(&userCount) + err = mh.db.QueryRow(mh.ctx, `SELECT COUNT(*) FROM "user"`).Scan(&userCount) if err != nil { - mh.log.Warnln("Failed to scan number of users:", err) + mh.log.Err(err).Msg("Failed to scan number of users") } else { mh.userCount.Set(float64(userCount)) } var messageCount int - err = mh.db.QueryRowContext(mh.ctx, "SELECT COUNT(*) FROM message").Scan(&messageCount) + err = mh.db.QueryRow(mh.ctx, "SELECT COUNT(*) FROM message").Scan(&messageCount) if err != nil { - mh.log.Warnln("Failed to scan number of messages:", err) + mh.log.Err(err).Msg("Failed to scan number of messages") } else { mh.messageCount.Set(float64(messageCount)) } var encryptedGroupCount, encryptedPrivateCount, unencryptedGroupCount, unencryptedPrivateCount int - err = mh.db.QueryRowContext(mh.ctx, ` + err = mh.db.QueryRow(mh.ctx, ` SELECT COUNT(CASE WHEN jid LIKE '%@g.us' AND encrypted THEN 1 END) AS encrypted_group_portals, COUNT(CASE WHEN jid LIKE '%@s.whatsapp.net' AND encrypted THEN 1 END) AS encrypted_private_portals, @@ -265,7 +266,7 @@ func (mh *MetricsHandler) updateStats() { FROM portal WHERE mxid<>'' `).Scan(&encryptedGroupCount, &encryptedPrivateCount, &unencryptedGroupCount, &unencryptedPrivateCount) if err != nil { - mh.log.Warnln("Failed to scan number of portals:", err) + mh.log.Err(err).Msg("Failed to scan number of portals") } else { mh.encryptedGroupCount.Set(float64(encryptedGroupCount)) mh.encryptedPrivateCount.Set(float64(encryptedPrivateCount)) @@ -279,7 +280,10 @@ func (mh *MetricsHandler) startUpdatingStats() { defer func() { err := recover() if err != nil { - mh.log.Fatalfln("Panic in metric updater: %v\n%s", err, string(debug.Stack())) + mh.log.WithLevel(zerolog.PanicLevel). + Bytes(zerolog.ErrorStackFieldName, debug.Stack()). + Interface(zerolog.ErrorFieldName, err). + Msg("Panic in metric updater") } }() ticker := time.Tick(10 * time.Second) @@ -299,8 +303,8 @@ func (mh *MetricsHandler) Start() { go mh.startUpdatingStats() err := mh.server.ListenAndServe() mh.running = false - if err != nil && err != http.ErrServerClosed { - mh.log.Fatalln("Error in metrics listener:", err) + if err != nil && !errors.Is(err, http.ErrServerClosed) { + mh.log.Err(err).Msg("Error in metrics listener") } } @@ -311,6 +315,6 @@ func (mh *MetricsHandler) Stop() { mh.stopRecorder() err := mh.server.Close() if err != nil { - mh.log.Errorln("Error closing metrics listener:", err) + mh.log.Err(err).Msg("Failed to close metrics listener") } } diff --git a/portal.go b/portal.go index f330bb4..fc3f361 100644 --- a/portal.go +++ b/portal.go @@ -1,5 +1,5 @@ // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2021 Tulir Asokan +// Copyright (C) 2024 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -31,6 +31,7 @@ import ( "image/jpeg" "image/png" "io" + "maps" "math" "mime" "net/http" @@ -43,13 +44,7 @@ import ( "github.com/rs/zerolog" "github.com/tidwall/gjson" - "go.mau.fi/util/dbutil" - "go.mau.fi/util/exerrors" - "go.mau.fi/util/exmime" - "go.mau.fi/util/ffmpeg" - "go.mau.fi/util/jsontime" - "go.mau.fi/util/random" - "go.mau.fi/util/variationselector" + "go.mau.fi/util/exzerolog" cwebp "go.mau.fi/webp" "go.mau.fi/whatsmeow" waProto "go.mau.fi/whatsmeow/binary/proto" @@ -59,7 +54,13 @@ import ( "golang.org/x/image/draw" "golang.org/x/image/webp" "google.golang.org/protobuf/proto" - log "maunium.net/go/maulogger/v2" + + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exmime" + "go.mau.fi/util/ffmpeg" + "go.mau.fi/util/jsontime" + "go.mau.fi/util/random" + "go.mau.fi/util/variationselector" "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/bridge" @@ -82,11 +83,17 @@ const PrivateChatTopic = "WhatsApp private chat" var ErrStatusBroadcastDisabled = errors.New("status bridging is disabled") func (br *WABridge) GetPortalByMXID(mxid id.RoomID) *Portal { + ctx := context.TODO() br.portalsLock.Lock() defer br.portalsLock.Unlock() portal, ok := br.portalsByMXID[mxid] if !ok { - return br.loadDBPortal(br.DB.Portal.GetByMXID(mxid), nil) + dbPortal, err := br.DB.Portal.GetByMXID(ctx, mxid) + if err != nil { + br.ZLog.Err(err).Stringer("mxid", mxid).Msg("Failed to get portal by MXID") + return nil + } + return br.loadDBPortal(ctx, dbPortal, nil) } return portal } @@ -105,7 +112,10 @@ func (portal *Portal) IsEncrypted() bool { func (portal *Portal) MarkEncrypted() { portal.Encrypted = true - portal.Update(nil) + err := portal.Update(context.TODO()) + if err != nil { + portal.zlog.Err(err).Msg("Failed to mark portal as encrypted") + } } func (portal *Portal) ReceiveMatrixEvent(user bridge.User, evt *event.Event) { @@ -121,27 +131,39 @@ func (portal *Portal) ReceiveMatrixEvent(user bridge.User, evt *event.Event) { } func (br *WABridge) GetPortalByJID(key database.PortalKey) *Portal { + ctx := context.TODO() br.portalsLock.Lock() defer br.portalsLock.Unlock() portal, ok := br.portalsByJID[key] if !ok { - return br.loadDBPortal(br.DB.Portal.GetByJID(key), &key) + dbPortal, err := br.DB.Portal.GetByJID(ctx, key) + if err != nil { + br.ZLog.Err(err).Str("key", key.String()).Msg("Failed to get portal by JID") + return nil + } + return br.loadDBPortal(ctx, dbPortal, &key) } return portal } func (br *WABridge) GetExistingPortalByJID(key database.PortalKey) *Portal { + ctx := context.TODO() br.portalsLock.Lock() defer br.portalsLock.Unlock() portal, ok := br.portalsByJID[key] if !ok { - return br.loadDBPortal(br.DB.Portal.GetByJID(key), nil) + dbPortal, err := br.DB.Portal.GetByJID(ctx, key) + if err != nil { + br.ZLog.Err(err).Str("key", key.String()).Msg("Failed to get portal by JID") + return nil + } + return br.loadDBPortal(ctx, dbPortal, nil) } return portal } func (br *WABridge) GetAllPortals() []*Portal { - return br.dbPortalsToPortals(br.DB.Portal.GetAll()) + return br.dbPortalsToPortals(br.DB.Portal.GetAll(context.TODO())) } func (br *WABridge) GetAllIPortals() (iportals []bridge.Portal) { @@ -154,14 +176,18 @@ func (br *WABridge) GetAllIPortals() (iportals []bridge.Portal) { } func (br *WABridge) GetAllPortalsByJID(jid types.JID) []*Portal { - return br.dbPortalsToPortals(br.DB.Portal.GetAllByJID(jid)) + return br.dbPortalsToPortals(br.DB.Portal.GetAllByJID(context.TODO(), jid)) } func (br *WABridge) GetAllByParentGroup(jid types.JID) []*Portal { - return br.dbPortalsToPortals(br.DB.Portal.GetAllByParentGroup(jid)) + return br.dbPortalsToPortals(br.DB.Portal.GetAllByParentGroup(context.TODO(), jid)) } -func (br *WABridge) dbPortalsToPortals(dbPortals []*database.Portal) []*Portal { +func (br *WABridge) dbPortalsToPortals(dbPortals []*database.Portal, err error) []*Portal { + if err != nil { + br.ZLog.Err(err).Msg("Failed to get portals") + return nil + } br.portalsLock.Lock() defer br.portalsLock.Unlock() output := make([]*Portal, len(dbPortals)) @@ -171,21 +197,25 @@ func (br *WABridge) dbPortalsToPortals(dbPortals []*database.Portal) []*Portal { } portal, ok := br.portalsByJID[dbPortal.Key] if !ok { - portal = br.loadDBPortal(dbPortal, nil) + portal = br.loadDBPortal(context.TODO(), dbPortal, nil) } output[index] = portal } return output } -func (br *WABridge) loadDBPortal(dbPortal *database.Portal, key *database.PortalKey) *Portal { +func (br *WABridge) loadDBPortal(ctx context.Context, dbPortal *database.Portal, key *database.PortalKey) *Portal { if dbPortal == nil { if key == nil { return nil } dbPortal = br.DB.Portal.New() dbPortal.Key = *key - dbPortal.Insert() + err := dbPortal.Insert(ctx) + if err != nil { + br.ZLog.Err(err).Str("key", key.String()).Msg("Failed to insert new portal") + return nil + } } portal := br.NewPortal(dbPortal) br.portalsByJID[portal.Key] = portal @@ -196,34 +226,34 @@ func (br *WABridge) loadDBPortal(dbPortal *database.Portal, key *database.Portal } func (portal *Portal) GetUsers() []*User { + // TODO what's this for? return nil } -func (br *WABridge) newBlankPortal(key database.PortalKey) *Portal { +func (br *WABridge) NewManualPortal(key database.PortalKey) *Portal { + dbPortal := br.DB.Portal.New() + dbPortal.Key = key + return br.NewPortal(dbPortal) +} + +func (br *WABridge) NewPortal(dbPortal *database.Portal) *Portal { portal := &Portal{ - bridge: br, - log: br.Log.Sub(fmt.Sprintf("Portal/%s", key)), - zlog: br.ZLog.With().Str("portal_key", key.String()).Logger(), - - events: make(chan *PortalEvent, br.Config.Bridge.PortalMessageBuffer), - + Portal: dbPortal, + bridge: br, + events: make(chan *PortalEvent, br.Config.Bridge.PortalMessageBuffer), mediaErrorCache: make(map[types.MessageID]*FailedMediaMeta), } + portal.updateLogger() go portal.handleMessageLoop() return portal } -func (br *WABridge) NewManualPortal(key database.PortalKey) *Portal { - portal := br.newBlankPortal(key) - portal.Portal = br.DB.Portal.New() - portal.Key = key - return portal -} - -func (br *WABridge) NewPortal(dbPortal *database.Portal) *Portal { - portal := br.newBlankPortal(dbPortal.Key) - portal.Portal = dbPortal - return portal +func (portal *Portal) updateLogger() { + logWith := portal.bridge.ZLog.With().Stringer("portal_key", portal.Key) + if portal.MXID != "" { + logWith = logWith.Stringer("room_id", portal.MXID) + } + portal.zlog = logWith.Logger() } const recentlyHandledLength = 100 @@ -270,9 +300,7 @@ type Portal struct { *database.Portal bridge *WABridge - // Deprecated: use zerolog - log log.Logger - zlog zerolog.Logger + zlog zerolog.Logger roomCreateLock sync.Mutex encryptLock sync.Mutex @@ -346,15 +374,21 @@ var ( ) func (portal *Portal) handleWhatsAppMessageLoopItem(msg *PortalMessage) { + log := portal.zlog.With(). + Str("action", "handle whatsapp event"). + Stringer("source_user_jid", msg.source.JID). + Stringer("source_user_mxid", msg.source.MXID). + Logger() + ctx := log.WithContext(context.TODO()) if len(portal.MXID) == 0 { if msg.fake == nil && msg.undecryptable == nil && (msg.evt == nil || !containsSupportedMessage(msg.evt.Message)) { - portal.log.Debugln("Not creating portal room for incoming message: message is not a chat message") + log.Debug().Msg("Not creating portal room for incoming message: message is not a chat message") return } - portal.log.Debugln("Creating Matrix room from incoming message") - err := portal.CreateMatrixRoom(msg.source, nil, nil, false, true) + log.Debug().Msg("Creating Matrix room from incoming message") + err := portal.CreateMatrixRoom(ctx, msg.source, nil, nil, false, true) if err != nil { - portal.log.Errorln("Failed to create portal room:", err) + log.Err(err).Msg("Failed to create portal room") return } } @@ -362,22 +396,48 @@ func (portal *Portal) handleWhatsAppMessageLoopItem(msg *PortalMessage) { defer portal.latestEventBackfillLock.Unlock() switch { case msg.evt != nil: - portal.handleMessage(msg.source, msg.evt, false) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c. + Str("message_id", msg.evt.Info.ID). + Stringer("message_sender", msg.evt.Info.Sender) + }) + portal.handleMessage(ctx, msg.source, msg.evt, false) case msg.receipt != nil: - portal.handleReceipt(msg.receipt, msg.source) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("receipt_type", msg.receipt.Type.GoString()) + }) + portal.handleReceipt(ctx, msg.receipt, msg.source) case msg.undecryptable != nil: + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c. + Str("message_id", msg.undecryptable.Info.ID). + Stringer("message_sender", msg.undecryptable.Info.Sender). + Bool("undecryptable", true) + }) portal.stopGallery() - portal.handleUndecryptableMessage(msg.source, msg.undecryptable) + portal.handleUndecryptableMessage(ctx, msg.source, msg.undecryptable) case msg.fake != nil: + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c. + Str("fake_message_id", msg.fake.ID). + Stringer("message_sender", msg.fake.Sender) + }) portal.stopGallery() msg.fake.ID = "FAKE::" + msg.fake.ID - portal.handleFakeMessage(*msg.fake) + portal.handleFakeMessage(ctx, *msg.fake) default: - portal.log.Warnln("Unexpected PortalMessage with no message: %+v", msg) + log.Warn().Any("event_data", msg).Msg("Unexpected PortalMessage with no message") } } func (portal *Portal) handleMatrixMessageLoopItem(msg *PortalMatrixMessage) { + log := portal.zlog.With(). + Str("action", "handle matrix event"). + Stringer("event_id", msg.evt.ID). + Str("event_type", msg.evt.Type.Type). + Stringer("sender", msg.evt.Sender). + Logger() + ctx := log.WithContext(context.TODO()) portal.latestEventBackfillLock.Lock() defer portal.latestEventBackfillLock.Unlock() evtTS := time.UnixMilli(msg.evt.Timestamp) @@ -388,27 +448,34 @@ func (portal *Portal) handleMatrixMessageLoopItem(msg *PortalMatrixMessage) { totalReceive: time.Since(evtTS), } implicitRRStart := time.Now() - portal.handleMatrixReadReceipt(msg.user, "", evtTS, false) + portal.handleMatrixReadReceipt(ctx, msg.user, "", evtTS, false) timings.implicitRR = time.Since(implicitRRStart) switch msg.evt.Type { case event.EventMessage, event.EventSticker, TypeMSC3381V2PollResponse, TypeMSC3381PollResponse, TypeMSC3381PollStart: - portal.HandleMatrixMessage(msg.user, msg.evt, timings) + portal.HandleMatrixMessage(ctx, msg.user, msg.evt, timings) case event.EventRedaction: - portal.HandleMatrixRedaction(msg.user, msg.evt) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Stringer("redaction_target_mxid", msg.evt.Redacts) + }) + portal.HandleMatrixRedaction(ctx, msg.user, msg.evt) case event.EventReaction: - portal.HandleMatrixReaction(msg.user, msg.evt) + portal.HandleMatrixReaction(ctx, msg.user, msg.evt) default: - portal.log.Warnln("Unsupported event type %+v in portal message channel", msg.evt.Type) + log.Warn().Msg("Unsupported event type in portal message channel") } } -func (portal *Portal) handleDeliveryReceipt(receipt *events.Receipt, source *User) { +func (portal *Portal) handleDeliveryReceipt(ctx context.Context, receipt *events.Receipt, source *User) { if !portal.IsPrivateChat() { return } + log := zerolog.Ctx(ctx) for _, msgID := range receipt.MessageIDs { - msg := portal.bridge.DB.Message.GetByJID(portal.Key, msgID) - if msg == nil || msg.IsFakeMXID() { + msg, err := portal.bridge.DB.Message.GetByJID(ctx, portal.Key, msgID) + if err != nil { + log.Err(err).Str("message_id", msgID).Msg("Failed to get receipt target message") + continue + } else if msg == nil || msg.IsFakeMXID() { continue } if msg.Sender == source.JID { @@ -420,18 +487,18 @@ func (portal *Portal) handleDeliveryReceipt(receipt *events.Receipt, source *Use Status: status.MsgStatusDelivered, ReportedBy: status.MsgReportedByBridge, }) - portal.sendStatusEvent(msg.MXID, "", nil, &[]id.UserID{portal.MainIntent().UserID}) + portal.sendStatusEvent(ctx, msg.MXID, "", nil, &[]id.UserID{portal.MainIntent().UserID}) } } } -func (portal *Portal) handleReceipt(receipt *events.Receipt, source *User) { +func (portal *Portal) handleReceipt(ctx context.Context, receipt *events.Receipt, source *User) { if receipt.Sender.Server != types.DefaultUserServer { // TODO handle lids return } if receipt.Type == types.ReceiptTypeDelivered { - portal.handleDeliveryReceipt(receipt, source) + portal.handleDeliveryReceipt(ctx, receipt, source) return } // The order of the message ID array depends on the sender's platform, so we just have to find @@ -440,9 +507,12 @@ func (portal *Portal) handleReceipt(receipt *events.Receipt, source *User) { // know which one is last markAsRead := make([]*database.Message, 0, 1) var bestTimestamp time.Time + log := zerolog.Ctx(ctx) for _, msgID := range receipt.MessageIDs { - msg := portal.bridge.DB.Message.GetByJID(portal.Key, msgID) - if msg == nil || msg.IsFakeMXID() { + msg, err := portal.bridge.DB.Message.GetByJID(ctx, portal.Key, msgID) + if err != nil { + log.Err(err).Str("message_id", msgID).Msg("Failed to get receipt target message") + } else if msg == nil || msg.IsFakeMXID() { continue } if msg.Timestamp.After(bestTimestamp) { @@ -454,18 +524,24 @@ func (portal *Portal) handleReceipt(receipt *events.Receipt, source *User) { } if receipt.Sender.User == source.JID.User { if len(markAsRead) > 0 { - source.SetLastReadTS(portal.Key, markAsRead[0].Timestamp) + source.SetLastReadTS(ctx, portal.Key, markAsRead[0].Timestamp) } else { - source.SetLastReadTS(portal.Key, receipt.Timestamp) + source.SetLastReadTS(ctx, portal.Key, receipt.Timestamp) } } intent := portal.bridge.GetPuppetByJID(receipt.Sender).IntentFor(portal) for _, msg := range markAsRead { - err := intent.SetReadMarkers(portal.MXID, source.makeReadMarkerContent(msg.MXID, intent.IsCustomPuppet)) + err := intent.SetReadMarkers(ctx, portal.MXID, source.makeReadMarkerContent(msg.MXID, intent.IsCustomPuppet)) if err != nil { - portal.log.Warnfln("Failed to mark message %s as read by %s: %v", msg.MXID, intent.UserID, err) + log.Err(err). + Stringer("message_mxid", msg.MXID). + Stringer("read_by_user_mxid", intent.UserID). + Msg("Failed to mark message as read") } else { - portal.log.Debugfln("Marked %s as read by %s", msg.MXID, intent.UserID) + log.Debug(). + Stringer("message_mxid", msg.MXID). + Stringer("read_by_user_mxid", intent.UserID). + Msg("Marked message as read") } } } @@ -499,7 +575,7 @@ func (portal *Portal) handleOneMessageLoopItem() { } else if msg.MediaRetry != nil { portal.handleMediaRetry(msg.MediaRetry.evt, msg.MediaRetry.source) } else { - portal.log.Warn("Portal event loop returned an event without any data") + portal.zlog.Warn().Msg("Unexpected PortalEvent with no data") } } } @@ -651,57 +727,60 @@ func formatDuration(d time.Duration) string { return naturalJoin(parts) } -func (portal *Portal) convertMessage(intent *appservice.IntentAPI, source *User, info *types.MessageInfo, waMsg *waProto.Message, isBackfill bool) *ConvertedMessage { +func (portal *Portal) convertMessage(ctx context.Context, intent *appservice.IntentAPI, source *User, info *types.MessageInfo, waMsg *waProto.Message, isBackfill bool) *ConvertedMessage { switch { case waMsg.Conversation != nil || waMsg.ExtendedTextMessage != nil: - return portal.convertTextMessage(intent, source, waMsg) + return portal.convertTextMessage(ctx, intent, source, waMsg) case waMsg.TemplateMessage != nil: - return portal.convertTemplateMessage(intent, source, info, waMsg.GetTemplateMessage()) + return portal.convertTemplateMessage(ctx, intent, source, info, waMsg.GetTemplateMessage()) case waMsg.HighlyStructuredMessage != nil: - return portal.convertTemplateMessage(intent, source, info, waMsg.GetHighlyStructuredMessage().GetHydratedHsm()) + return portal.convertTemplateMessage(ctx, intent, source, info, waMsg.GetHighlyStructuredMessage().GetHydratedHsm()) case waMsg.TemplateButtonReplyMessage != nil: - return portal.convertTemplateButtonReplyMessage(intent, waMsg.GetTemplateButtonReplyMessage()) + return portal.convertTemplateButtonReplyMessage(ctx, intent, waMsg.GetTemplateButtonReplyMessage()) case waMsg.ListMessage != nil: - return portal.convertListMessage(intent, source, waMsg.GetListMessage()) + return portal.convertListMessage(ctx, intent, source, waMsg.GetListMessage()) case waMsg.ListResponseMessage != nil: - return portal.convertListResponseMessage(intent, waMsg.GetListResponseMessage()) + return portal.convertListResponseMessage(ctx, intent, waMsg.GetListResponseMessage()) case waMsg.PollCreationMessage != nil: - return portal.convertPollCreationMessage(intent, waMsg.GetPollCreationMessage()) + return portal.convertPollCreationMessage(ctx, intent, waMsg.GetPollCreationMessage()) case waMsg.PollCreationMessageV2 != nil: - return portal.convertPollCreationMessage(intent, waMsg.GetPollCreationMessageV2()) + return portal.convertPollCreationMessage(ctx, intent, waMsg.GetPollCreationMessageV2()) case waMsg.PollCreationMessageV3 != nil: - return portal.convertPollCreationMessage(intent, waMsg.GetPollCreationMessageV3()) + return portal.convertPollCreationMessage(ctx, intent, waMsg.GetPollCreationMessageV3()) case waMsg.PollUpdateMessage != nil: - return portal.convertPollUpdateMessage(intent, source, info, waMsg.GetPollUpdateMessage()) + return portal.convertPollUpdateMessage(ctx, intent, source, info, waMsg.GetPollUpdateMessage()) case waMsg.ImageMessage != nil: - return portal.convertMediaMessage(intent, source, info, waMsg.GetImageMessage(), "photo", isBackfill) + return portal.convertMediaMessage(ctx, intent, source, info, waMsg.GetImageMessage(), "photo", isBackfill) case waMsg.StickerMessage != nil: - return portal.convertMediaMessage(intent, source, info, waMsg.GetStickerMessage(), "sticker", isBackfill) + return portal.convertMediaMessage(ctx, intent, source, info, waMsg.GetStickerMessage(), "sticker", isBackfill) case waMsg.VideoMessage != nil: - return portal.convertMediaMessage(intent, source, info, waMsg.GetVideoMessage(), "video attachment", isBackfill) + return portal.convertMediaMessage(ctx, intent, source, info, waMsg.GetVideoMessage(), "video attachment", isBackfill) case waMsg.PtvMessage != nil: - return portal.convertMediaMessage(intent, source, info, waMsg.GetPtvMessage(), "video message", isBackfill) + return portal.convertMediaMessage(ctx, intent, source, info, waMsg.GetPtvMessage(), "video message", isBackfill) case waMsg.AudioMessage != nil: typeName := "audio attachment" if waMsg.GetAudioMessage().GetPtt() { typeName = "voice message" } - return portal.convertMediaMessage(intent, source, info, waMsg.GetAudioMessage(), typeName, isBackfill) + return portal.convertMediaMessage(ctx, intent, source, info, waMsg.GetAudioMessage(), typeName, isBackfill) case waMsg.DocumentMessage != nil: - return portal.convertMediaMessage(intent, source, info, waMsg.GetDocumentMessage(), "file attachment", isBackfill) + return portal.convertMediaMessage(ctx, intent, source, info, waMsg.GetDocumentMessage(), "file attachment", isBackfill) case waMsg.ContactMessage != nil: - return portal.convertContactMessage(intent, waMsg.GetContactMessage()) + return portal.convertContactMessage(ctx, intent, waMsg.GetContactMessage()) case waMsg.ContactsArrayMessage != nil: - return portal.convertContactsArrayMessage(intent, waMsg.GetContactsArrayMessage()) + return portal.convertContactsArrayMessage(ctx, intent, waMsg.GetContactsArrayMessage()) case waMsg.LocationMessage != nil: - return portal.convertLocationMessage(intent, waMsg.GetLocationMessage()) + return portal.convertLocationMessage(ctx, intent, waMsg.GetLocationMessage()) case waMsg.LiveLocationMessage != nil: - return portal.convertLiveLocationMessage(intent, waMsg.GetLiveLocationMessage()) + return portal.convertLiveLocationMessage(ctx, intent, waMsg.GetLiveLocationMessage()) case waMsg.GroupInviteMessage != nil: - return portal.convertGroupInviteMessage(intent, info, waMsg.GetGroupInviteMessage()) + return portal.convertGroupInviteMessage(ctx, intent, info, waMsg.GetGroupInviteMessage()) case waMsg.ProtocolMessage != nil && waMsg.ProtocolMessage.GetType() == waProto.ProtocolMessage_EPHEMERAL_SETTING: portal.ExpirationTime = waMsg.ProtocolMessage.GetEphemeralExpiration() - portal.Update(nil) + err := portal.Update(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after updating expiration timer") + } return &ConvertedMessage{ Intent: intent, Type: event.EventMessage, @@ -715,41 +794,50 @@ func (portal *Portal) convertMessage(intent *appservice.IntentAPI, source *User, } } -func (portal *Portal) implicitlyEnableDisappearingMessages(timer time.Duration) { +func (portal *Portal) implicitlyEnableDisappearingMessages(ctx context.Context, timer time.Duration) { portal.ExpirationTime = uint32(timer.Seconds()) - portal.Update(nil) + err := portal.Update(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after implicitly enabling disappearing timer") + } intent := portal.MainIntent() if portal.Encrypted { intent = portal.bridge.Bot } duration := formatDuration(time.Duration(portal.ExpirationTime) * time.Second) - _, err := portal.sendMessage(intent, event.EventMessage, &event.MessageEventContent{ + _, err = portal.sendMessage(ctx, intent, event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgNotice, Body: fmt.Sprintf("Automatically enabled disappearing message timer (%s) because incoming message is disappearing", duration), }, nil, 0) if err != nil { - portal.zlog.Warn().Err(err).Msg("Failed to send notice about implicit disappearing timer") + zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to send notice about implicit disappearing timer") } } -func (portal *Portal) UpdateGroupDisappearingMessages(sender *types.JID, timestamp time.Time, timer uint32) { +func (portal *Portal) UpdateGroupDisappearingMessages(ctx context.Context, sender *types.JID, timestamp time.Time, timer uint32) { if portal.ExpirationTime == timer { return } portal.ExpirationTime = timer - portal.Update(nil) + err := portal.Update(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after updating expiration timer") + } intent := portal.MainIntent() if sender != nil && sender.Server == types.DefaultUserServer { intent = portal.bridge.GetPuppetByJID(sender.ToNonAD()).IntentFor(portal) } else { sender = &types.EmptyJID } - _, err := portal.sendMessage(intent, event.EventMessage, &event.MessageEventContent{ + _, err = portal.sendMessage(ctx, intent, event.EventMessage, &event.MessageEventContent{ Body: portal.formatDisappearingMessageNotice(), MsgType: event.MsgNotice, }, nil, timestamp.UnixMilli()) if err != nil { - portal.log.Warnfln("Failed to notify portal about disappearing message timer change by %s to %d", *sender, timer) + zerolog.Ctx(ctx).Warn().Err(err). + Uint32("new_timer", timer). + Stringer("sender_jid", sender). + Msg("Failed to notify portal about disappearing message timer change") } } @@ -770,15 +858,19 @@ func init() { undecryptableMessageContent.MsgType = event.MsgNotice } -func (portal *Portal) handleUndecryptableMessage(source *User, evt *events.UndecryptableMessage) { +func (portal *Portal) handleUndecryptableMessage(ctx context.Context, source *User, evt *events.UndecryptableMessage) { + log := zerolog.Ctx(ctx) if len(portal.MXID) == 0 { - portal.log.Warnln("handleUndecryptableMessage called even though portal.MXID is empty") + log.Warn().Msg("handleUndecryptableMessage called even though portal.MXID is empty") return } else if portal.isRecentlyHandled(evt.Info.ID, database.MsgErrDecryptionFailed) { - portal.log.Debugfln("Not handling %s (undecryptable): message was recently handled", evt.Info.ID) + log.Debug().Msg("Not handling recently handled message") return - } else if existingMsg := portal.bridge.DB.Message.GetByJID(portal.Key, evt.Info.ID); existingMsg != nil { - portal.log.Debugfln("Not handling %s (undecryptable): message is duplicate", evt.Info.ID) + } else if existingMsg, err := portal.bridge.DB.Message.GetByJID(ctx, portal.Key, evt.Info.ID); err != nil { + log.Err(err).Msg("Failed to get message from database to check if undecryptable message is duplicate") + return + } else if existingMsg != nil { + log.Debug().Msg("Not handling duplicate message") return } metricType := "error" @@ -789,49 +881,53 @@ func (portal *Portal) handleUndecryptableMessage(source *User, evt *events.Undec "messageID": evt.Info.ID, "undecryptableType": metricType, }) - intent := portal.getMessageIntent(source, &evt.Info, "undecryptable") + intent := portal.getMessageIntent(ctx, source, &evt.Info) if intent == nil { return } content := undecryptableMessageContent - resp, err := portal.sendMessage(intent, event.EventMessage, &content, nil, evt.Info.Timestamp.UnixMilli()) + resp, err := portal.sendMessage(ctx, intent, event.EventMessage, &content, nil, evt.Info.Timestamp.UnixMilli()) if err != nil { - portal.log.Errorfln("Failed to send decryption error of %s to Matrix: %v", evt.Info.ID, err) + log.Err(err).Msg("Failed to send WhatsApp decryption error message to Matrix") return } - portal.finishHandling(nil, &evt.Info, resp.EventID, intent.UserID, database.MsgUnknown, 0, database.MsgErrDecryptionFailed) + portal.finishHandling(ctx, nil, &evt.Info, resp.EventID, intent.UserID, database.MsgUnknown, 0, database.MsgErrDecryptionFailed) } -func (portal *Portal) handleFakeMessage(msg fakeMessage) { +func (portal *Portal) handleFakeMessage(ctx context.Context, msg fakeMessage) { + log := zerolog.Ctx(ctx) if portal.isRecentlyHandled(msg.ID, database.MsgNoError) { - portal.log.Debugfln("Not handling %s (fake): message was recently handled", msg.ID) + log.Debug().Msg("Not handling recently handled message") return - } else if existingMsg := portal.bridge.DB.Message.GetByJID(portal.Key, msg.ID); existingMsg != nil { - portal.log.Debugfln("Not handling %s (fake): message is duplicate", msg.ID) + } else if existingMsg, err := portal.bridge.DB.Message.GetByJID(ctx, portal.Key, msg.ID); err != nil { + log.Err(err).Msg("Failed to get message from database to check if fake message is duplicate") + return + } else if existingMsg != nil { + log.Debug().Msg("Not handling duplicate message") return } if msg.Sender.Server != types.DefaultUserServer { - portal.log.Debugfln("Not handling %s (fake): message is from a lid user (%s)", msg.ID, msg.Sender) + log.Debug().Msg("Not handling message from @lid user") // TODO handle lids return } intent := portal.bridge.GetPuppetByJID(msg.Sender).IntentFor(portal) if !intent.IsCustomPuppet && portal.IsPrivateChat() && msg.Sender.User == portal.Key.Receiver.User && portal.Key.Receiver != portal.Key.JID { - portal.log.Debugfln("Not handling %s (fake): user doesn't have double puppeting enabled", msg.ID) + log.Debug().Msg("Not handling fake message for user who doesn't have double puppeting enabled") return } msgType := event.MsgNotice if msg.Important { msgType = event.MsgText } - resp, err := portal.sendMessage(intent, event.EventMessage, &event.MessageEventContent{ + resp, err := portal.sendMessage(ctx, intent, event.EventMessage, &event.MessageEventContent{ MsgType: msgType, Body: msg.Text, }, nil, msg.Time.UnixMilli()) if err != nil { - portal.log.Errorfln("Failed to send %s to Matrix: %v", msg.ID, err) + log.Err(err).Msg("Failed to send fake message to Matrix") } else { - portal.finishHandling(nil, &types.MessageInfo{ + portal.finishHandling(ctx, nil, &types.MessageInfo{ ID: msg.ID, Timestamp: msg.Time, MessageSource: types.MessageSource{ @@ -841,9 +937,10 @@ func (portal *Portal) handleFakeMessage(msg fakeMessage) { } } -func (portal *Portal) handleMessage(source *User, evt *events.Message, historical bool) { +func (portal *Portal) handleMessage(ctx context.Context, source *User, evt *events.Message, historical bool) { + log := zerolog.Ctx(ctx) if len(portal.MXID) == 0 { - portal.log.Warnln("handleMessage called even though portal.MXID is empty") + log.Warn().Msg("handleMessage called even though portal.MXID is empty") return } msgID := evt.Info.ID @@ -851,10 +948,17 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message, historica if msgType == "ignore" { return } else if portal.isRecentlyHandled(msgID, database.MsgNoError) { - portal.log.Debugfln("Not handling %s (%s): message was recently handled", msgID, msgType) + log.Debug().Msg("Not handling recently handled message") + return + } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("wa_message_type", msgType) + }) + existingMsg, err := portal.bridge.DB.Message.GetByJID(ctx, portal.Key, msgID) + if err != nil { + log.Err(err).Msg("Failed to get message from database to check if message is duplicate") return } - existingMsg := portal.bridge.DB.Message.GetByJID(portal.Key, msgID) if existingMsg != nil { if existingMsg.Error == database.MsgErrDecryptionFailed { resolveType := "sender" @@ -865,34 +969,42 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message, historica "messageID": evt.Info.ID, "resolveType": resolveType, }) - portal.log.Debugfln("Got decryptable version of previously undecryptable message %s (%s) via %s", msgID, msgType, resolveType) + log.Debug().Str("resolved_via", resolveType).Msg("Got decryptable version of previously undecryptable message") } else { - portal.log.Debugfln("Not handling %s (%s): message is duplicate", msgID, msgType) + log.Debug().Msg("Not handling duplicate message") return } } var editTargetMsg *database.Message if msgType == "edit" { editTargetID := evt.Message.GetProtocolMessage().GetKey().GetId() - editTargetMsg = portal.bridge.DB.Message.GetByJID(portal.Key, editTargetID) - if editTargetMsg == nil { - portal.log.Warnfln("Not handling %s: couldn't find edit target %s", msgID, editTargetID) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("edit_target_id", editTargetID) + }) + editTargetMsg, err = portal.bridge.DB.Message.GetByJID(ctx, portal.Key, editTargetID) + if err != nil { + log.Err(err).Msg("Failed to get edit target message from database") + return + } else if editTargetMsg == nil { + log.Warn().Msg("Not handling edit: couldn't find edit target") return } else if editTargetMsg.Type != database.MsgNormal { - portal.log.Warnfln("Not handling %s: edit target %s is not a normal message (it's %s)", msgID, editTargetID, editTargetMsg.Type) + log.Warn().Str("edit_target_db_type", string(editTargetMsg.Type)). + Msg("Not handling edit: edit target is not a normal message") return } else if editTargetMsg.Sender.User != evt.Info.Sender.User { - portal.log.Warnfln("Not handling %s: edit target %s was sent by %s, not %s", msgID, editTargetID, editTargetMsg.Sender.User, evt.Info.Sender.User) + log.Warn().Stringer("edit_target_sender", editTargetMsg.Sender). + Msg("Not handling edit: edit was sent by another user") return } evt.Message = evt.Message.GetProtocolMessage().GetEditedMessage() } - intent := portal.getMessageIntent(source, &evt.Info, msgType) + intent := portal.getMessageIntent(ctx, source, &evt.Info) if intent == nil { return } - converted := portal.convertMessage(intent, source, &evt.Info, evt.Message, false) + converted := portal.convertMessage(ctx, intent, source, &evt.Info, evt.Message, false) if converted != nil { isGalleriable := portal.bridge.Config.Bridge.BeeperGalleries && (evt.Message.ImageMessage != nil || evt.Message.VideoMessage != nil) && @@ -906,16 +1018,14 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message, historica editTargetMsg == nil if !historical && portal.IsPrivateChat() && evt.Info.Sender.Device == 0 && converted.ExpiresIn > 0 && portal.ExpirationTime == 0 { - portal.zlog.Info(). + log.Info(). Str("timer", converted.ExpiresIn.String()). - Str("sender_jid", evt.Info.Sender.String()). - Str("message_id", evt.Info.ID). Msg("Implicitly enabling disappearing messages as incoming message is disappearing") - portal.implicitlyEnableDisappearingMessages(converted.ExpiresIn) + portal.implicitlyEnableDisappearingMessages(ctx, converted.ExpiresIn) } if evt.Info.IsIncomingBroadcast() { if converted.Extra == nil { - converted.Extra = map[string]interface{}{} + converted.Extra = map[string]any{} } converted.Extra["fi.mau.whatsapp.source_broadcast_list"] = evt.Info.Chat.String() } @@ -925,10 +1035,10 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message, historica var eventID id.EventID var lastEventID id.EventID if existingMsg != nil { - portal.MarkDisappearing(nil, existingMsg.MXID, converted.ExpiresIn, evt.Info.Timestamp) + portal.MarkDisappearing(ctx, existingMsg.MXID, converted.ExpiresIn, evt.Info.Timestamp) converted.Content.SetEdit(existingMsg.MXID) } else if converted.ReplyTo != nil { - portal.SetReply(evt.Info.ID, converted.Content, converted.ReplyTo, false) + portal.SetReply(ctx, converted.Content, converted.ReplyTo, false) } dbMsgType := database.MsgNormal if editTargetMsg != nil { @@ -949,12 +1059,13 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message, historica // Stop collecting a gallery (except if it's an edit) portal.stopGallery() } - resp, err := portal.sendMessage(converted.Intent, converted.Type, converted.Content, converted.Extra, evt.Info.Timestamp.UnixMilli()) + var resp *mautrix.RespSendEvent + resp, err = portal.sendMessage(ctx, converted.Intent, converted.Type, converted.Content, converted.Extra, evt.Info.Timestamp.UnixMilli()) if err != nil { - portal.log.Errorfln("Failed to send %s to Matrix: %v", msgID, err) + log.Err(err).Msg("Failed to send WhatsApp message to Matrix") } else { if editTargetMsg == nil { - portal.MarkDisappearing(nil, resp.EventID, converted.ExpiresIn, evt.Info.Timestamp) + portal.MarkDisappearing(ctx, resp.EventID, converted.ExpiresIn, evt.Info.Timestamp) } eventID = resp.EventID lastEventID = eventID @@ -966,21 +1077,21 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message, historica } // TODO figure out how to handle captions with undecryptable messages turning decryptable if converted.Caption != nil && existingMsg == nil && editTargetMsg == nil { - resp, err = portal.sendMessage(converted.Intent, converted.Type, converted.Caption, nil, evt.Info.Timestamp.UnixMilli()) + resp, err = portal.sendMessage(ctx, converted.Intent, converted.Type, converted.Caption, nil, evt.Info.Timestamp.UnixMilli()) if err != nil { - portal.log.Errorfln("Failed to send caption of %s to Matrix: %v", msgID, err) + log.Err(err).Msg("Failed to send caption of WhatsApp message to Matrix") } else { - portal.MarkDisappearing(nil, resp.EventID, converted.ExpiresIn, evt.Info.Timestamp) + portal.MarkDisappearing(ctx, resp.EventID, converted.ExpiresIn, evt.Info.Timestamp) lastEventID = resp.EventID } } if converted.MultiEvent != nil && existingMsg == nil && editTargetMsg == nil { for index, subEvt := range converted.MultiEvent { - resp, err = portal.sendMessage(converted.Intent, converted.Type, subEvt, nil, evt.Info.Timestamp.UnixMilli()) + resp, err = portal.sendMessage(ctx, converted.Intent, converted.Type, subEvt, nil, evt.Info.Timestamp.UnixMilli()) if err != nil { - portal.log.Errorfln("Failed to send sub-event %d of %s to Matrix: %v", index+1, msgID, err) + log.Err(err).Int("part_number", index+1).Msg("Failed to send sub-event of WhatsApp message to Matrix") } else { - portal.MarkDisappearing(nil, resp.EventID, converted.ExpiresIn, evt.Info.Timestamp) + portal.MarkDisappearing(ctx, resp.EventID, converted.ExpiresIn, evt.Info.Timestamp) lastEventID = resp.EventID } } @@ -989,40 +1100,50 @@ func (portal *Portal) handleMessage(source *User, evt *events.Message, historica // There are some edge cases (like call notices) where previous messages aren't marked as read // when the user sends a message from another device, so just mark the new message as read to be safe. // Hungryserv does this automatically, so the bridge doesn't need to do it manually. - err = intent.SetReadMarkers(portal.MXID, source.makeReadMarkerContent(lastEventID, true)) + err = intent.SetReadMarkers(ctx, portal.MXID, source.makeReadMarkerContent(lastEventID, true)) if err != nil { - portal.log.Warnfln("Failed to mark own message %s as read by %s: %v", lastEventID, source.MXID, err) + log.Warn().Err(err).Stringer("last_event_id", lastEventID). + Msg("Failed to mark last message as read after sending") } } if len(eventID) != 0 { - portal.finishHandling(existingMsg, &evt.Info, eventID, intent.UserID, dbMsgType, galleryPart, converted.Error) + portal.finishHandling(ctx, existingMsg, &evt.Info, eventID, intent.UserID, dbMsgType, galleryPart, converted.Error) } } else if msgType == "reaction" || msgType == "encrypted reaction" { if evt.Message.GetEncReactionMessage() != nil { + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("reaction_target_id", evt.Message.GetEncReactionMessage().GetTargetMessageKey().GetId()) + }) decryptedReaction, err := source.Client.DecryptReaction(evt) if err != nil { - portal.log.Errorfln("Failed to decrypt reaction from %s to %s: %v", evt.Info.Sender, evt.Message.GetEncReactionMessage().GetTargetMessageKey().GetId(), err) + log.Err(err).Msg("Failed to decrypt reaction") } else { - portal.HandleMessageReaction(intent, source, &evt.Info, decryptedReaction, existingMsg) + portal.HandleMessageReaction(ctx, intent, source, &evt.Info, decryptedReaction, existingMsg) } } else { - portal.HandleMessageReaction(intent, source, &evt.Info, evt.Message.GetReactionMessage(), existingMsg) + portal.HandleMessageReaction(ctx, intent, source, &evt.Info, evt.Message.GetReactionMessage(), existingMsg) } } else if msgType == "revoke" { - portal.HandleMessageRevoke(source, &evt.Info, evt.Message.GetProtocolMessage().GetKey()) + portal.HandleMessageRevoke(ctx, source, &evt.Info, evt.Message.GetProtocolMessage().GetKey()) if existingMsg != nil { - _, _ = portal.MainIntent().RedactEvent(portal.MXID, existingMsg.MXID, mautrix.ReqRedact{ + _, _ = portal.MainIntent().RedactEvent(ctx, portal.MXID, existingMsg.MXID, mautrix.ReqRedact{ Reason: "The undecryptable message was actually the deletion of another message", }) - existingMsg.UpdateMXID(nil, "net.maunium.whatsapp.fake::"+existingMsg.MXID, database.MsgFake, database.MsgNoError) + err = existingMsg.UpdateMXID(ctx, "net.maunium.whatsapp.fake::"+existingMsg.MXID, database.MsgFake, database.MsgNoError) + if err != nil { + log.Err(err).Msg("Failed to update message in database after finding undecryptable message was a revoke message") + } } } else { - portal.log.Warnfln("Unhandled message: %+v (%s)", evt.Info, msgType) + log.Warn().Any("event_info", evt.Info).Msg("Unhandled message") if existingMsg != nil { - _, _ = portal.MainIntent().RedactEvent(portal.MXID, existingMsg.MXID, mautrix.ReqRedact{ + _, _ = portal.MainIntent().RedactEvent(ctx, portal.MXID, existingMsg.MXID, mautrix.ReqRedact{ Reason: "The undecryptable message contained an unsupported message type", }) - existingMsg.UpdateMXID(nil, "net.maunium.whatsapp.fake::"+existingMsg.MXID, database.MsgFake, database.MsgNoError) + err = existingMsg.UpdateMXID(ctx, "net.maunium.whatsapp.fake::"+existingMsg.MXID, database.MsgFake, database.MsgNoError) + if err != nil { + log.Err(err).Msg("Failed to update message in database after finding undecryptable message was an unknown message") + } } return } @@ -1040,7 +1161,7 @@ func (portal *Portal) isRecentlyHandled(id types.MessageID, error database.Messa return false } -func (portal *Portal) markHandled(txn dbutil.Transaction, msg *database.Message, info *types.MessageInfo, mxid id.EventID, senderMXID id.UserID, isSent, recent bool, msgType database.MessageType, galleryPart int, errType database.MessageErrorType) *database.Message { +func (portal *Portal) markHandled(ctx context.Context, msg *database.Message, info *types.MessageInfo, mxid id.EventID, senderMXID id.UserID, isSent, recent bool, msgType database.MessageType, galleryPart int, errType database.MessageErrorType) *database.Message { if msg == nil { msg = portal.bridge.DB.Message.New() msg.Chat = portal.Key @@ -1056,9 +1177,15 @@ func (portal *Portal) markHandled(txn dbutil.Transaction, msg *database.Message, if info.IsIncomingBroadcast() { msg.BroadcastListJID = info.Chat } - msg.Insert(txn) + err := msg.Insert(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to insert message to database") + } } else { - msg.UpdateMXID(txn, mxid, msgType, errType) + err := msg.UpdateMXID(ctx, mxid, msgType, errType) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to update message in database") + } } if recent { @@ -1071,7 +1198,7 @@ func (portal *Portal) markHandled(txn dbutil.Transaction, msg *database.Message, return msg } -func (portal *Portal) getMessagePuppet(user *User, info *types.MessageInfo) (puppet *Puppet) { +func (portal *Portal) getMessagePuppet(ctx context.Context, user *User, info *types.MessageInfo) (puppet *Puppet) { if info.IsFromMe { return portal.bridge.GetPuppetByJID(user.JID) } else if portal.IsPrivateChat() { @@ -1080,46 +1207,45 @@ func (portal *Portal) getMessagePuppet(user *User, info *types.MessageInfo) (pup puppet = portal.bridge.GetPuppetByJID(info.Sender) } if puppet == nil { - portal.log.Warnfln("Message %+v doesn't seem to have a valid sender (%s): puppet is nil", *info, info.Sender) + zerolog.Ctx(ctx).Warn().Msg("Message doesn't seem to have a valid sender: puppet is nil") return nil } user.EnqueuePortalResync(portal) - puppet.SyncContact(user, true, true, "handling message") + puppet.SyncContact(ctx, user, true, true, "handling message") return puppet } -func (portal *Portal) getMessageIntent(user *User, info *types.MessageInfo, msgType string) *appservice.IntentAPI { +func (portal *Portal) getMessageIntent(ctx context.Context, user *User, info *types.MessageInfo) *appservice.IntentAPI { if portal.IsNewsletter() && info.Sender == info.Chat { return portal.MainIntent() } - puppet := portal.getMessagePuppet(user, info) + puppet := portal.getMessagePuppet(ctx, user, info) if puppet == nil { return nil } intent := puppet.IntentFor(portal) if !intent.IsCustomPuppet && portal.IsPrivateChat() && info.Sender.User == portal.Key.Receiver.User && portal.Key.Receiver != portal.Key.JID { - portal.log.Debugfln("Not handling %s (%s): user doesn't have double puppeting enabled", info.ID, msgType) + zerolog.Ctx(ctx).Debug().Msg("Not handling message: user doesn't have double puppeting enabled") return nil } return intent } -func (portal *Portal) finishHandling(existing *database.Message, message *types.MessageInfo, mxid id.EventID, senderMXID id.UserID, msgType database.MessageType, galleryPart int, errType database.MessageErrorType) { - portal.markHandled(nil, existing, message, mxid, senderMXID, true, true, msgType, galleryPart, errType) - portal.sendDeliveryReceipt(mxid) - var suffix string - if errType == database.MsgErrDecryptionFailed { - suffix = "(undecryptable message error notice)" - } else if errType == database.MsgErrMediaNotFound { - suffix = "(media not found notice)" +func (portal *Portal) finishHandling(ctx context.Context, existing *database.Message, message *types.MessageInfo, mxid id.EventID, senderMXID id.UserID, msgType database.MessageType, galleryPart int, errType database.MessageErrorType) { + portal.markHandled(ctx, existing, message, mxid, senderMXID, true, true, msgType, galleryPart, errType) + portal.sendDeliveryReceipt(ctx, mxid) + logEvt := zerolog.Ctx(ctx).Debug(). + Stringer("matrix_event_id", mxid) + if errType != database.MsgNoError { + logEvt.Str("error_type", string(errType)) } - portal.log.Debugfln("Handled message %s (%s) -> %s %s", message.ID, msgType, mxid, suffix) + logEvt.Msg("Successfully handled WhatsApp message") } -func (portal *Portal) kickExtraUsers(participantMap map[types.JID]bool) { - members, err := portal.MainIntent().JoinedMembers(portal.MXID) +func (portal *Portal) kickExtraUsers(ctx context.Context, participantMap map[types.JID]bool) { + members, err := portal.MainIntent().JoinedMembers(ctx, portal.MXID) if err != nil { - portal.log.Warnln("Failed to get member list:", err) + zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to get member list to kick extra users") return } for member := range members.Joined { @@ -1127,12 +1253,14 @@ func (portal *Portal) kickExtraUsers(participantMap map[types.JID]bool) { if ok { _, shouldBePresent := participantMap[jid] if !shouldBePresent { - _, err = portal.MainIntent().KickUser(portal.MXID, &mautrix.ReqKickUser{ + _, err = portal.MainIntent().KickUser(ctx, portal.MXID, &mautrix.ReqKickUser{ UserID: member, Reason: "User had left this WhatsApp chat", }) if err != nil { - portal.log.Warnfln("Failed to kick user %s who had left: %v", member, err) + zerolog.Ctx(ctx).Warn().Err(err). + Stringer("user_id", member). + Msg("Failed to kick extra user from room") } } } @@ -1154,28 +1282,34 @@ func (portal *Portal) kickExtraUsers(participantMap map[types.JID]bool) { // portal.kickExtraUsers(participantMap) //} -func (portal *Portal) syncParticipant(source *User, participant types.GroupParticipant, puppet *Puppet, user *User, wg *sync.WaitGroup) { +func (portal *Portal) syncParticipant(ctx context.Context, source *User, participant types.GroupParticipant, puppet *Puppet, user *User, wg *sync.WaitGroup) { defer func() { wg.Done() if err := recover(); err != nil { - portal.log.Errorfln("Syncing participant %s panicked: %v\n%s", participant.JID, err, debug.Stack()) + zerolog.Ctx(ctx).Error(). + Bytes(zerolog.ErrorStackFieldName, debug.Stack()). + Any(zerolog.ErrorFieldName, err). + Stringer("participant_jid", participant.JID). + Msg("Syncing participant panicked") } }() - puppet.SyncContact(source, true, false, "group participant") + puppet.SyncContact(ctx, source, true, false, "group participant") if portal.MXID != "" { if user != nil && user != source { - portal.ensureUserInvited(user) + portal.ensureUserInvited(ctx, user) } if user == nil || !puppet.IntentFor(portal).IsCustomPuppet { - err := puppet.IntentFor(portal).EnsureJoined(portal.MXID) + err := puppet.IntentFor(portal).EnsureJoined(ctx, portal.MXID) if err != nil { - portal.log.Warnfln("Failed to make puppet of %s join %s: %v", participant.JID, portal.MXID, err) + zerolog.Ctx(ctx).Warn().Err(err). + Stringer("participant_jid", participant.JID). + Msg("Failed to make ghost user join portal") } } } } -func (portal *Portal) SyncParticipants(source *User, metadata *types.GroupInfo) ([]id.UserID, *event.PowerLevelsEventContent) { +func (portal *Portal) SyncParticipants(ctx context.Context, source *User, metadata *types.GroupInfo) ([]id.UserID, *event.PowerLevelsEventContent) { if portal.IsNewsletter() { return nil, nil } @@ -1183,7 +1317,7 @@ func (portal *Portal) SyncParticipants(source *User, metadata *types.GroupInfo) var levels *event.PowerLevelsEventContent var err error if portal.MXID != "" { - levels, err = portal.MainIntent().PowerLevels(portal.MXID) + levels, err = portal.MainIntent().PowerLevels(ctx, portal.MXID) } if levels == nil || err != nil { levels = portal.GetBasePowerLevels() @@ -1194,20 +1328,24 @@ func (portal *Portal) SyncParticipants(source *User, metadata *types.GroupInfo) wg.Add(len(metadata.Participants)) participantMap := make(map[types.JID]bool) userIDs := make([]id.UserID, 0, len(metadata.Participants)) + log := zerolog.Ctx(ctx) for _, participant := range metadata.Participants { if participant.JID.IsEmpty() || participant.JID.Server != types.DefaultUserServer { wg.Done() // TODO handle lids continue } - portal.log.Debugfln("Syncing participant %s (admin: %t)", participant.JID, participant.IsAdmin) + log.Debug(). + Stringer("participant_jid", participant.JID). + Bool("is_admin", participant.IsAdmin). + Msg("Syncing participant") participantMap[participant.JID] = true puppet := portal.bridge.GetPuppetByJID(participant.JID) user := portal.bridge.GetUserByJID(participant.JID) if portal.bridge.Config.Bridge.ParallelMemberSync { - go portal.syncParticipant(source, participant, puppet, user, &wg) + go portal.syncParticipant(ctx, source, participant, puppet, user, &wg) } else { - portal.syncParticipant(source, participant, puppet, user, &wg) + portal.syncParticipant(ctx, source, participant, puppet, user, &wg) } expectedLevel := 0 @@ -1227,19 +1365,19 @@ func (portal *Portal) SyncParticipants(source *User, metadata *types.GroupInfo) } if portal.MXID != "" { if changed { - _, err = portal.MainIntent().SetPowerLevels(portal.MXID, levels) + _, err = portal.MainIntent().SetPowerLevels(ctx, portal.MXID, levels) if err != nil { - portal.log.Errorln("Failed to change power levels:", err) + log.Err(err).Msg("Failed to update power levels in room") } } - portal.kickExtraUsers(participantMap) + portal.kickExtraUsers(ctx, participantMap) } wg.Wait() - portal.log.Debugln("Participant sync completed") + log.Debug().Msg("Participant sync completed") return userIDs, levels } -func reuploadAvatar(intent *appservice.IntentAPI, url string) (id.ContentURI, error) { +func reuploadAvatar(ctx context.Context, intent *appservice.IntentAPI, url string) (id.ContentURI, error) { getResp, err := http.DefaultClient.Get(url) if err != nil { return id.ContentURI{}, fmt.Errorf("failed to download avatar: %w", err) @@ -1250,26 +1388,26 @@ func reuploadAvatar(intent *appservice.IntentAPI, url string) (id.ContentURI, er return id.ContentURI{}, fmt.Errorf("failed to read avatar bytes: %w", err) } - resp, err := intent.UploadBytes(data, http.DetectContentType(data)) + resp, err := intent.UploadBytes(ctx, data, http.DetectContentType(data)) if err != nil { return id.ContentURI{}, fmt.Errorf("failed to upload avatar to Matrix: %w", err) } return resp.ContentURI, nil } -func (user *User) reuploadAvatarDirectPath(intent *appservice.IntentAPI, directPath string) (id.ContentURI, error) { +func (user *User) reuploadAvatarDirectPath(ctx context.Context, intent *appservice.IntentAPI, directPath string) (id.ContentURI, error) { data, err := user.Client.DownloadMediaWithPath(directPath, nil, nil, nil, 0, "", "") if err != nil { return id.ContentURI{}, fmt.Errorf("failed to download avatar: %w", err) } - resp, err := intent.UploadBytes(data, http.DetectContentType(data)) + resp, err := intent.UploadBytes(ctx, data, http.DetectContentType(data)) if err != nil { return id.ContentURI{}, fmt.Errorf("failed to upload avatar to Matrix: %w", err) } return resp.ContentURI, nil } -func (user *User) updateAvatar(jid types.JID, isCommunity bool, avatarID *string, avatarURL *id.ContentURI, avatarSet *bool, log log.Logger, intent *appservice.IntentAPI) bool { +func (user *User) updateAvatar(ctx context.Context, jid types.JID, isCommunity bool, avatarID *string, avatarURL *id.ContentURI, avatarSet *bool, intent *appservice.IntentAPI) bool { currentID := "" if *avatarSet && *avatarID != "remove" && *avatarID != "unauthorized" { currentID = *avatarID @@ -1279,6 +1417,7 @@ func (user *User) updateAvatar(jid types.JID, isCommunity bool, avatarID *string ExistingID: currentID, IsCommunity: isCommunity, }) + log := zerolog.Ctx(ctx) if errors.Is(err, whatsmeow.ErrProfilePictureUnauthorized) { if *avatarID == "" { *avatarID = "unauthorized" @@ -1295,7 +1434,7 @@ func (user *User) updateAvatar(jid types.JID, isCommunity bool, avatarID *string *avatarURL = id.ContentURI{} return true } else if err != nil { - log.Warnln("Failed to get avatar URL:", err) + log.Err(err).Msg("Failed to get avatar URL") return false } else if avatar == nil { // Avatar hasn't changed @@ -1304,32 +1443,32 @@ func (user *User) updateAvatar(jid types.JID, isCommunity bool, avatarID *string if avatar.ID == *avatarID && *avatarSet { return false } else if len(avatar.URL) == 0 && len(avatar.DirectPath) == 0 { - log.Warnln("Didn't get URL in response to avatar query") + log.Warn().Msg("Didn't get URL in response to avatar query") return false } else if avatar.ID != *avatarID || avatarURL.IsEmpty() { var url id.ContentURI if len(avatar.URL) > 0 { - url, err = reuploadAvatar(intent, avatar.URL) + url, err = reuploadAvatar(ctx, intent, avatar.URL) if err != nil { - log.Warnln("Failed to reupload avatar:", err) + log.Err(err).Msg("Failed to reupload avatar") return false } } else { - url, err = user.reuploadAvatarDirectPath(intent, avatar.DirectPath) + url, err = user.reuploadAvatarDirectPath(ctx, intent, avatar.DirectPath) if err != nil { - log.Warnln("Failed to reupload avatar:", err) + log.Err(err).Msg("Failed to reupload avatar") return false } } *avatarURL = url } - log.Debugfln("Updated avatar %s -> %s", *avatarID, avatar.ID) + log.Debug().Str("old_avatar_id", *avatarID).Str("new_avatar_id", avatar.ID).Msg("Updated avatar") *avatarID = avatar.ID *avatarSet = false return true } -func (portal *Portal) UpdateNewsletterAvatar(user *User, meta *types.NewsletterMetadata) bool { +func (portal *Portal) UpdateNewsletterAvatar(ctx context.Context, user *User, meta *types.NewsletterMetadata) bool { portal.avatarLock.Lock() defer portal.avatarLock.Unlock() var picID string @@ -1342,51 +1481,56 @@ func (portal *Portal) UpdateNewsletterAvatar(user *User, meta *types.NewsletterM if picID == "" { picID = "remove" } - if portal.Avatar != picID || !portal.AvatarSet { - if picID == "remove" { - portal.AvatarURL = id.ContentURI{} - } else if portal.Avatar != picID || portal.AvatarURL.IsEmpty() { - var err error - if picture == nil { - meta, err = user.Client.GetNewsletterInfo(portal.Key.JID) - if err != nil { - portal.log.Warnln("Failed to fetch full res avatar info for newsletter:", err) - return false - } - picture = meta.ThreadMeta.Picture - if picture == nil { - portal.log.Warnln("Didn't get full res avatar info for newsletter") - return false - } - picID = picture.ID - } - portal.AvatarURL, err = user.reuploadAvatarDirectPath(portal.MainIntent(), picture.DirectPath) + if portal.Avatar == picID && portal.AvatarSet { + return false + } + log := zerolog.Ctx(ctx) + if picID == "remove" { + portal.AvatarURL = id.ContentURI{} + } else if portal.Avatar != picID || portal.AvatarURL.IsEmpty() { + var err error + if picture == nil { + meta, err = user.Client.GetNewsletterInfo(portal.Key.JID) if err != nil { - portal.log.Warnln("Failed to reupload newsletter avatar:", err) + log.Err(err).Msg("Failed to fetch full res avatar info for newsletter") return false } + picture = meta.ThreadMeta.Picture + if picture == nil { + log.Warn().Msg("Didn't get full res avatar info for newsletter") + return false + } + picID = picture.ID + } + portal.AvatarURL, err = user.reuploadAvatarDirectPath(ctx, portal.MainIntent(), picture.DirectPath) + if err != nil { + log.Err(err).Msg("Failed to reupload newsletter avatar") + return false } - portal.Avatar = picID - portal.AvatarSet = false - return portal.setRoomAvatar(true, types.EmptyJID, true) } - return false + portal.Avatar = picID + portal.AvatarSet = false + return portal.setRoomAvatar(ctx, true, types.EmptyJID, true) } -func (portal *Portal) UpdateAvatar(user *User, setBy types.JID, updateInfo bool) bool { +func (portal *Portal) UpdateAvatar(ctx context.Context, user *User, setBy types.JID, updateInfo bool) bool { if portal.IsNewsletter() { return false } portal.avatarLock.Lock() defer portal.avatarLock.Unlock() - changed := user.updateAvatar(portal.Key.JID, portal.IsParent, &portal.Avatar, &portal.AvatarURL, &portal.AvatarSet, portal.log, portal.MainIntent()) - return portal.setRoomAvatar(changed, setBy, updateInfo) + changed := user.updateAvatar(ctx, portal.Key.JID, portal.IsParent, &portal.Avatar, &portal.AvatarURL, &portal.AvatarSet, portal.MainIntent()) + return portal.setRoomAvatar(ctx, changed, setBy, updateInfo) } -func (portal *Portal) setRoomAvatar(changed bool, setBy types.JID, updateInfo bool) bool { +func (portal *Portal) setRoomAvatar(ctx context.Context, changed bool, setBy types.JID, updateInfo bool) bool { + log := zerolog.Ctx(ctx) if !changed || portal.Avatar == "unauthorized" { if changed || updateInfo { - portal.Update(nil) + err := portal.Update(ctx) + if err != nil { + log.Err(err).Msg("Failed to save portal in setRoomAvatar") + } } return changed } @@ -1396,89 +1540,109 @@ func (portal *Portal) setRoomAvatar(changed bool, setBy types.JID, updateInfo bo if !setBy.IsEmpty() && setBy.Server == types.DefaultUserServer { intent = portal.bridge.GetPuppetByJID(setBy).IntentFor(portal) } - _, err := intent.SetRoomAvatar(portal.MXID, portal.AvatarURL) + _, err := intent.SetRoomAvatar(ctx, portal.MXID, portal.AvatarURL) if errors.Is(err, mautrix.MForbidden) && intent != portal.MainIntent() { - _, err = portal.MainIntent().SetRoomAvatar(portal.MXID, portal.AvatarURL) + _, err = portal.MainIntent().SetRoomAvatar(ctx, portal.MXID, portal.AvatarURL) } if err != nil { - portal.log.Warnln("Failed to set room avatar:", err) + log.Err(err).Msg("Failed to set room avatar") return true } else { portal.AvatarSet = true } } if updateInfo { - portal.UpdateBridgeInfo() - portal.Update(nil) - portal.updateChildRooms() + portal.UpdateBridgeInfo(ctx) + err := portal.Update(ctx) + if err != nil { + log.Err(err).Msg("Failed to save portal in setRoomAvatar") + } + portal.updateChildRooms(ctx) } return true } -func (portal *Portal) UpdateName(name string, setBy types.JID, updateInfo bool) bool { +func (portal *Portal) UpdateName(ctx context.Context, name string, setBy types.JID, updateInfo bool) bool { if name == "" && portal.IsBroadcastList() { name = UnnamedBroadcastName } - if portal.Name != name || (!portal.NameSet && len(portal.MXID) > 0 && portal.shouldSetDMRoomMetadata()) { - portal.log.Debugfln("Updating name %q -> %q", portal.Name, name) - portal.Name = name - portal.NameSet = false + if portal.Name == name && (portal.NameSet || len(portal.MXID) == 0 || !portal.shouldSetDMRoomMetadata()) { + return false + } + log := zerolog.Ctx(ctx) + log.Debug().Str("old_name", portal.Name).Str("new_name", name).Msg("Updating room name") + portal.Name = name + portal.NameSet = false + if updateInfo { + defer func() { + err := portal.Update(ctx) + if err != nil { + log.Err(err).Msg("Failed to save portal after updating name") + } + }() + } + if len(portal.MXID) == 0 { + return true + } + if !portal.shouldSetDMRoomMetadata() { + // TODO only do this if updateInfo? + portal.UpdateBridgeInfo(ctx) + return true + } + intent := portal.MainIntent() + if !setBy.IsEmpty() && setBy.Server == types.DefaultUserServer { + intent = portal.bridge.GetPuppetByJID(setBy).IntentFor(portal) + } + _, err := intent.SetRoomName(ctx, portal.MXID, name) + if errors.Is(err, mautrix.MForbidden) && intent != portal.MainIntent() { + _, err = portal.MainIntent().SetRoomName(ctx, portal.MXID, name) + } + if err != nil { + log.Err(err).Msg("Failed to set room name") + } else { + portal.NameSet = true if updateInfo { - defer portal.Update(nil) - } - - if len(portal.MXID) > 0 && !portal.shouldSetDMRoomMetadata() { - portal.UpdateBridgeInfo() - } else if len(portal.MXID) > 0 { - intent := portal.MainIntent() - if !setBy.IsEmpty() && setBy.Server == types.DefaultUserServer { - intent = portal.bridge.GetPuppetByJID(setBy).IntentFor(portal) - } - _, err := intent.SetRoomName(portal.MXID, name) - if errors.Is(err, mautrix.MForbidden) && intent != portal.MainIntent() { - _, err = portal.MainIntent().SetRoomName(portal.MXID, name) - } - if err == nil { - portal.NameSet = true - if updateInfo { - portal.UpdateBridgeInfo() - portal.updateChildRooms() - } - return true - } else { - portal.log.Warnln("Failed to set room name:", err) - } + portal.UpdateBridgeInfo(ctx) + portal.updateChildRooms(ctx) } } - return false + return true } -func (portal *Portal) UpdateTopic(topic string, setBy types.JID, updateInfo bool) bool { - if portal.Topic != topic || !portal.TopicSet { - portal.log.Debugfln("Updating topic %q -> %q", portal.Topic, topic) - portal.Topic = topic - portal.TopicSet = false - - intent := portal.MainIntent() - if !setBy.IsEmpty() && setBy.Server == types.DefaultUserServer { - intent = portal.bridge.GetPuppetByJID(setBy).IntentFor(portal) - } - _, err := intent.SetRoomTopic(portal.MXID, topic) - if errors.Is(err, mautrix.MForbidden) && intent != portal.MainIntent() { - _, err = portal.MainIntent().SetRoomTopic(portal.MXID, topic) - } - if err == nil { - portal.TopicSet = true - if updateInfo { - portal.UpdateBridgeInfo() - portal.Update(nil) +func (portal *Portal) UpdateTopic(ctx context.Context, topic string, setBy types.JID, updateInfo bool) bool { + if portal.Topic == topic && (portal.TopicSet || len(portal.MXID) == 0) { + return false + } + log := zerolog.Ctx(ctx) + log.Debug().Str("old_topic", portal.Topic).Str("new_topic", topic).Msg("Updating topic") + portal.Topic = topic + portal.TopicSet = false + if updateInfo { + defer func() { + err := portal.Update(ctx) + if err != nil { + log.Err(err).Msg("Failed to save portal after updating topic") } - return true - } else { - portal.log.Warnln("Failed to set room topic:", err) + }() + } + + intent := portal.MainIntent() + if !setBy.IsEmpty() && setBy.Server == types.DefaultUserServer { + intent = portal.bridge.GetPuppetByJID(setBy).IntentFor(portal) + } + _, err := intent.SetRoomTopic(ctx, portal.MXID, topic) + if errors.Is(err, mautrix.MForbidden) && intent != portal.MainIntent() { + _, err = portal.MainIntent().SetRoomTopic(ctx, portal.MXID, topic) + } + if err != nil { + log.Err(err).Msg("Failed to set room topic") + } else { + portal.TopicSet = true + if updateInfo { + portal.UpdateBridgeInfo(ctx) } } - return false + return true } func newsletterToGroupInfo(meta *types.NewsletterMetadata) *types.GroupInfo { @@ -1496,34 +1660,40 @@ func newsletterToGroupInfo(meta *types.NewsletterMetadata) *types.GroupInfo { return &out } -func (portal *Portal) UpdateParentGroup(source *User, parent types.JID, updateInfo bool) bool { +func (portal *Portal) UpdateParentGroup(ctx context.Context, source *User, parent types.JID, updateInfo bool) bool { portal.parentGroupUpdateLock.Lock() defer portal.parentGroupUpdateLock.Unlock() if portal.ParentGroup != parent { - portal.log.Debugfln("Updating parent group %v -> %v", portal.ParentGroup, parent) - portal.updateCommunitySpace(source, false, false) + zerolog.Ctx(ctx).Debug(). + Stringer("old_parent_group", portal.ParentGroup). + Stringer("new_parent_group", parent). + Msg("Updating parent group") + portal.updateCommunitySpace(ctx, source, false, false) portal.ParentGroup = parent portal.parentPortal = nil portal.InSpace = false - portal.updateCommunitySpace(source, true, false) + portal.updateCommunitySpace(ctx, source, true, false) if updateInfo { - portal.UpdateBridgeInfo() - portal.Update(nil) + portal.UpdateBridgeInfo(ctx) + err := portal.Update(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after updating parent group") + } } return true } else if !portal.ParentGroup.IsEmpty() && !portal.InSpace { - return portal.updateCommunitySpace(source, true, updateInfo) + return portal.updateCommunitySpace(ctx, source, true, updateInfo) } return false } -func (portal *Portal) UpdateMetadata(user *User, groupInfo *types.GroupInfo, newsletterMetadata *types.NewsletterMetadata) bool { +func (portal *Portal) UpdateMetadata(ctx context.Context, user *User, groupInfo *types.GroupInfo, newsletterMetadata *types.NewsletterMetadata) bool { if portal.IsPrivateChat() { return false } else if portal.IsStatusBroadcastList() { update := false - update = portal.UpdateName(StatusBroadcastName, types.EmptyJID, false) || update - update = portal.UpdateTopic(StatusBroadcastTopic, types.EmptyJID, false) || update + update = portal.UpdateName(ctx, StatusBroadcastName, types.EmptyJID, false) || update + update = portal.UpdateTopic(ctx, StatusBroadcastTopic, types.EmptyJID, false) || update return update } else if portal.IsBroadcastList() { update := false @@ -1545,7 +1715,7 @@ func (portal *Portal) UpdateMetadata(user *User, groupInfo *types.GroupInfo, new var err error newsletterMetadata, err = user.Client.GetNewsletterInfo(portal.Key.JID) if err != nil { - portal.zlog.Err(err).Msg("Failed to get newsletter info") + zerolog.Ctx(ctx).Err(err).Msg("Failed to get newsletter info") return false } } @@ -1555,66 +1725,75 @@ func (portal *Portal) UpdateMetadata(user *User, groupInfo *types.GroupInfo, new var err error groupInfo, err = user.Client.GetGroupInfo(portal.Key.JID) if err != nil { - portal.log.Errorln("Failed to get group info:", err) + zerolog.Ctx(ctx).Err(err).Msg("Failed to get group info") return false } } - portal.SyncParticipants(user, groupInfo) + portal.SyncParticipants(ctx, user, groupInfo) update := false - update = portal.UpdateName(groupInfo.Name, groupInfo.NameSetBy, false) || update - update = portal.UpdateTopic(groupInfo.Topic, groupInfo.TopicSetBy, false) || update - update = portal.UpdateParentGroup(user, groupInfo.LinkedParentJID, false) || update + update = portal.UpdateName(ctx, groupInfo.Name, groupInfo.NameSetBy, false) || update + update = portal.UpdateTopic(ctx, groupInfo.Topic, groupInfo.TopicSetBy, false) || update + update = portal.UpdateParentGroup(ctx, user, groupInfo.LinkedParentJID, false) || update if portal.ExpirationTime != groupInfo.DisappearingTimer { update = true portal.ExpirationTime = groupInfo.DisappearingTimer } if portal.IsParent != groupInfo.IsParent { if portal.MXID != "" { - portal.log.Warnfln("Existing group changed is_parent from %t to %t", portal.IsParent, groupInfo.IsParent) + zerolog.Ctx(ctx).Warn().Bool("new_is_parent", groupInfo.IsParent).Msg("Existing group changed is_parent status") } portal.IsParent = groupInfo.IsParent update = true } - portal.RestrictMessageSending(groupInfo.IsAnnounce) - portal.RestrictMetadataChanges(groupInfo.IsLocked) + portal.RestrictMessageSending(ctx, groupInfo.IsAnnounce) + portal.RestrictMetadataChanges(ctx, groupInfo.IsLocked) if newsletterMetadata != nil && newsletterMetadata.ViewerMeta != nil { - portal.PromoteNewsletterUser(user, newsletterMetadata.ViewerMeta.Role) + portal.PromoteNewsletterUser(ctx, user, newsletterMetadata.ViewerMeta.Role) } return update } -func (portal *Portal) ensureUserInvited(user *User) bool { - return user.ensureInvited(portal.MainIntent(), portal.MXID, portal.IsPrivateChat()) +func (portal *Portal) ensureUserInvited(ctx context.Context, user *User) bool { + return user.ensureInvited(ctx, portal.MainIntent(), portal.MXID, portal.IsPrivateChat()) } -func (portal *Portal) UpdateMatrixRoom(user *User, groupInfo *types.GroupInfo, newsletterMetadata *types.NewsletterMetadata) bool { +func (portal *Portal) UpdateMatrixRoom(ctx context.Context, user *User, groupInfo *types.GroupInfo, newsletterMetadata *types.NewsletterMetadata) bool { if len(portal.MXID) == 0 { return false } - portal.log.Infoln("Syncing portal for", user.MXID) + log := zerolog.Ctx(ctx).With(). + Str("action", "update matrix room"). + Str("portal_key", portal.Key.String()). + Stringer("source_mxid", user.MXID). + Logger() + ctx = log.WithContext(ctx) + log.Info().Msg("Syncing portal") - portal.ensureUserInvited(user) - go portal.addToPersonalSpace(user) + portal.ensureUserInvited(ctx, user) + go portal.addToPersonalSpace(ctx, user) if groupInfo == nil && newsletterMetadata != nil { groupInfo = newsletterToGroupInfo(newsletterMetadata) } update := false - update = portal.UpdateMetadata(user, groupInfo, newsletterMetadata) || update + update = portal.UpdateMetadata(ctx, user, groupInfo, newsletterMetadata) || update if !portal.IsPrivateChat() && !portal.IsBroadcastList() && !portal.IsNewsletter() { - update = portal.UpdateAvatar(user, types.EmptyJID, false) || update + update = portal.UpdateAvatar(ctx, user, types.EmptyJID, false) || update } else if newsletterMetadata != nil { - update = portal.UpdateNewsletterAvatar(user, newsletterMetadata) || update + update = portal.UpdateNewsletterAvatar(ctx, user, newsletterMetadata) || update } if update || portal.LastSync.Add(24*time.Hour).Before(time.Now()) { portal.LastSync = time.Now() - portal.Update(nil) - portal.UpdateBridgeInfo() - portal.updateChildRooms() + err := portal.Update(ctx) + if err != nil { + log.Err(err).Msg("Failed to save portal after updating") + } + portal.UpdateBridgeInfo(ctx) + portal.updateChildRooms(ctx) } return true } @@ -1655,8 +1834,8 @@ func (portal *Portal) applyPowerLevelFixes(levels *event.PowerLevelsEventContent return changed } -func (portal *Portal) ChangeAdminStatus(jids []types.JID, setAdmin bool) id.EventID { - levels, err := portal.MainIntent().PowerLevels(portal.MXID) +func (portal *Portal) ChangeAdminStatus(ctx context.Context, jids []types.JID, setAdmin bool) id.EventID { + levels, err := portal.MainIntent().PowerLevels(ctx, portal.MXID) if err != nil { levels = portal.GetBasePowerLevels() } @@ -1679,9 +1858,9 @@ func (portal *Portal) ChangeAdminStatus(jids []types.JID, setAdmin bool) id.Even } } if changed { - resp, err := portal.MainIntent().SetPowerLevels(portal.MXID, levels) + resp, err := portal.MainIntent().SetPowerLevels(ctx, portal.MXID, levels) if err != nil { - portal.log.Errorln("Failed to change power levels:", err) + zerolog.Ctx(ctx).Err(err).Msg("Failed to set power levels") } else { return resp.EventID } @@ -1689,8 +1868,8 @@ func (portal *Portal) ChangeAdminStatus(jids []types.JID, setAdmin bool) id.Even return "" } -func (portal *Portal) RestrictMessageSending(restrict bool) id.EventID { - levels, err := portal.MainIntent().PowerLevels(portal.MXID) +func (portal *Portal) RestrictMessageSending(ctx context.Context, restrict bool) id.EventID { + levels, err := portal.MainIntent().PowerLevels(ctx, portal.MXID) if err != nil { levels = portal.GetBasePowerLevels() } @@ -1706,17 +1885,17 @@ func (portal *Portal) RestrictMessageSending(restrict bool) id.EventID { } levels.EventsDefault = newLevel - resp, err := portal.MainIntent().SetPowerLevels(portal.MXID, levels) + resp, err := portal.MainIntent().SetPowerLevels(ctx, portal.MXID, levels) if err != nil { - portal.log.Errorln("Failed to change power levels:", err) + zerolog.Ctx(ctx).Err(err).Msg("Failed to set power levels") return "" } else { return resp.EventID } } -func (portal *Portal) PromoteNewsletterUser(user *User, role types.NewsletterRole) id.EventID { - levels, err := portal.MainIntent().PowerLevels(portal.MXID) +func (portal *Portal) PromoteNewsletterUser(ctx context.Context, user *User, role types.NewsletterRole) id.EventID { + levels, err := portal.MainIntent().PowerLevels(ctx, portal.MXID) if err != nil { levels = portal.GetBasePowerLevels() } @@ -1735,17 +1914,17 @@ func (portal *Portal) PromoteNewsletterUser(user *User, role types.NewsletterRol return "" } - resp, err := portal.MainIntent().SetPowerLevels(portal.MXID, levels) + resp, err := portal.MainIntent().SetPowerLevels(ctx, portal.MXID, levels) if err != nil { - portal.log.Errorln("Failed to change power levels:", err) + zerolog.Ctx(ctx).Err(err).Msg("Failed to set power levels") return "" } else { return resp.EventID } } -func (portal *Portal) RestrictMetadataChanges(restrict bool) id.EventID { - levels, err := portal.MainIntent().PowerLevels(portal.MXID) +func (portal *Portal) RestrictMetadataChanges(ctx context.Context, restrict bool) id.EventID { + levels, err := portal.MainIntent().PowerLevels(ctx, portal.MXID) if err != nil { levels = portal.GetBasePowerLevels() } @@ -1758,9 +1937,9 @@ func (portal *Portal) RestrictMetadataChanges(restrict bool) id.EventID { changed = levels.EnsureEventLevel(event.StateRoomAvatar, newLevel) || changed changed = levels.EnsureEventLevel(event.StateTopic, newLevel) || changed if changed { - resp, err := portal.MainIntent().SetPowerLevels(portal.MXID, levels) + resp, err := portal.MainIntent().SetPowerLevels(ctx, portal.MXID, levels) if err != nil { - portal.log.Errorln("Failed to change power levels:", err) + zerolog.Ctx(ctx).Err(err).Msg("Failed to set power levels") } else { return resp.EventID } @@ -1798,34 +1977,40 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { return portal.getBridgeInfoStateKey(), bridgeInfo } -func (portal *Portal) UpdateBridgeInfo() { +func (portal *Portal) UpdateBridgeInfo(ctx context.Context) { + log := zerolog.Ctx(ctx) if len(portal.MXID) == 0 { - portal.log.Debugln("Not updating bridge info: no Matrix room created") + log.Debug().Msg("Not updating bridge info: no Matrix room created") return } - portal.log.Debugln("Updating bridge info...") + log.Debug().Msg("Updating bridge info...") stateKey, content := portal.getBridgeInfo() - _, err := portal.MainIntent().SendStateEvent(portal.MXID, event.StateBridge, stateKey, content) + _, err := portal.MainIntent().SendStateEvent(ctx, portal.MXID, event.StateBridge, stateKey, content) if err != nil { - portal.log.Warnln("Failed to update m.bridge:", err) + log.Warn().Err(err).Msg("Failed to update m.bridge info") } // TODO remove this once https://github.com/matrix-org/matrix-doc/pull/2346 is in spec - _, err = portal.MainIntent().SendStateEvent(portal.MXID, event.StateHalfShotBridge, stateKey, content) + _, err = portal.MainIntent().SendStateEvent(ctx, portal.MXID, event.StateHalfShotBridge, stateKey, content) if err != nil { - portal.log.Warnln("Failed to update uk.half-shot.bridge:", err) + log.Warn().Err(err).Msg("Failed to update uk.half-shot.bridge info") } } -func (portal *Portal) updateChildRooms() { +func (portal *Portal) updateChildRooms(ctx context.Context) { if !portal.IsParent { return } children := portal.bridge.GetAllByParentGroup(portal.Key.JID) for _, child := range children { - changed := child.updateCommunitySpace(nil, true, false) - child.UpdateBridgeInfo() + changed := child.updateCommunitySpace(ctx, nil, true, false) + // TODO set updateInfo to true instead of updating manually? + child.UpdateBridgeInfo(ctx) if changed { - portal.Update(nil) + // TODO is this saving the wrong portal? + err := portal.Update(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after updating") + } } } } @@ -1845,31 +2030,37 @@ func (portal *Portal) GetEncryptionEventContent() (evt *event.EncryptionEventCon return } -func (portal *Portal) CreateMatrixRoom(user *User, groupInfo *types.GroupInfo, newsletterMetadata *types.NewsletterMetadata, isFullInfo, backfill bool) error { +func (portal *Portal) CreateMatrixRoom(ctx context.Context, user *User, groupInfo *types.GroupInfo, newsletterMetadata *types.NewsletterMetadata, isFullInfo, backfill bool) error { portal.roomCreateLock.Lock() defer portal.roomCreateLock.Unlock() if len(portal.MXID) > 0 { return nil } + log := zerolog.Ctx(ctx).With(). + Str("action", "create matrix room"). + Str("portal_key", portal.Key.String()). + Stringer("source_mxid", user.MXID). + Logger() + ctx = log.WithContext(ctx) intent := portal.MainIntent() - if err := intent.EnsureRegistered(); err != nil { + if err := intent.EnsureRegistered(ctx); err != nil { return err } - portal.log.Infoln("Creating Matrix room. Info source:", user.MXID) + log.Info().Msg("Creating Matrix room") //var broadcastMetadata *types.BroadcastListInfo if portal.IsPrivateChat() { puppet := portal.bridge.GetPuppetByJID(portal.Key.JID) - puppet.SyncContact(user, true, false, "creating private chat portal") + puppet.SyncContact(ctx, user, true, false, "creating private chat portal") portal.Name = puppet.Displayname portal.AvatarURL = puppet.AvatarURL portal.Avatar = puppet.Avatar portal.Topic = PrivateChatTopic } else if portal.IsStatusBroadcastList() { if !portal.bridge.Config.Bridge.EnableStatusBroadcast { - portal.log.Debugln("Status bridging is disabled in config, not creating room after all") + log.Debug().Msg("Status bridging is disabled in config, not creating room after all") return ErrStatusBroadcastDisabled } portal.Name = StatusBroadcastName @@ -1889,7 +2080,7 @@ func (portal *Portal) CreateMatrixRoom(user *User, groupInfo *types.GroupInfo, n // portal.Name = UnnamedBroadcastName //} //portal.Topic = BroadcastTopic - portal.log.Debugln("Broadcast list is not yet supported, not creating room after all") + log.Debug().Msg("Broadcast list is not yet supported, not creating room after all") return fmt.Errorf("broadcast list bridging is currently not supported") } else { if portal.IsNewsletter() { @@ -1909,12 +2100,18 @@ func (portal *Portal) CreateMatrixRoom(user *User, groupInfo *types.GroupInfo, n // Ensure that the user is actually a participant in the conversation // before creating the matrix room if errors.Is(err, whatsmeow.ErrNotInGroup) { - user.log.Debugfln("Skipping creating matrix room for %s because the user is not a participant", portal.Key.JID) - user.bridge.DB.Backfill.DeleteAllForPortal(user.MXID, portal.Key) - user.bridge.DB.HistorySync.DeleteAllMessagesForPortal(user.MXID, portal.Key) + log.Debug().Msg("Skipping creating room because the user is not a participant") + err = user.bridge.DB.BackfillQueue.DeleteAllForPortal(ctx, user.MXID, portal.Key) + if err != nil { + log.Err(err).Msg("Failed to delete backfill queue for portal") + } + err = user.bridge.DB.HistorySync.DeleteAllMessagesForPortal(ctx, user.MXID, portal.Key) + if err != nil { + log.Err(err).Msg("Failed to delete historical messages for portal") + } return err } else if err != nil { - portal.log.Warnfln("Failed to get group info through %s: %v", user.JID, err) + log.Err(err).Msg("Failed to get group info") } else { groupInfo = foundInfo isFullInfo = true @@ -1930,9 +2127,9 @@ func (portal *Portal) CreateMatrixRoom(user *User, groupInfo *types.GroupInfo, n } } if portal.IsNewsletter() { - portal.UpdateNewsletterAvatar(user, newsletterMetadata) + portal.UpdateNewsletterAvatar(ctx, user, newsletterMetadata) } else { - portal.UpdateAvatar(user, types.EmptyJID, false) + portal.UpdateAvatar(ctx, user, types.EmptyJID, false) } } @@ -2019,10 +2216,10 @@ func (portal *Portal) CreateMatrixRoom(user *User, groupInfo *types.GroupInfo, n } autoJoinInvites := portal.bridge.SpecVersions.Supports(mautrix.BeeperFeatureAutojoinInvites) if autoJoinInvites { - portal.log.Debugfln("Hungryserv mode: adding all group members in create request") + log.Debug().Msg("Hungryserv mode: adding all group members in create request") if groupInfo != nil && !portal.IsNewsletter() { // TODO non-hungryserv could also include all members in invites, and then send joins manually? - participants, powerLevels := portal.SyncParticipants(user, groupInfo) + participants, powerLevels := portal.SyncParticipants(ctx, user, groupInfo) invite = append(invite, participants...) if initialState[0].Type != event.StatePowerLevels { panic(fmt.Errorf("unexpected type %s in first initial state event", initialState[0].Type.Type)) @@ -2057,19 +2254,23 @@ func (portal *Portal) CreateMatrixRoom(user *User, groupInfo *types.GroupInfo, n } }() } - resp, err := intent.CreateRoom(req) + resp, err := intent.CreateRoom(ctx, req) if err != nil { return err } - portal.log.Infoln("Matrix room created:", resp.RoomID) + log.Info().Stringer("room_id", resp.RoomID).Msg("Matrix room created") portal.InSpace = false portal.NameSet = len(req.Name) > 0 portal.TopicSet = len(req.Topic) > 0 portal.MXID = resp.RoomID + portal.updateLogger() portal.bridge.portalsLock.Lock() portal.bridge.portalsByMXID[portal.MXID] = portal portal.bridge.portalsLock.Unlock() - portal.Update(nil) + err = portal.Update(ctx) + if err != nil { + log.Err(err).Msg("Failed to save portal after creating room") + } // We set the memberships beforehand to make sure the encryption key exchange in initial backfill knows the users are here. inviteMembership := event.MembershipInvite @@ -2077,19 +2278,22 @@ func (portal *Portal) CreateMatrixRoom(user *User, groupInfo *types.GroupInfo, n inviteMembership = event.MembershipJoin } for _, userID := range invite { - portal.bridge.StateStore.SetMembership(portal.MXID, userID, inviteMembership) + err = portal.bridge.StateStore.SetMembership(ctx, portal.MXID, userID, inviteMembership) + if err != nil { + log.Err(err).Stringer("user_id", userID).Msg("Failed to update membership in state store") + } } if !autoJoinInvites { - portal.ensureUserInvited(user) + portal.ensureUserInvited(ctx, user) } - user.syncChatDoublePuppetDetails(portal, true) + user.syncChatDoublePuppetDetails(ctx, portal, true) - go portal.updateCommunitySpace(user, true, true) - go portal.addToPersonalSpace(user) + go portal.updateCommunitySpace(ctx, user, true, true) + go portal.addToPersonalSpace(ctx, user) if !portal.IsNewsletter() && groupInfo != nil && !autoJoinInvites { - portal.SyncParticipants(user, groupInfo) + portal.SyncParticipants(ctx, user, groupInfo) } //if broadcastMetadata != nil { // portal.SyncBroadcastRecipients(user, broadcastMetadata) @@ -2098,66 +2302,60 @@ func (portal *Portal) CreateMatrixRoom(user *User, groupInfo *types.GroupInfo, n puppet := user.bridge.GetPuppetByJID(portal.Key.JID) if portal.bridge.Config.Bridge.Encryption.Default { - err = portal.bridge.Bot.EnsureJoined(portal.MXID) + err = portal.bridge.Bot.EnsureJoined(ctx, portal.MXID) if err != nil { - portal.log.Errorln("Failed to join created portal with bridge bot for e2be:", err) + log.Err(err).Msg("Failed to ensure bridge bot is joined to created portal") } } - user.UpdateDirectChats(map[id.UserID][]id.RoomID{puppet.MXID: {portal.MXID}}) + user.UpdateDirectChats(ctx, map[id.UserID][]id.RoomID{puppet.MXID: {portal.MXID}}) } else if portal.IsParent { - portal.updateChildRooms() - } - - firstEventResp, err := portal.MainIntent().SendMessageEvent(portal.MXID, PortalCreationDummyEvent, struct{}{}) - if err != nil { - portal.log.Errorln("Failed to send dummy event to mark portal creation:", err) - } else { - portal.FirstEventID = firstEventResp.EventID - portal.Update(nil) + portal.updateChildRooms(ctx) } if user.bridge.Config.Bridge.HistorySync.Backfill && backfill { if legacyBackfill { backfillStarted = true - go portal.legacyBackfill(user) + go portal.legacyBackfill(context.WithoutCancel(ctx), user) } else { portals := []*Portal{portal} - user.EnqueueImmediateBackfills(portals) - user.EnqueueDeferredBackfills(portals) + user.EnqueueImmediateBackfills(ctx, portals) + user.EnqueueDeferredBackfills(ctx, portals) user.BackfillQueue.ReCheck() } } return nil } -func (portal *Portal) addToPersonalSpace(user *User) { - spaceID := user.GetSpaceRoom() - if len(spaceID) == 0 || user.IsInSpace(portal.Key) { +func (portal *Portal) addToPersonalSpace(ctx context.Context, user *User) { + spaceID := user.GetSpaceRoom(ctx) + if len(spaceID) == 0 || user.IsInSpace(ctx, portal.Key) { return } - _, err := portal.bridge.Bot.SendStateEvent(spaceID, event.StateSpaceChild, portal.MXID.String(), &event.SpaceChildEventContent{ + _, err := portal.bridge.Bot.SendStateEvent(ctx, spaceID, event.StateSpaceChild, portal.MXID.String(), &event.SpaceChildEventContent{ Via: []string{portal.bridge.Config.Homeserver.Domain}, }) if err != nil { - portal.log.Errorfln("Failed to add room to %s's personal filtering space (%s): %v", user.MXID, spaceID, err) + zerolog.Ctx(ctx).Err(err).Stringer("space_id", spaceID).Msg("Failed to add portal to user's personal filtering space") } else { - portal.log.Debugfln("Added room to %s's personal filtering space (%s)", user.MXID, spaceID) - user.MarkInSpace(portal.Key) + zerolog.Ctx(ctx).Debug().Stringer("space_id", spaceID).Msg("Added portal to user's personal filtering space") + user.MarkInSpace(ctx, portal.Key) } } func (portal *Portal) removeSpaceParentEvent(space id.RoomID) { - _, err := portal.MainIntent().SendStateEvent(portal.MXID, event.StateSpaceParent, space.String(), &event.SpaceParentEventContent{}) + _, err := portal.MainIntent().SendStateEvent(context.TODO(), portal.MXID, event.StateSpaceParent, space.String(), &event.SpaceParentEventContent{}) if err != nil { - portal.log.Warnfln("Failed to send m.space.parent event to remove portal from %s: %v", space, err) + portal.zlog.Err(err).Stringer("space_mxid", space).Msg("Failed to send m.space.parent event to remove portal from space") } } -func (portal *Portal) updateCommunitySpace(user *User, add, updateInfo bool) bool { +func (portal *Portal) updateCommunitySpace(ctx context.Context, user *User, add, updateInfo bool) bool { if add == portal.InSpace { return false } + // TODO if this function is changed to use the context logger, updateChildRooms should add the child portal info to the logger + log := portal.zlog.With().Stringer("room_id", portal.MXID).Logger() space := portal.GetParentPortal() if space == nil { return false @@ -2165,41 +2363,47 @@ func (portal *Portal) updateCommunitySpace(user *User, add, updateInfo bool) boo if !add || user == nil { return false } - portal.log.Debugfln("Creating portal for parent group %v", space.Key.JID) - err := space.CreateMatrixRoom(user, nil, nil, false, false) + log.Debug().Stringer("parent_group_jid", space.Key.JID).Msg("Creating portal for parent group") + err := space.CreateMatrixRoom(ctx, user, nil, nil, false, false) if err != nil { - portal.log.Debugfln("Failed to create portal for parent group: %v", err) + log.Err(err).Msg("Failed to create portal for parent group") return false } } - var action string var parentContent event.SpaceParentEventContent var childContent event.SpaceChildEventContent if add { parentContent.Canonical = true parentContent.Via = []string{portal.bridge.Config.Homeserver.Domain} childContent.Via = []string{portal.bridge.Config.Homeserver.Domain} - action = "add portal to" - portal.log.Debugfln("Adding %s to space %s (%s)", portal.MXID, space.MXID, space.Key.JID) + log.Debug(). + Stringer("space_mxid", space.MXID). + Stringer("parent_group_jid", space.Key.JID). + Msg("Adding room to parent group space") } else { - action = "remove portal from" - portal.log.Debugfln("Removing %s from space %s (%s)", portal.MXID, space.MXID, space.Key.JID) + log.Debug(). + Stringer("space_mxid", space.MXID). + Stringer("parent_group_jid", space.Key.JID). + Msg("Removing room from parent group space") } - _, err := space.MainIntent().SendStateEvent(space.MXID, event.StateSpaceChild, portal.MXID.String(), &childContent) + _, err := space.MainIntent().SendStateEvent(ctx, space.MXID, event.StateSpaceChild, portal.MXID.String(), &childContent) if err != nil { - portal.log.Errorfln("Failed to send m.space.child event to %s %s: %v", action, space.MXID, err) + log.Err(err).Stringer("space_mxid", space.MXID).Msg("Failed to send m.space.child event") return false } - _, err = portal.MainIntent().SendStateEvent(portal.MXID, event.StateSpaceParent, space.MXID.String(), &parentContent) + _, err = portal.MainIntent().SendStateEvent(ctx, portal.MXID, event.StateSpaceParent, space.MXID.String(), &parentContent) if err != nil { - portal.log.Warnfln("Failed to send m.space.parent event to %s %s: %v", action, space.MXID, err) + log.Err(err).Stringer("space_mxid", space.MXID).Msg("Failed to send m.space.parent event") } portal.InSpace = add if updateInfo { - portal.Update(nil) - portal.UpdateBridgeInfo() + err = portal.Update(ctx) + if err != nil { + log.Err(err).Msg("Failed to save portal after updating parent space") + } + portal.UpdateBridgeInfo(ctx) } return true } @@ -2271,12 +2475,11 @@ func (portal *Portal) addReplyMention(content *event.MessageEventContent, sender } } -func (portal *Portal) SetReply(msgID string, content *event.MessageEventContent, replyTo *ReplyInfo, isHungryBackfill bool) bool { +func (portal *Portal) SetReply(ctx context.Context, content *event.MessageEventContent, replyTo *ReplyInfo, isHungryBackfill bool) bool { if replyTo == nil { return false } - log := portal.zlog.With(). - Str("message_id", msgID). + log := zerolog.Ctx(ctx).With(). Object("reply_to", replyTo). Str("action", "SetReply"). Logger() @@ -2300,8 +2503,11 @@ func (portal *Portal) SetReply(msgID string, content *event.MessageEventContent, } } } - message := portal.bridge.DB.Message.GetByJID(key, replyTo.MessageID) - if message == nil || message.IsFakeMXID() { + message, err := portal.bridge.DB.Message.GetByJID(ctx, key, replyTo.MessageID) + if err != nil { + log.Err(err).Msg("Failed to get reply target from database") + return false + } else if message == nil || message.IsFakeMXID() { if isHungryBackfill { content.RelatesTo = (&event.RelatesTo{}).SetReplyTo(targetPortal.deterministicEventID(replyTo.Sender, replyTo.MessageID, "")) portal.addReplyMention(content, replyTo.Sender, "") @@ -2316,14 +2522,14 @@ func (portal *Portal) SetReply(msgID string, content *event.MessageEventContent, if portal.bridge.Config.Bridge.DisableReplyFallbacks { return true } - evt, err := targetPortal.MainIntent().GetEvent(targetPortal.MXID, message.MXID) + evt, err := targetPortal.MainIntent().GetEvent(ctx, targetPortal.MXID, message.MXID) if err != nil { log.Warn().Err(err).Msg("Failed to get reply target event") return true } _ = evt.Content.ParseRaw(evt.Type) if evt.Type == event.EventEncrypted { - decryptedEvt, err := portal.bridge.Crypto.Decrypt(evt) + decryptedEvt, err := portal.bridge.Crypto.Decrypt(ctx, evt) if err != nil { log.Warn().Err(err).Msg("Failed to decrypt reply target event") } else { @@ -2334,31 +2540,45 @@ func (portal *Portal) SetReply(msgID string, content *event.MessageEventContent, return true } -func (portal *Portal) HandleMessageReaction(intent *appservice.IntentAPI, user *User, info *types.MessageInfo, reaction *waProto.ReactionMessage, existingMsg *database.Message) { +func (portal *Portal) HandleMessageReaction(ctx context.Context, intent *appservice.IntentAPI, user *User, info *types.MessageInfo, reaction *waProto.ReactionMessage, existingMsg *database.Message) { if existingMsg != nil { - _, _ = portal.MainIntent().RedactEvent(portal.MXID, existingMsg.MXID, mautrix.ReqRedact{ + _, _ = portal.MainIntent().RedactEvent(ctx, portal.MXID, existingMsg.MXID, mautrix.ReqRedact{ Reason: "The undecryptable message was actually a reaction", }) } targetJID := reaction.GetKey().GetId() + log := zerolog.Ctx(ctx).With(). + Str("reaction_target_id", targetJID). + Logger() if reaction.GetText() == "" { - existing := portal.bridge.DB.Reaction.GetByTargetJID(portal.Key, targetJID, info.Sender) - if existing == nil { - portal.log.Debugfln("Dropping removal %s of unknown reaction to %s from %s", info.ID, targetJID, info.Sender) + existing, err := portal.bridge.DB.Reaction.GetByTargetJID(ctx, portal.Key, targetJID, info.Sender) + if err != nil { + log.Err(err).Msg("Failed to get existing reaction to remove") + return + } else if existing == nil { + log.Debug().Msg("Dropping removal of unknown reaction") return } - resp, err := intent.RedactEvent(portal.MXID, existing.MXID) + resp, err := intent.RedactEvent(ctx, portal.MXID, existing.MXID) if err != nil { - portal.log.Errorfln("Failed to redact reaction %s/%s from %s to %s: %v", existing.MXID, existing.JID, info.Sender, targetJID, err) + log.Err(err). + Stringer("reaction_mxid", existing.MXID). + Msg("Failed to redact reaction") + } + portal.finishHandling(ctx, existingMsg, info, resp.EventID, intent.UserID, database.MsgReaction, 0, database.MsgNoError) + err = existing.Delete(ctx) + if err != nil { + log.Err(err).Msg("Failed to delete reaction from database") } - portal.finishHandling(existingMsg, info, resp.EventID, intent.UserID, database.MsgReaction, 0, database.MsgNoError) - existing.Delete() } else { - target := portal.bridge.DB.Message.GetByJID(portal.Key, targetJID) - if target == nil { - portal.log.Debugfln("Dropping reaction %s from %s to unknown message %s", info.ID, info.Sender, targetJID) + target, err := portal.bridge.DB.Message.GetByJID(ctx, portal.Key, targetJID) + if err != nil { + log.Err(err).Msg("Failed to get reaction target message from database") + return + } else if target == nil { + log.Debug().Msg("Dropping reaction to unknown message") return } @@ -2368,64 +2588,66 @@ func (portal *Portal) HandleMessageReaction(intent *appservice.IntentAPI, user * EventID: target.MXID, Key: variationselector.Add(reaction.GetText()), } - resp, err := intent.SendMassagedMessageEvent(portal.MXID, event.EventReaction, &content, info.Timestamp.UnixMilli()) + resp, err := intent.SendMassagedMessageEvent(ctx, portal.MXID, event.EventReaction, &content, info.Timestamp.UnixMilli()) if err != nil { - portal.log.Errorfln("Failed to bridge reaction %s from %s to %s: %v", info.ID, info.Sender, target.JID, err) + log.Err(err).Msg("Failed to bridge reaction") return } - portal.finishHandling(existingMsg, info, resp.EventID, intent.UserID, database.MsgReaction, 0, database.MsgNoError) - portal.upsertReaction(nil, intent, target.JID, info.Sender, resp.EventID, info.ID) + portal.finishHandling(ctx, existingMsg, info, resp.EventID, intent.UserID, database.MsgReaction, 0, database.MsgNoError) + portal.upsertReaction(ctx, intent, target.JID, info.Sender, resp.EventID, info.ID) } } -func (portal *Portal) HandleMessageRevoke(user *User, info *types.MessageInfo, key *waProto.MessageKey) bool { - msg := portal.bridge.DB.Message.GetByJID(portal.Key, key.GetId()) - if msg == nil || msg.IsFakeMXID() { +func (portal *Portal) HandleMessageRevoke(ctx context.Context, user *User, info *types.MessageInfo, key *waProto.MessageKey) bool { + log := zerolog.Ctx(ctx).With().Str("revoke_target_id", key.GetId()).Logger() + msg, err := portal.bridge.DB.Message.GetByJID(ctx, portal.Key, key.GetId()) + if err != nil { + log.Err(err).Msg("Failed to get revoke target message from database") + return false + } else if msg == nil || msg.IsFakeMXID() { return false } intent := portal.bridge.GetPuppetByJID(info.Sender).IntentFor(portal) - _, err := intent.RedactEvent(portal.MXID, msg.MXID) + _, err = intent.RedactEvent(ctx, portal.MXID, msg.MXID) + if errors.Is(err, mautrix.MForbidden) { + _, err = portal.MainIntent().RedactEvent(ctx, portal.MXID, msg.MXID) + } if err != nil { - if errors.Is(err, mautrix.MForbidden) { - _, err = portal.MainIntent().RedactEvent(portal.MXID, msg.MXID) - if err != nil { - portal.log.Errorln("Failed to redact %s: %v", msg.JID, err) - } - } - } else { - msg.Delete() + log.Err(err).Stringer("revoke_target_mxid", msg.MXID).Msg("Failed to redact message from revoke") + } else if err = msg.Delete(ctx); err != nil { + log.Err(err).Msg("Failed to delete message from database after revoke") } return true } -func (portal *Portal) deleteForMe(user *User, content *events.DeleteForMe) bool { - matrixUsers, err := portal.GetMatrixUsers() +func (portal *Portal) deleteForMe(ctx context.Context, user *User, content *events.DeleteForMe) bool { + matrixUsers, err := portal.GetMatrixUsers(ctx) if err != nil { - portal.log.Errorln("Failed to get Matrix users in portal to see if DeleteForMe should be handled:", err) + portal.zlog.Err(err).Msg("Failed to get Matrix users in portal to see if DeleteForMe should be handled") return false } if len(matrixUsers) == 1 && matrixUsers[0] == user.MXID { - msg := portal.bridge.DB.Message.GetByJID(portal.Key, content.MessageID) + msg, err := portal.bridge.DB.Message.GetByJID(ctx, portal.Key, content.MessageID) if msg == nil || msg.IsFakeMXID() { return false } - _, err := portal.MainIntent().RedactEvent(portal.MXID, msg.MXID) + _, err = portal.MainIntent().RedactEvent(ctx, portal.MXID, msg.MXID) if err != nil { - portal.log.Errorln("Failed to redact %s: %v", msg.JID, err) - } else { - msg.Delete() + portal.zlog.Err(err).Str("message_id", msg.JID).Msg("Failed to redact message from DeleteForMe") + } else if err = msg.Delete(ctx); err != nil { + portal.zlog.Err(err).Str("message_id", msg.JID).Msg("Failed to delete message from database after DeleteForMe") } return true } return false } -func (portal *Portal) sendMainIntentMessage(content *event.MessageEventContent) (*mautrix.RespSendEvent, error) { - return portal.sendMessage(portal.MainIntent(), event.EventMessage, content, nil, 0) +func (portal *Portal) sendMainIntentMessage(ctx context.Context, content *event.MessageEventContent) (*mautrix.RespSendEvent, error) { + return portal.sendMessage(ctx, portal.MainIntent(), event.EventMessage, content, nil, 0) } -func (portal *Portal) encrypt(intent *appservice.IntentAPI, content *event.Content, eventType event.Type) (event.Type, error) { +func (portal *Portal) encrypt(ctx context.Context, intent *appservice.IntentAPI, content *event.Content, eventType event.Type) (event.Type, error) { if !portal.Encrypted || portal.bridge.Crypto == nil { return eventType, nil } @@ -2433,26 +2655,26 @@ func (portal *Portal) encrypt(intent *appservice.IntentAPI, content *event.Conte // TODO maybe the locking should be inside mautrix-go? portal.encryptLock.Lock() defer portal.encryptLock.Unlock() - err := portal.bridge.Crypto.Encrypt(portal.MXID, eventType, content) + err := portal.bridge.Crypto.Encrypt(ctx, portal.MXID, eventType, content) if err != nil { return eventType, fmt.Errorf("failed to encrypt event: %w", err) } return event.EventEncrypted, nil } -func (portal *Portal) sendMessage(intent *appservice.IntentAPI, eventType event.Type, content *event.MessageEventContent, extraContent map[string]interface{}, timestamp int64) (*mautrix.RespSendEvent, error) { +func (portal *Portal) sendMessage(ctx context.Context, intent *appservice.IntentAPI, eventType event.Type, content *event.MessageEventContent, extraContent map[string]interface{}, timestamp int64) (*mautrix.RespSendEvent, error) { wrappedContent := event.Content{Parsed: content, Raw: extraContent} var err error - eventType, err = portal.encrypt(intent, &wrappedContent, eventType) + eventType, err = portal.encrypt(ctx, intent, &wrappedContent, eventType) if err != nil { return nil, err } - _, _ = intent.UserTyping(portal.MXID, false, 0) + _, _ = intent.UserTyping(ctx, portal.MXID, false, 0) if timestamp == 0 { - return intent.SendMessageEvent(portal.MXID, eventType, &wrappedContent) + return intent.SendMessageEvent(ctx, portal.MXID, eventType, &wrappedContent) } else { - return intent.SendMassagedMessageEvent(portal.MXID, eventType, &wrappedContent, timestamp) + return intent.SendMassagedMessageEvent(ctx, portal.MXID, eventType, &wrappedContent, timestamp) } } @@ -2531,7 +2753,7 @@ func (cm *ConvertedMessage) MergeCaption() { } cm.Caption = nil } -func (portal *Portal) convertTextMessage(intent *appservice.IntentAPI, source *User, msg *waProto.Message) *ConvertedMessage { +func (portal *Portal) convertTextMessage(ctx context.Context, intent *appservice.IntentAPI, source *User, msg *waProto.Message) *ConvertedMessage { content := &event.MessageEventContent{ Body: msg.GetConversation(), MsgType: event.MsgText, @@ -2541,10 +2763,10 @@ func (portal *Portal) convertTextMessage(intent *appservice.IntentAPI, source *U } contextInfo := msg.GetExtendedTextMessage().GetContextInfo() - portal.bridge.Formatter.ParseWhatsApp(portal.MXID, content, contextInfo.GetMentionedJid(), false, false) + portal.bridge.Formatter.ParseWhatsApp(ctx, portal.MXID, content, contextInfo.GetMentionedJid(), false, false) expiresIn := time.Duration(contextInfo.GetExpiration()) * time.Second extraAttrs := map[string]interface{}{} - extraAttrs["com.beeper.linkpreviews"] = portal.convertURLPreviewToBeeper(intent, source, msg.GetExtendedTextMessage()) + extraAttrs["com.beeper.linkpreviews"] = portal.convertURLPreviewToBeeper(ctx, intent, source, msg.GetExtendedTextMessage()) return &ConvertedMessage{ Intent: intent, @@ -2556,7 +2778,7 @@ func (portal *Portal) convertTextMessage(intent *appservice.IntentAPI, source *U } } -func (portal *Portal) convertTemplateMessage(intent *appservice.IntentAPI, source *User, info *types.MessageInfo, tplMsg *waProto.TemplateMessage) *ConvertedMessage { +func (portal *Portal) convertTemplateMessage(ctx context.Context, intent *appservice.IntentAPI, source *User, info *types.MessageInfo, tplMsg *waProto.TemplateMessage) *ConvertedMessage { converted := &ConvertedMessage{ Intent: intent, Type: event.EventMessage, @@ -2600,11 +2822,11 @@ func (portal *Portal) convertTemplateMessage(intent *appservice.IntentAPI, sourc var convertedTitle *ConvertedMessage switch title := tpl.GetTitle().(type) { case *waProto.TemplateMessage_HydratedFourRowTemplate_DocumentMessage: - convertedTitle = portal.convertMediaMessage(intent, source, info, title.DocumentMessage, "file attachment", false) + convertedTitle = portal.convertMediaMessage(ctx, intent, source, info, title.DocumentMessage, "file attachment", false) case *waProto.TemplateMessage_HydratedFourRowTemplate_ImageMessage: - convertedTitle = portal.convertMediaMessage(intent, source, info, title.ImageMessage, "photo", false) + convertedTitle = portal.convertMediaMessage(ctx, intent, source, info, title.ImageMessage, "photo", false) case *waProto.TemplateMessage_HydratedFourRowTemplate_VideoMessage: - convertedTitle = portal.convertMediaMessage(intent, source, info, title.VideoMessage, "video attachment", false) + convertedTitle = portal.convertMediaMessage(ctx, intent, source, info, title.VideoMessage, "video attachment", false) case *waProto.TemplateMessage_HydratedFourRowTemplate_LocationMessage: content = fmt.Sprintf("Unsupported location message\n\n%s", content) case *waProto.TemplateMessage_HydratedFourRowTemplate_HydratedTitleText: @@ -2612,7 +2834,7 @@ func (portal *Portal) convertTemplateMessage(intent *appservice.IntentAPI, sourc } converted.Content.Body = content - portal.bridge.Formatter.ParseWhatsApp(portal.MXID, converted.Content, nil, true, false) + portal.bridge.Formatter.ParseWhatsApp(ctx, portal.MXID, converted.Content, nil, true, false) if convertedTitle != nil { converted.MediaKey = convertedTitle.MediaKey converted.Extra = convertedTitle.Extra @@ -2627,7 +2849,7 @@ func (portal *Portal) convertTemplateMessage(intent *appservice.IntentAPI, sourc return converted } -func (portal *Portal) convertTemplateButtonReplyMessage(intent *appservice.IntentAPI, msg *waProto.TemplateButtonReplyMessage) *ConvertedMessage { +func (portal *Portal) convertTemplateButtonReplyMessage(ctx context.Context, intent *appservice.IntentAPI, msg *waProto.TemplateButtonReplyMessage) *ConvertedMessage { return &ConvertedMessage{ Intent: intent, Type: event.EventMessage, @@ -2646,7 +2868,7 @@ func (portal *Portal) convertTemplateButtonReplyMessage(intent *appservice.Inten } } -func (portal *Portal) convertListMessage(intent *appservice.IntentAPI, source *User, msg *waProto.ListMessage) *ConvertedMessage { +func (portal *Portal) convertListMessage(ctx context.Context, intent *appservice.IntentAPI, source *User, msg *waProto.ListMessage) *ConvertedMessage { converted := &ConvertedMessage{ Intent: intent, Type: event.EventMessage, @@ -2671,7 +2893,7 @@ func (portal *Portal) convertListMessage(intent *appservice.IntentAPI, source *U body = fmt.Sprintf("%s\n\n%s", body, msg.GetFooterText()) } converted.Content.Body = body - portal.bridge.Formatter.ParseWhatsApp(portal.MXID, converted.Content, nil, false, true) + portal.bridge.Formatter.ParseWhatsApp(ctx, portal.MXID, converted.Content, nil, false, true) var optionsMarkdown strings.Builder _, _ = fmt.Fprintf(&optionsMarkdown, "#### %s\n", msg.GetButtonText()) @@ -2696,7 +2918,7 @@ func (portal *Portal) convertListMessage(intent *appservice.IntentAPI, source *U return converted } -func (portal *Portal) convertListResponseMessage(intent *appservice.IntentAPI, msg *waProto.ListResponseMessage) *ConvertedMessage { +func (portal *Portal) convertListResponseMessage(ctx context.Context, intent *appservice.IntentAPI, msg *waProto.ListResponseMessage) *ConvertedMessage { var body string if msg.GetTitle() != "" { if msg.GetDescription() != "" { @@ -2726,10 +2948,16 @@ func (portal *Portal) convertListResponseMessage(intent *appservice.IntentAPI, m } } -func (portal *Portal) convertPollUpdateMessage(intent *appservice.IntentAPI, source *User, info *types.MessageInfo, msg *waProto.PollUpdateMessage) *ConvertedMessage { - pollMessage := portal.bridge.DB.Message.GetByJID(portal.Key, msg.GetPollCreationMessageKey().GetId()) - if pollMessage == nil { - portal.log.Warnfln("Failed to convert vote message %s: poll message %s not found", info.ID, msg.GetPollCreationMessageKey().GetId()) +func (portal *Portal) convertPollUpdateMessage(ctx context.Context, intent *appservice.IntentAPI, source *User, info *types.MessageInfo, msg *waProto.PollUpdateMessage) *ConvertedMessage { + log := zerolog.Ctx(ctx).With(). + Str("poll_id", msg.GetPollCreationMessageKey().GetId()). + Logger() + pollMessage, err := portal.bridge.DB.Message.GetByJID(ctx, portal.Key, msg.GetPollCreationMessageKey().GetId()) + if err != nil { + log.Err(err).Msg("Failed to get poll message to convert vote") + return nil + } else if pollMessage == nil { + log.Warn().Msg("Poll message not found for converting vote message") return nil } vote, err := source.Client.DecryptPollVote(&events.Message{ @@ -2737,21 +2965,25 @@ func (portal *Portal) convertPollUpdateMessage(intent *appservice.IntentAPI, sou Message: &waProto.Message{PollUpdateMessage: msg}, }) if err != nil { - portal.log.Errorfln("Failed to decrypt vote message %s: %v", info.ID, err) + log.Err(err).Msg("Failed to decrypt vote message") return nil } selectedHashes := make([]string, len(vote.GetSelectedOptions())) if pollMessage.Type == database.MsgMatrixPoll { - mappedAnswers := pollMessage.GetPollOptionIDs(vote.GetSelectedOptions()) + mappedAnswers, err := pollMessage.GetPollOptionIDs(ctx, vote.GetSelectedOptions()) + if err != nil { + log.Err(err).Msg("Failed to get poll option IDs") + return nil + } for i, opt := range vote.GetSelectedOptions() { if len(opt) != 32 { - portal.log.Warnfln("Unexpected option hash length %d in %s's vote to %s", len(opt), info.Sender, pollMessage.MXID) + log.Warn().Int("hash_len", len(opt)).Msg("Unexpected option hash length in vote") continue } var ok bool - selectedHashes[i], ok = mappedAnswers[*(*[32]byte)(opt)] + selectedHashes[i], ok = mappedAnswers[[32]byte(opt)] if !ok { - portal.log.Warnfln("Didn't find ID for option %X in %s's vote to %s", opt, info.Sender, pollMessage.MXID) + log.Warn().Hex("option_hash", opt).Msg("Didn't find ID for option in vote") } } } else { @@ -2777,12 +3009,12 @@ func (portal *Portal) convertPollUpdateMessage(intent *appservice.IntentAPI, sou "org.matrix.msc3381.poll.response": map[string]any{ "answers": selectedHashes, }, - "org.matrix.msc3381.v2.selections": selectedHashes, + //"org.matrix.msc3381.v2.selections": selectedHashes, }, } } -func (portal *Portal) convertPollCreationMessage(intent *appservice.IntentAPI, msg *waProto.PollCreationMessage) *ConvertedMessage { +func (portal *Portal) convertPollCreationMessage(ctx context.Context, intent *appservice.IntentAPI, msg *waProto.PollCreationMessage) *ConvertedMessage { optionNames := make([]string, len(msg.GetOptions())) optionsListText := make([]string, len(optionNames)) optionsListHTML := make([]string, len(optionNames)) @@ -2834,23 +3066,23 @@ func (portal *Portal) convertPollCreationMessage(intent *appservice.IntentAPI, m "selectable_options_count": msg.GetSelectableOptionsCount(), }, - // Current extensible events (as of November 2022) - "org.matrix.msc1767.markup": []map[string]any{ - {"mimetype": "text/html", "body": formattedBody}, - {"mimetype": "text/plain", "body": body}, - }, - "org.matrix.msc3381.v2.poll": map[string]any{ - "kind": "org.matrix.msc3381.v2.disclosed", - "max_selections": maxChoices, - "question": map[string]any{ - "org.matrix.msc1767.markup": []map[string]any{ - {"mimetype": "text/plain", "body": msg.GetName()}, - }, - }, - "answers": msc3381V2Answers, - }, + // Slightly less extensible events (November 2022) + //"org.matrix.msc1767.markup": []map[string]any{ + // {"mimetype": "text/html", "body": formattedBody}, + // {"mimetype": "text/plain", "body": body}, + //}, + //"org.matrix.msc3381.v2.poll": map[string]any{ + // "kind": "org.matrix.msc3381.v2.disclosed", + // "max_selections": maxChoices, + // "question": map[string]any{ + // "org.matrix.msc1767.markup": []map[string]any{ + // {"mimetype": "text/plain", "body": msg.GetName()}, + // }, + // }, + // "answers": msc3381V2Answers, + //}, - // Legacy extensible events + // Legacyest extensible events "org.matrix.msc1767.message": []map[string]any{ {"mimetype": "text/html", "body": formattedBody}, {"mimetype": "text/plain", "body": body}, @@ -2869,7 +3101,7 @@ func (portal *Portal) convertPollCreationMessage(intent *appservice.IntentAPI, m } } -func (portal *Portal) convertLiveLocationMessage(intent *appservice.IntentAPI, msg *waProto.LiveLocationMessage) *ConvertedMessage { +func (portal *Portal) convertLiveLocationMessage(ctx context.Context, intent *appservice.IntentAPI, msg *waProto.LiveLocationMessage) *ConvertedMessage { content := &event.MessageEventContent{ Body: "Started sharing live location", MsgType: event.MsgNotice, @@ -2886,7 +3118,7 @@ func (portal *Portal) convertLiveLocationMessage(intent *appservice.IntentAPI, m } } -func (portal *Portal) convertLocationMessage(intent *appservice.IntentAPI, msg *waProto.LocationMessage) *ConvertedMessage { +func (portal *Portal) convertLocationMessage(ctx context.Context, intent *appservice.IntentAPI, msg *waProto.LocationMessage) *ConvertedMessage { url := msg.GetUrl() if len(url) == 0 { url = fmt.Sprintf("https://maps.google.com/?q=%.5f,%.5f", msg.GetDegreesLatitude(), msg.GetDegreesLongitude()) @@ -2914,7 +3146,7 @@ func (portal *Portal) convertLocationMessage(intent *appservice.IntentAPI, msg * if len(msg.GetJpegThumbnail()) > 0 { thumbnailMime := http.DetectContentType(msg.GetJpegThumbnail()) - uploadedThumbnail, _ := intent.UploadBytes(msg.GetJpegThumbnail(), thumbnailMime) + uploadedThumbnail, _ := intent.UploadBytes(ctx, msg.GetJpegThumbnail(), thumbnailMime) if uploadedThumbnail != nil { cfg, _, _ := image.DecodeConfig(bytes.NewReader(msg.GetJpegThumbnail())) content.Info = &event.FileInfo{ @@ -2939,6 +3171,7 @@ func (portal *Portal) convertLocationMessage(intent *appservice.IntentAPI, msg * } const inviteMsg = `%s
This invitation to join "%s" expires at %s. Reply to this message with !wa accept to accept the invite.` +const inviteMsgBroken = `%s
This invitation to join "%s" expires at %s. However, the invite message is broken or unsupported and cannot be accepted.` const inviteMetaField = "fi.mau.whatsapp.invite" const escapedInviteMetaField = `fi\.mau\.whatsapp\.invite` @@ -2949,27 +3182,32 @@ type InviteMeta struct { Inviter types.JID `json:"inviter"` } -func (portal *Portal) convertGroupInviteMessage(intent *appservice.IntentAPI, info *types.MessageInfo, msg *waProto.GroupInviteMessage) *ConvertedMessage { +func (portal *Portal) convertGroupInviteMessage(ctx context.Context, intent *appservice.IntentAPI, info *types.MessageInfo, msg *waProto.GroupInviteMessage) *ConvertedMessage { expiry := time.Unix(msg.GetInviteExpiration(), 0) - htmlMessage := fmt.Sprintf(inviteMsg, event.TextToHTML(msg.GetCaption()), msg.GetGroupName(), expiry) + template := inviteMsg + var extraAttrs map[string]any + groupJID, err := types.ParseJID(msg.GetGroupJid()) + if err != nil { + zerolog.Ctx(ctx).Err(err).Str("invite_group_jid", msg.GetGroupJid()).Msg("Failed to parse invite group JID") + template = inviteMsgBroken + } else { + extraAttrs = map[string]interface{}{ + inviteMetaField: InviteMeta{ + JID: groupJID, + Code: msg.GetInviteCode(), + Expiration: msg.GetInviteExpiration(), + Inviter: info.Sender.ToNonAD(), + }, + } + } + + htmlMessage := fmt.Sprintf(template, event.TextToHTML(msg.GetCaption()), msg.GetGroupName(), expiry) content := &event.MessageEventContent{ MsgType: event.MsgText, Body: format.HTMLToText(htmlMessage), Format: event.FormatHTML, FormattedBody: htmlMessage, } - groupJID, err := types.ParseJID(msg.GetGroupJid()) - if err != nil { - portal.log.Errorfln("Failed to parse invite group JID: %v", err) - } - extraAttrs := map[string]interface{}{ - inviteMetaField: InviteMeta{ - JID: groupJID, - Code: msg.GetInviteCode(), - Expiration: msg.GetInviteExpiration(), - Inviter: info.Sender.ToNonAD(), - }, - } return &ConvertedMessage{ Intent: intent, Type: event.EventMessage, @@ -2980,15 +3218,15 @@ func (portal *Portal) convertGroupInviteMessage(intent *appservice.IntentAPI, in } } -func (portal *Portal) convertContactMessage(intent *appservice.IntentAPI, msg *waProto.ContactMessage) *ConvertedMessage { +func (portal *Portal) convertContactMessage(ctx context.Context, intent *appservice.IntentAPI, msg *waProto.ContactMessage) *ConvertedMessage { fileName := fmt.Sprintf("%s.vcf", msg.GetDisplayName()) data := []byte(msg.GetVcard()) mimeType := "text/vcard" uploadMimeType, file := portal.encryptFileInPlace(data, mimeType) - uploadResp, err := intent.UploadBytesWithName(data, uploadMimeType, fileName) + uploadResp, err := intent.UploadBytesWithName(ctx, data, uploadMimeType, fileName) if err != nil { - portal.log.Errorfln("Failed to upload vcard of %s: %v", msg.GetDisplayName(), err) + zerolog.Ctx(ctx).Err(err).Str("displayname", msg.GetDisplayName()).Msg("Failed to upload vcard") return nil } @@ -3016,14 +3254,14 @@ func (portal *Portal) convertContactMessage(intent *appservice.IntentAPI, msg *w } } -func (portal *Portal) convertContactsArrayMessage(intent *appservice.IntentAPI, msg *waProto.ContactsArrayMessage) *ConvertedMessage { +func (portal *Portal) convertContactsArrayMessage(ctx context.Context, intent *appservice.IntentAPI, msg *waProto.ContactsArrayMessage) *ConvertedMessage { name := msg.GetDisplayName() if len(name) == 0 { name = fmt.Sprintf("%d contacts", len(msg.GetContacts())) } contacts := make([]*event.MessageEventContent, 0, len(msg.GetContacts())) for _, contact := range msg.GetContacts() { - converted := portal.convertContactMessage(intent, contact) + converted := portal.convertContactMessage(ctx, intent, contact) if converted != nil { contacts = append(contacts, converted.Content) } @@ -3041,37 +3279,34 @@ func (portal *Portal) convertContactsArrayMessage(intent *appservice.IntentAPI, } } -func (portal *Portal) tryKickUser(userID id.UserID, intent *appservice.IntentAPI) error { - _, err := intent.KickUser(portal.MXID, &mautrix.ReqKickUser{UserID: userID}) - if err != nil { - httpErr, ok := err.(mautrix.HTTPError) - if ok && httpErr.RespError != nil && httpErr.RespError.ErrCode == "M_FORBIDDEN" { - _, err = portal.MainIntent().KickUser(portal.MXID, &mautrix.ReqKickUser{UserID: userID}) - } +func (portal *Portal) tryKickUser(ctx context.Context, userID id.UserID, intent *appservice.IntentAPI) error { + _, err := intent.KickUser(ctx, portal.MXID, &mautrix.ReqKickUser{UserID: userID}) + if errors.Is(err, mautrix.MForbidden) { + _, err = portal.MainIntent().KickUser(ctx, portal.MXID, &mautrix.ReqKickUser{UserID: userID}) } return err } -func (portal *Portal) removeUser(isSameUser bool, kicker *appservice.IntentAPI, target id.UserID, targetIntent *appservice.IntentAPI) { +func (portal *Portal) removeUser(ctx context.Context, isSameUser bool, kicker *appservice.IntentAPI, target id.UserID, targetIntent *appservice.IntentAPI) { if !isSameUser || targetIntent == nil { - err := portal.tryKickUser(target, kicker) + err := portal.tryKickUser(ctx, target, kicker) if err != nil { - portal.log.Warnfln("Failed to kick %s from %s: %v", target, portal.MXID, err) + zerolog.Ctx(ctx).Warn().Err(err).Stringer("target_mxid", target).Msg("Failed to kick user from portal") if targetIntent != nil { - _, _ = targetIntent.LeaveRoom(portal.MXID) + _, _ = targetIntent.LeaveRoom(ctx, portal.MXID) } } } else { - _, err := targetIntent.LeaveRoom(portal.MXID) + _, err := targetIntent.LeaveRoom(ctx, portal.MXID) if err != nil { - portal.log.Warnfln("Failed to leave portal as %s: %v", target, err) - _, _ = portal.MainIntent().KickUser(portal.MXID, &mautrix.ReqKickUser{UserID: target}) + zerolog.Ctx(ctx).Warn().Err(err).Stringer("target_mxid", target).Msg("Failed to leave portal as user") + _, _ = portal.MainIntent().KickUser(ctx, portal.MXID, &mautrix.ReqKickUser{UserID: target}) } } - portal.CleanupIfEmpty() + portal.CleanupIfEmpty(ctx) } -func (portal *Portal) HandleWhatsAppKick(source *User, senderJID types.JID, jids []types.JID) { +func (portal *Portal) HandleWhatsAppKick(ctx context.Context, source *User, senderJID types.JID, jids []types.JID) { sender := portal.bridge.GetPuppetByJID(senderJID) senderIntent := sender.IntentFor(portal) for _, jid := range jids { @@ -3084,7 +3319,7 @@ func (portal *Portal) HandleWhatsAppKick(source *User, senderJID types.JID, jids // continue //} puppet := portal.bridge.GetPuppetByJID(jid) - portal.removeUser(puppet.JID == sender.JID, senderIntent, puppet.MXID, puppet.DefaultIntent()) + portal.removeUser(ctx, puppet.JID == sender.JID, senderIntent, puppet.MXID, puppet.DefaultIntent()) if !portal.IsBroadcastList() { user := portal.bridge.GetUserByJID(jid) @@ -3093,13 +3328,13 @@ func (portal *Portal) HandleWhatsAppKick(source *User, senderJID types.JID, jids if puppet.CustomMXID == user.MXID { customIntent = puppet.CustomIntent() } - portal.removeUser(puppet.JID == sender.JID, senderIntent, user.MXID, customIntent) + portal.removeUser(ctx, puppet.JID == sender.JID, senderIntent, user.MXID, customIntent) } } } } -func (portal *Portal) HandleWhatsAppInvite(source *User, senderJID *types.JID, jids []types.JID) (evtID id.EventID) { +func (portal *Portal) HandleWhatsAppInvite(ctx context.Context, source *User, senderJID *types.JID, jids []types.JID) (evtID id.EventID) { intent := portal.MainIntent() if senderJID != nil && !senderJID.IsEmpty() { sender := portal.bridge.GetPuppetByJID(*senderJID) @@ -3111,42 +3346,47 @@ func (portal *Portal) HandleWhatsAppInvite(source *User, senderJID *types.JID, j continue } puppet := portal.bridge.GetPuppetByJID(jid) - puppet.SyncContact(source, true, false, "handling whatsapp invite") - resp, err := intent.SendStateEvent(portal.MXID, event.StateMember, puppet.MXID.String(), &event.MemberEventContent{ + puppet.SyncContact(ctx, source, true, false, "handling whatsapp invite") + resp, err := intent.SendStateEvent(ctx, portal.MXID, event.StateMember, puppet.MXID.String(), &event.MemberEventContent{ Membership: event.MembershipInvite, Displayname: puppet.Displayname, AvatarURL: puppet.AvatarURL.CUString(), }) if err != nil { - portal.log.Warnfln("Failed to invite %s as %s: %v", puppet.MXID, intent.UserID, err) - _ = portal.MainIntent().EnsureInvited(portal.MXID, puppet.MXID) + zerolog.Ctx(ctx).Warn().Err(err). + Stringer("target_mxid", puppet.MXID). + Stringer("inviter_mxid", intent.UserID). + Msg("Failed to invite user") + _ = portal.MainIntent().EnsureInvited(ctx, portal.MXID, puppet.MXID) } else { evtID = resp.EventID } - err = puppet.DefaultIntent().EnsureJoined(portal.MXID) + err = puppet.DefaultIntent().EnsureJoined(ctx, portal.MXID) if err != nil { - portal.log.Errorfln("Failed to ensure %s is joined: %v", puppet.MXID, err) + zerolog.Ctx(ctx).Err(err). + Stringer("target_mxid", puppet.MXID). + Msg("Failed to ensure user is joined to portal") } } return } -func (portal *Portal) HandleWhatsAppDeleteChat(user *User) { +func (portal *Portal) HandleWhatsAppDeleteChat(ctx context.Context, user *User) { if portal.MXID == "" { return } - matrixUsers, err := portal.GetMatrixUsers() + matrixUsers, err := portal.GetMatrixUsers(ctx) if err != nil { - portal.log.Errorln("Failed to get Matrix users to see if DeleteChat should be handled:", err) + zerolog.Ctx(ctx).Err(err).Msg("Failed to get Matrix users to see if DeleteChat should be handled") return } if len(matrixUsers) > 1 { - portal.log.Infoln("Portal contains more than one Matrix user, so deleteChat will not be handled.") + zerolog.Ctx(ctx).Debug().Msg("Portal contains more than one Matrix user, ignoring DeleteChat event") return } else if (len(matrixUsers) == 1 && matrixUsers[0] == user.MXID) || len(matrixUsers) < 1 { - portal.log.Debugln("User deleted chat and there are no other Matrix users using it, deleting portal...") - portal.Delete() - portal.Cleanup(false) + zerolog.Ctx(ctx).Debug().Msg("User deleted chat and there are no other Matrix users, deleting portal...") + portal.Delete(ctx) + portal.Cleanup(ctx, false) } } @@ -3167,19 +3407,11 @@ type FailedMediaMeta struct { Media FailedMediaKeys `json:"whatsapp_media"` } -func shallowCopyMap(data map[string]interface{}) map[string]interface{} { - newMap := make(map[string]interface{}, len(data)) - for key, value := range data { - newMap[key] = value - } - return newMap -} - func (portal *Portal) makeMediaBridgeFailureMessage(info *types.MessageInfo, bridgeErr error, converted *ConvertedMessage, keys *FailedMediaKeys, userFriendlyError string) *ConvertedMessage { if errors.Is(bridgeErr, whatsmeow.ErrMediaDownloadFailedWith403) || errors.Is(bridgeErr, whatsmeow.ErrMediaDownloadFailedWith404) || errors.Is(bridgeErr, whatsmeow.ErrMediaDownloadFailedWith410) { - portal.log.Debugfln("Failed to bridge media for %s: %v", info.ID, bridgeErr) + portal.zlog.Debug().Err(bridgeErr).Str("message_id", info.ID).Msg("Failed to bridge media for message") } else { - portal.log.Errorfln("Failed to bridge media for %s: %v", info.ID, bridgeErr) + portal.zlog.Err(bridgeErr).Str("message_id", info.ID).Msg("Failed to bridge media for message") } if keys != nil { if portal.bridge.Config.Bridge.CaptionInMessage { @@ -3188,7 +3420,7 @@ func (portal *Portal) makeMediaBridgeFailureMessage(info *types.MessageInfo, bri meta := &FailedMediaMeta{ Type: converted.Type, Content: converted.Content, - ExtraContent: shallowCopyMap(converted.Extra), + ExtraContent: maps.Clone(converted.Extra), Media: *keys, } converted.Extra[failedMediaField] = meta @@ -3254,7 +3486,7 @@ type MediaMessageWithDuration interface { const WhatsAppStickerSize = 190 -func (portal *Portal) convertMediaMessageContent(intent *appservice.IntentAPI, msg MediaMessage) *ConvertedMessage { +func (portal *Portal) convertMediaMessageContent(ctx context.Context, intent *appservice.IntentAPI, msg MediaMessage) *ConvertedMessage { content := &event.MessageEventContent{ Info: &event.FileInfo{ MimeType: msg.GetMimetype(), @@ -3309,9 +3541,9 @@ func (portal *Portal) convertMediaMessageContent(intent *appservice.IntentAPI, m thumbnailCfg, _, _ := image.DecodeConfig(bytes.NewReader(thumbnailData)) thumbnailSize := len(thumbnailData) thumbnailUploadMime, thumbnailFile := portal.encryptFileInPlace(thumbnailData, thumbnailMime) - uploadedThumbnail, err := intent.UploadBytes(thumbnailData, thumbnailUploadMime) + uploadedThumbnail, err := intent.UploadBytes(ctx, thumbnailData, thumbnailUploadMime) if err != nil { - portal.log.Warnfln("Failed to upload thumbnail: %v", err) + zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to upload thumbnail") } else if uploadedThumbnail != nil { if thumbnailFile != nil { thumbnailFile.URL = uploadedThumbnail.ContentURI.CUString() @@ -3351,7 +3583,7 @@ func (portal *Portal) convertMediaMessageContent(intent *appservice.IntentAPI, m case *waProto.DocumentMessage: content.MsgType = event.MsgFile default: - portal.log.Warnfln("Unexpected media type %T in convertMediaMessageContent", msg) + zerolog.Ctx(ctx).Warn().Type("content_struct", msg).Msg("Unexpected media type in convertMediaMessageContent") content.MsgType = event.MsgFile } @@ -3360,16 +3592,16 @@ func (portal *Portal) convertMediaMessageContent(intent *appservice.IntentAPI, m var waveform []int if audioMessage.Waveform != nil { waveform = make([]int, len(audioMessage.Waveform)) - max := 0 + maxWave := 0 for i, part := range audioMessage.Waveform { waveform[i] = int(part) - if waveform[i] > max { - max = waveform[i] + if waveform[i] > maxWave { + maxWave = waveform[i] } } multiplier := 0 - if max > 0 { - multiplier = 1024 / max + if maxWave > 0 { + multiplier = 1024 / maxWave } if multiplier > 32 { multiplier = 32 @@ -3395,7 +3627,7 @@ func (portal *Portal) convertMediaMessageContent(intent *appservice.IntentAPI, m MsgType: event.MsgNotice, } - portal.bridge.Formatter.ParseWhatsApp(portal.MXID, captionContent, msg.GetContextInfo().GetMentionedJid(), false, false) + portal.bridge.Formatter.ParseWhatsApp(ctx, portal.MXID, captionContent, msg.GetContextInfo().GetMentionedJid(), false, false) } return &ConvertedMessage{ @@ -3409,7 +3641,7 @@ func (portal *Portal) convertMediaMessageContent(intent *appservice.IntentAPI, m } } -func (portal *Portal) uploadMedia(intent *appservice.IntentAPI, data []byte, content *event.MessageEventContent) error { +func (portal *Portal) uploadMedia(ctx context.Context, intent *appservice.IntentAPI, data []byte, content *event.MessageEventContent) error { uploadMimeType, file := portal.encryptFileInPlace(data, content.Info.MimeType) req := mautrix.ReqUploadMedia{ @@ -3418,13 +3650,13 @@ func (portal *Portal) uploadMedia(intent *appservice.IntentAPI, data []byte, con } var mxc id.ContentURI if portal.bridge.Config.Homeserver.AsyncMedia { - uploaded, err := intent.UploadAsync(req) + uploaded, err := intent.UploadAsync(ctx, req) if err != nil { return err } mxc = uploaded.ContentURI } else { - uploaded, err := intent.UploadMedia(req) + uploaded, err := intent.UploadMedia(ctx, req) if err != nil { return err } @@ -3457,8 +3689,8 @@ func (portal *Portal) uploadMedia(intent *appservice.IntentAPI, data []byte, con return nil } -func (portal *Portal) convertMediaMessage(intent *appservice.IntentAPI, source *User, info *types.MessageInfo, msg MediaMessage, typeName string, isBackfill bool) *ConvertedMessage { - converted := portal.convertMediaMessageContent(intent, msg) +func (portal *Portal) convertMediaMessage(ctx context.Context, intent *appservice.IntentAPI, source *User, info *types.MessageInfo, msg MediaMessage, typeName string, isBackfill bool) *ConvertedMessage { + converted := portal.convertMediaMessageContent(ctx, intent, msg) if msg.GetFileLength() > uint64(portal.bridge.MediaConfig.UploadSize) { return portal.makeMediaBridgeFailureMessage(info, errors.New("file is too large"), converted, nil, fmt.Sprintf("Large %s not bridged - please use WhatsApp app to view", typeName)) } @@ -3482,19 +3714,19 @@ func (portal *Portal) convertMediaMessage(intent *appservice.IntentAPI, source * EncSHA256: msg.GetFileEncSha256(), }, errorText) } else if errors.Is(err, whatsmeow.ErrNoURLPresent) { - portal.log.Debugfln("No URL present error for media message %s, ignoring...", info.ID) + zerolog.Ctx(ctx).Debug().Msg("No URL present error for media message, ignoring...") return nil } else if errors.Is(err, whatsmeow.ErrFileLengthMismatch) || errors.Is(err, whatsmeow.ErrInvalidMediaSHA256) { - portal.log.Warnfln("Mismatching media checksums in %s: %v. Ignoring because WhatsApp seems to ignore them too", info.ID, err) + zerolog.Ctx(ctx).Warn().Err(err).Msg("Mismatching media checksums in message. Ignoring because WhatsApp seems to ignore them too") } else if err != nil { return portal.makeMediaBridgeFailureMessage(info, err, converted, nil, "") } - err = portal.uploadMedia(intent, data, converted.Content) + err = portal.uploadMedia(ctx, intent, data, converted.Content) if err != nil { if errors.Is(err, mautrix.MTooLarge) { return portal.makeMediaBridgeFailureMessage(info, errors.New("homeserver rejected too large file"), converted, nil, "") - } else if httpErr, ok := err.(mautrix.HTTPError); ok && httpErr.IsStatus(413) { + } else if httpErr := (mautrix.HTTPError{}); errors.As(err, &httpErr) && httpErr.IsStatus(413) { return portal.makeMediaBridgeFailureMessage(info, errors.New("proxy rejected too large file"), converted, nil, "") } else { return portal.makeMediaBridgeFailureMessage(info, fmt.Errorf("failed to upload media: %w", err), converted, nil, "") @@ -3503,12 +3735,12 @@ func (portal *Portal) convertMediaMessage(intent *appservice.IntentAPI, source * return converted } -func (portal *Portal) fetchMediaRetryEvent(msg *database.Message) (*FailedMediaMeta, error) { +func (portal *Portal) fetchMediaRetryEvent(ctx context.Context, msg *database.Message) (*FailedMediaMeta, error) { errorMeta, ok := portal.mediaErrorCache[msg.JID] if ok { return errorMeta, nil } - evt, err := portal.MainIntent().GetEvent(portal.MXID, msg.MXID) + evt, err := portal.MainIntent().GetEvent(ctx, portal.MXID, msg.MXID) if err != nil { return nil, fmt.Errorf("failed to fetch event %s: %w", msg.MXID, err) } @@ -3517,7 +3749,7 @@ func (portal *Portal) fetchMediaRetryEvent(msg *database.Message) (*FailedMediaM if err != nil { return nil, fmt.Errorf("failed to parse encrypted content in %s: %w", msg.MXID, err) } - evt, err = portal.bridge.Crypto.Decrypt(evt) + evt, err = portal.bridge.Crypto.Decrypt(ctx, evt) if err != nil { return nil, fmt.Errorf("failed to decrypt event %s: %w", msg.MXID, err) } @@ -3539,7 +3771,7 @@ func (portal *Portal) fetchMediaRetryEvent(msg *database.Message) (*FailedMediaM return errorMeta, nil } -func (portal *Portal) sendMediaRetryFailureEdit(intent *appservice.IntentAPI, msg *database.Message, err error) { +func (portal *Portal) sendMediaRetryFailureEdit(ctx context.Context, intent *appservice.IntentAPI, msg *database.Message, err error) { content := event.MessageEventContent{ MsgType: event.MsgNotice, Body: fmt.Sprintf("Failed to bridge media after re-requesting it from your phone: %v", err), @@ -3550,28 +3782,37 @@ func (portal *Portal) sendMediaRetryFailureEdit(intent *appservice.IntentAPI, ms EventID: msg.MXID, Type: event.RelReplace, } - resp, sendErr := portal.sendMessage(intent, event.EventMessage, &content, nil, time.Now().UnixMilli()) + resp, sendErr := portal.sendMessage(ctx, intent, event.EventMessage, &content, nil, time.Now().UnixMilli()) if sendErr != nil { - portal.log.Warnfln("Failed to edit %s after retry failure for %s: %v", msg.MXID, msg.JID, sendErr) + zerolog.Ctx(ctx).Err(err).Msg("Failed to edit message after media retry failure") } else { - portal.log.Debugfln("Successfully edited %s -> %s after retry failure for %s", msg.MXID, resp.EventID, msg.JID) + zerolog.Ctx(ctx).Debug().Stringer("edit_mxid", resp.EventID). + Msg("Successfully edited message after media retry failure") } - } func (portal *Portal) handleMediaRetry(retry *events.MediaRetry, source *User) { - msg := portal.bridge.DB.Message.GetByJID(portal.Key, retry.MessageID) + log := portal.zlog.With(). + Str("action", "handle media retry"). + Str("retry_message_id", retry.MessageID). + Logger() + ctx := log.WithContext(context.TODO()) + msg, err := portal.bridge.DB.Message.GetByJID(ctx, portal.Key, retry.MessageID) if msg == nil { - portal.log.Warnfln("Dropping media retry notification for unknown message %s", retry.MessageID) + log.Warn().Msg("Dropping media retry notification for unknown message") return - } else if msg.Error != database.MsgErrMediaNotFound { - portal.log.Warnfln("Dropping media retry notification for non-errored message %s / %s", retry.MessageID, msg.MXID) + } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Stringer("retry_message_mxid", msg.MXID) + }) + if msg.Error != database.MsgErrMediaNotFound { + log.Warn().Msg("Dropping media retry notification for non-errored message") return } - meta, err := portal.fetchMediaRetryEvent(msg) + meta, err := portal.fetchMediaRetryEvent(ctx, msg) if err != nil { - portal.log.Warnfln("Can't handle media retry notification for %s: %v", retry.MessageID, err) + log.Warn().Err(err).Msg("Can't handle media retry notification for message") return } @@ -3591,35 +3832,35 @@ func (portal *Portal) handleMediaRetry(retry *events.MediaRetry, source *User) { retryData, err := whatsmeow.DecryptMediaRetryNotification(retry, meta.Media.Key) if err != nil { - portal.log.Warnfln("Failed to handle media retry notification for %s: %v", retry.MessageID, err) - portal.sendMediaRetryFailureEdit(intent, msg, err) + log.Warn().Err(err).Msg("Failed to decrypt media retry notification") + portal.sendMediaRetryFailureEdit(ctx, intent, msg, err) return } else if retryData.GetResult() != waProto.MediaRetryNotification_SUCCESS { errorName := waProto.MediaRetryNotification_ResultType_name[int32(retryData.GetResult())] if retryData.GetDirectPath() == "" { - portal.log.Warnfln("Got error response in media retry notification for %s: %s", retry.MessageID, errorName) - portal.log.Debugfln("Error response contents: %+v", retryData) + log.Warn().Str("error_name", errorName).Msg("Got error response in media retry notification") + log.Debug().Any("error_content", retryData).Msg("Full error response content") if retryData.GetResult() == waProto.MediaRetryNotification_NOT_FOUND { - portal.sendMediaRetryFailureEdit(intent, msg, whatsmeow.ErrMediaNotAvailableOnPhone) + portal.sendMediaRetryFailureEdit(ctx, intent, msg, whatsmeow.ErrMediaNotAvailableOnPhone) } else { - portal.sendMediaRetryFailureEdit(intent, msg, fmt.Errorf("phone sent error response: %s", errorName)) + portal.sendMediaRetryFailureEdit(ctx, intent, msg, fmt.Errorf("phone sent error response: %s", errorName)) } return } else { - portal.log.Debugfln("Got error response %s in media retry notification for %s, but response also contains a new download URL - trying to download", retry.MessageID, errorName) + log.Debug().Msg("Got error response in media retry notification, but response also contains a new download URL - trying to download") } } data, err := source.Client.DownloadMediaWithPath(retryData.GetDirectPath(), meta.Media.EncSHA256, meta.Media.SHA256, meta.Media.Key, meta.Media.Length, meta.Media.Type, "") if err != nil { - portal.log.Warnfln("Failed to download media in %s after retry notification: %v", retry.MessageID, err) - portal.sendMediaRetryFailureEdit(intent, msg, err) + log.Warn().Err(err).Msg("Failed to download media after retry notification") + portal.sendMediaRetryFailureEdit(ctx, intent, msg, err) return } - err = portal.uploadMedia(intent, data, meta.Content) + err = portal.uploadMedia(ctx, intent, data, meta.Content) if err != nil { - portal.log.Warnfln("Failed to re-upload media for %s after retry notification: %v", retry.MessageID, err) - portal.sendMediaRetryFailureEdit(intent, msg, fmt.Errorf("re-uploading media failed: %v", err)) + log.Err(err).Msg("Failed to re-upload media after retry notification") + portal.sendMediaRetryFailureEdit(ctx, intent, msg, fmt.Errorf("re-uploading media failed: %v", err)) return } replaceContent := &event.MessageEventContent{ @@ -3633,40 +3874,49 @@ func (portal *Portal) handleMediaRetry(retry *events.MediaRetry, source *User) { } // Move the extra content into m.new_content too meta.ExtraContent = map[string]interface{}{ - "m.new_content": shallowCopyMap(meta.ExtraContent), + "m.new_content": maps.Clone(meta.ExtraContent), } - resp, err := portal.sendMessage(intent, meta.Type, replaceContent, meta.ExtraContent, time.Now().UnixMilli()) + resp, err := portal.sendMessage(ctx, intent, meta.Type, replaceContent, meta.ExtraContent, time.Now().UnixMilli()) if err != nil { - portal.log.Warnfln("Failed to edit %s after retry notification for %s: %v", msg.MXID, retry.MessageID, err) + log.Err(err).Msg("Failed to edit message after reuploading media from retry notification") return } - portal.log.Debugfln("Successfully edited %s -> %s after retry notification for %s", msg.MXID, resp.EventID, retry.MessageID) - msg.UpdateMXID(nil, resp.EventID, database.MsgNormal, database.MsgNoError) + log.Debug().Stringer("edit_mxid", resp.EventID).Msg("Successfully edited message after retry notification") + err = msg.UpdateMXID(ctx, resp.EventID, database.MsgNormal, database.MsgNoError) + if err != nil { + log.Err(err).Msg("Failed to save message to database after editing with retry notification") + } } -func (portal *Portal) requestMediaRetry(user *User, eventID id.EventID, mediaKey []byte) (bool, error) { - msg := portal.bridge.DB.Message.GetByMXID(eventID) - if msg == nil { - err := errors.New(fmt.Sprintf("%s requested a media retry for unknown event %s", user.MXID, eventID)) - portal.log.Debugfln(err.Error()) - return false, err - } else if msg.Error != database.MsgErrMediaNotFound { - err := errors.New(fmt.Sprintf("%s requested a media retry for non-errored event %s", user.MXID, eventID)) - portal.log.Debugfln(err.Error()) - return false, err +func (portal *Portal) requestMediaRetry(ctx context.Context, user *User, eventID id.EventID, mediaKey []byte) (bool, error) { + log := zerolog.Ctx(ctx).With().Stringer("target_event_id", eventID).Logger() + msg, err := portal.bridge.DB.Message.GetByMXID(ctx, eventID) + if err != nil { + log.Err(err).Msg("Failed to get media retry target from database") + return false, fmt.Errorf("failed to get media retry target") + } else if msg == nil { + log.Debug().Msg("Can't send media retry request for unknown message") + return false, fmt.Errorf("unknown message") + } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("target_message_id", msg.JID) + }) + if msg.Error != database.MsgErrMediaNotFound { + log.Debug().Msg("Dropping media retry request for non-errored message") + return false, fmt.Errorf("message is not errored") } // If the media key is not provided, grab it from the event in Matrix if mediaKey == nil { - evt, err := portal.fetchMediaRetryEvent(msg) + evt, err := portal.fetchMediaRetryEvent(ctx, msg) if err != nil { - portal.log.Warnfln("Can't send media retry request for %s: %v", msg.JID, err) + log.Warn().Err(err).Msg("Dropping media retry request as media key couldn't be fetched") return true, nil } mediaKey = evt.Media.Key } - err := user.Client.SendMediaRetryReceipt(&types.MessageInfo{ + err = user.Client.SendMediaRetryReceipt(&types.MessageInfo{ ID: msg.JID, MessageSource: types.MessageSource{ IsFromMe: msg.Sender.User == user.JID.User, @@ -3676,9 +3926,9 @@ func (portal *Portal) requestMediaRetry(user *User, eventID id.EventID, mediaKey }, }, mediaKey) if err != nil { - portal.log.Warnfln("Failed to send media retry request for %s: %v", msg.JID, err) + log.Err(err).Msg("Failed to send media retry request") } else { - portal.log.Debugfln("Sent media retry request for %s", msg.JID) + log.Debug().Msg("Sent media retry request") } return true, err } @@ -3740,9 +3990,9 @@ func (portal *Portal) downloadThumbnail(ctx context.Context, original []byte, th if len(thumbnailURL) == 0 { // just fall back to making thumbnail of original } else if mxc, err := thumbnailURL.Parse(); err != nil { - portal.log.Warnfln("Malformed thumbnail URL in %s: %v (falling back to generating thumbnail from source)", eventID, err) - } else if thumbnail, err := portal.MainIntent().DownloadBytesContext(ctx, mxc); err != nil { - portal.log.Warnfln("Failed to download thumbnail in %s: %v (falling back to generating thumbnail from source)", eventID, err) + zerolog.Ctx(ctx).Warn().Err(err).Msg("Malformed thumbnail URL in event, falling back to generating thumbnail from source") + } else if thumbnail, err := portal.MainIntent().DownloadBytes(ctx, mxc); err != nil { + zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to download thumbnail in event, falling back to generating thumbnail from source") } else { return createThumbnail(thumbnail, png) } @@ -3835,7 +4085,7 @@ func (portal *Portal) preprocessMatrixMedia(ctx context.Context, sender *User, r if err != nil { return nil, err } - data, err := portal.MainIntent().DownloadBytesContext(ctx, mxc) + data, err := portal.MainIntent().DownloadBytes(ctx, mxc) if err != nil { return nil, exerrors.NewDualError(errMediaDownloadFailed, err) } @@ -3903,7 +4153,7 @@ func (portal *Portal) preprocessMatrixMedia(ctx context.Context, sender *User, r return nil, exerrors.NewDualError(fmt.Errorf("%w (%s to %s)", errMediaConvertFailed, mimeType, content.Info.MimeType), convertErr) } else { // If the mime type didn't change and the errored conversion function returned the original data, just log a warning and continue - portal.log.Warnfln("Failed to re-encode %s media: %v, continuing with original file", mimeType, convertErr) + zerolog.Ctx(ctx).Warn().Err(convertErr).Str("source_mime", mimeType).Msg("Failed to re-encode media, continuing with original file") } } var uploadResp whatsmeow.UploadResponse @@ -3922,7 +4172,7 @@ func (portal *Portal) preprocessMatrixMedia(ctx context.Context, sender *User, r thumbnail, err = portal.downloadThumbnail(ctx, data, content.GetInfo().ThumbnailURL, eventID, isSticker) // Ignore format errors for non-image files, we don't care about those thumbnails if err != nil && (!errors.Is(err, image.ErrFormat) || mediaType == whatsmeow.MediaImage) { - portal.log.Warnfln("Failed to generate thumbnail for %s: %v", eventID, err) + zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to generate thumbnail for image message") } } @@ -3945,15 +4195,15 @@ type MediaUpload struct { FileLength int } -func (portal *Portal) addRelaybotFormat(userID id.UserID, content *event.MessageEventContent) bool { - member := portal.MainIntent().Member(portal.MXID, userID) +func (portal *Portal) addRelaybotFormat(ctx context.Context, userID id.UserID, content *event.MessageEventContent) bool { + member := portal.MainIntent().Member(ctx, portal.MXID, userID) if member == nil { member = &event.MemberEventContent{} } content.EnsureHasHTML() data, err := portal.bridge.Config.Bridge.Relay.FormatMessage(content, userID, *member) if err != nil { - portal.log.Errorln("Failed to apply relaybot format:", err) + zerolog.Ctx(ctx).Err(err).Msg("Failed to apply relaybot format") } content.FormattedBody = data return true @@ -4098,7 +4348,7 @@ func init() { event.TypeMap[TypeMSC3381PollStart] = reflect.TypeOf(PollStartContent{}) } -func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) { +func (portal *Portal) convertMatrixPollVote(ctx context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) { content, ok := evt.Content.Parsed.(*PollResponseContent) if !ok { return nil, sender, nil, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed) @@ -4109,8 +4359,12 @@ func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt } else if content.V2Selections != nil { answers = content.V2Selections } - pollMsg := portal.bridge.DB.Message.GetByMXID(content.RelatesTo.EventID) - if pollMsg == nil { + log := zerolog.Ctx(ctx) + pollMsg, err := portal.bridge.DB.Message.GetByMXID(ctx, content.RelatesTo.EventID) + if err != nil { + log.Err(err).Msg("Failed to get poll message from database") + return nil, sender, nil, fmt.Errorf("failed to get poll message") + } else if pollMsg == nil { return nil, sender, nil, errTargetNotFound } pollMsgInfo := &types.MessageInfo{ @@ -4125,13 +4379,17 @@ func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt } optionHashes := make([][]byte, 0, len(answers)) if pollMsg.Type == database.MsgMatrixPoll { - mappedAnswers := pollMsg.GetPollOptionHashes(answers) + mappedAnswers, err := pollMsg.GetPollOptionHashes(ctx, answers) + if err != nil { + log.Err(err).Msg("Failed to get poll option hashes from database") + return nil, sender, nil, fmt.Errorf("failed to get poll option hashes") + } for _, selection := range answers { hash, ok := mappedAnswers[selection] if ok { optionHashes = append(optionHashes, hash[:]) } else { - portal.log.Warnfln("Didn't find hash for option %s in %s's vote to %s", selection, evt.Sender, pollMsg.MXID) + log.Warn().Str("option", selection).Msg("Didn't find hash for selected option") } } } else { @@ -4148,7 +4406,7 @@ func (portal *Portal) convertMatrixPollVote(_ context.Context, sender *User, evt return &waProto.Message{PollUpdateMessage: pollUpdate}, sender, nil, err } -func (portal *Portal) convertMatrixPollStart(_ context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) { +func (portal *Portal) convertMatrixPollStart(ctx context.Context, sender *User, evt *event.Event) (*waProto.Message, *User, *extraConvertMeta, error) { content, ok := evt.Content.Parsed.(*PollStartContent) if !ok { return nil, sender, nil, fmt.Errorf("%w %T", errUnexpectedParsedContentType, evt.Content.Parsed) @@ -4157,7 +4415,7 @@ func (portal *Portal) convertMatrixPollStart(_ context.Context, sender *User, ev if maxAnswers >= len(content.PollStart.Answers) || maxAnswers < 0 { maxAnswers = 0 } - ctxInfo := portal.generateContextInfo(content.RelatesTo) + ctxInfo := portal.generateContextInfo(ctx, content.RelatesTo) var question string question, ctxInfo.MentionedJid = portal.msc1767ToWhatsApp(content.PollStart.Question, true) if len(question) == 0 { @@ -4169,7 +4427,7 @@ func (portal *Portal) convertMatrixPollStart(_ context.Context, sender *User, ev body, _ := portal.msc1767ToWhatsApp(opt.MSC1767Message, false) hash := sha256.Sum256([]byte(body)) if _, alreadyExists := optionMap[hash]; alreadyExists { - portal.log.Warnfln("Poll %s by %s has option %q more than once, rejecting", evt.ID, evt.Sender, body) + zerolog.Ctx(ctx).Warn().Str("option", body).Msg("Poll has duplicate options, rejecting") return nil, sender, nil, errPollDuplicateOption } optionMap[hash] = opt.ID @@ -4192,11 +4450,16 @@ func (portal *Portal) convertMatrixPollStart(_ context.Context, sender *User, ev }, sender, &extraConvertMeta{PollOptions: optionMap}, err } -func (portal *Portal) generateContextInfo(relatesTo *event.RelatesTo) *waProto.ContextInfo { +func (portal *Portal) generateContextInfo(ctx context.Context, relatesTo *event.RelatesTo) *waProto.ContextInfo { var ctxInfo waProto.ContextInfo replyToID := relatesTo.GetReplyTo() if len(replyToID) > 0 { - replyToMsg := portal.bridge.DB.Message.GetByMXID(replyToID) + replyToMsg, err := portal.bridge.DB.Message.GetByMXID(ctx, replyToID) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("reply_to_mxid", replyToID). + Msg("Failed to get reply target from database") + } if replyToMsg != nil && !replyToMsg.IsFakeJID() && (replyToMsg.Type == database.MsgNormal || replyToMsg.Type == database.MsgMatrixPoll || replyToMsg.Type == database.MsgBeeperGallery) { ctxInfo.StanzaId = &replyToMsg.JID ctxInfo.Participant = proto.String(replyToMsg.Sender.ToNonAD().String()) @@ -4261,12 +4524,23 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev } isRelay = true } + log := zerolog.Ctx(ctx) var editRootMsg *database.Message if editEventID := content.RelatesTo.GetReplaceID(); editEventID != "" { - editRootMsg = portal.bridge.DB.Message.GetByMXID(editEventID) - if editErr := getEditError(editRootMsg, sender); editErr != nil { + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Stringer("edit_target_mxid", editEventID) + }) + var err error + editRootMsg, err = portal.bridge.DB.Message.GetByMXID(ctx, editEventID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get edit target message from database") + return nil, sender, extraMeta, errEditUnknownTarget + } else if editErr := getEditError(editRootMsg, sender); editErr != nil { return nil, sender, extraMeta, editErr } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("edit_target_id", editRootMsg.JID) + }) extraMeta.EditRootMsg = editRootMsg if content.NewContent != nil { content = content.NewContent @@ -4274,8 +4548,8 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev } msg := &waProto.Message{} - ctxInfo := portal.generateContextInfo(content.RelatesTo) - relaybotFormatted := isRelay && portal.addRelaybotFormat(realSenderMXID, content) + ctxInfo := portal.generateContextInfo(ctx, content.RelatesTo) + relaybotFormatted := isRelay && portal.addRelaybotFormat(ctx, realSenderMXID, content) if evt.Type == event.EventSticker { if relaybotFormatted { // Stickers can't have captions, so force relaybot stickers to be images @@ -4308,7 +4582,7 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev Text: &text, ContextInfo: ctxInfo, } - hasPreview := portal.convertURLPreviewToWhatsApp(ctx, sender, evt, msg.ExtendedTextMessage) + hasPreview := portal.convertURLPreviewToWhatsApp(ctx, sender, content, msg.ExtendedTextMessage) if ctx.Err() != nil { return nil, sender, extraMeta, ctx.Err() } @@ -4514,20 +4788,17 @@ func (portal *Portal) generateMessageInfo(sender *User) *types.MessageInfo { } } -func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timings messageTimings) { +func (portal *Portal) HandleMatrixMessage(ctx context.Context, sender *User, evt *event.Event, timings messageTimings) { start := time.Now() ms := metricSender{portal: portal, timings: &timings} - log := portal.zlog.With(). - Str("event_id", evt.ID.String()). - Str("action", "handle matrix message"). - Logger() + log := zerolog.Ctx(ctx) allowRelay := evt.Type != TypeMSC3381PollResponse && evt.Type != TypeMSC3381V2PollResponse && evt.Type != TypeMSC3381PollStart if err := portal.canBridgeFrom(sender, allowRelay, true); err != nil { - go ms.sendMessageMetrics(evt, err, "Ignoring", true) + go ms.sendMessageMetrics(ctx, evt, err, "Ignoring", true) return } else if portal.Key.JID == types.StatusBroadcastJID && portal.bridge.Config.Bridge.DisableStatusBroadcastSend { - go ms.sendMessageMetrics(evt, errBroadcastSendDisabled, "Ignoring", true) + go ms.sendMessageMetrics(ctx, evt, errBroadcastSendDisabled, "Ignoring", true) return } @@ -4536,25 +4807,37 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing var dbMsg *database.Message if retryMeta := evt.Content.AsMessage().MessageSendRetry; retryMeta != nil { origEvtID = retryMeta.OriginalEventID - dbMsg = portal.bridge.DB.Message.GetByMXID(origEvtID) - if dbMsg != nil && dbMsg.Sent { - portal.log.Debugfln("Ignoring retry request %s (#%d, age: %s) for %s/%s from %s as message was already sent", evt.ID, retryMeta.RetryCount, messageAge, origEvtID, dbMsg.JID, evt.Sender) - go ms.sendMessageMetrics(evt, nil, "", true) + var err error + logEvt := log.Debug(). + Dur("message_age", messageAge). + Int("retry_count", retryMeta.RetryCount). + Stringer("orig_event_id", origEvtID) + dbMsg, err = portal.bridge.DB.Message.GetByMXID(ctx, origEvtID) + if err != nil { + log.Err(err).Msg("Failed to get retry request target message from database") + // TODO drop message? + } else if dbMsg != nil && dbMsg.Sent { + logEvt. + Str("wa_message_id", dbMsg.JID). + Msg("Ignoring retry request as message was already sent") + go ms.sendMessageMetrics(ctx, evt, nil, "", true) return } else if dbMsg != nil { - portal.log.Debugfln("Got retry request %s (#%d, age: %s) for %s/%s from %s", evt.ID, retryMeta.RetryCount, messageAge, origEvtID, dbMsg.JID, evt.Sender) + logEvt. + Str("wa_message_id", dbMsg.JID). + Msg("Got retry request for message") } else { - portal.log.Debugfln("Got retry request %s (#%d, age: %s) for %s from %s (original message not known)", evt.ID, retryMeta.RetryCount, messageAge, origEvtID, evt.Sender) + logEvt.Msg("Got retry request for message, but original message is not known") } } else { - portal.log.Debugfln("Received message %s from %s (age: %s)", evt.ID, evt.Sender, messageAge) + log.Debug().Dur("message_age", messageAge).Msg("Received Matrix message") } errorAfter := portal.bridge.Config.Bridge.MessageHandlingTimeout.ErrorAfter deadline := portal.bridge.Config.Bridge.MessageHandlingTimeout.Deadline isScheduled, _ := evt.Content.Raw["com.beeper.scheduled"].(bool) if isScheduled { - portal.log.Debugfln("%s is a scheduled message, extending handling timeouts", evt.ID) + log.Debug().Msg("Message is a scheduled message, extending handling timeouts") errorAfter *= 10 deadline *= 10 } @@ -4562,31 +4845,33 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing if errorAfter > 0 { remainingTime := errorAfter - messageAge if remainingTime < 0 { - go ms.sendMessageMetrics(evt, errTimeoutBeforeHandling, "Timeout handling", true) + go ms.sendMessageMetrics(ctx, evt, errTimeoutBeforeHandling, "Timeout handling", true) return } else if remainingTime < 1*time.Second { - portal.log.Warnfln("Message %s was delayed before reaching the bridge, only have %s (of %s timeout) until delay warning", evt.ID, remainingTime, errorAfter) + log.Warn(). + Dur("remaining_timeout", remainingTime). + Dur("warning_total_timeout", errorAfter). + Msg("Message was delayed before reaching the bridge") } go func() { time.Sleep(remainingTime) - ms.sendMessageMetrics(evt, errMessageTakingLong, "Timeout handling", false) + ms.sendMessageMetrics(ctx, evt, errMessageTakingLong, "Timeout handling", false) }() } - ctx := context.Background() + timedCtx := ctx if deadline > 0 { var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, deadline) + timedCtx, cancel = context.WithTimeout(ctx, deadline) defer cancel() } - ctx = log.WithContext(ctx) timings.preproc = time.Since(start) start = time.Now() - msg, sender, extraMeta, err := portal.convertMatrixMessage(ctx, sender, evt) + msg, sender, extraMeta, err := portal.convertMatrixMessage(timedCtx, sender, evt) timings.convert = time.Since(start) if msg == nil { - go ms.sendMessageMetrics(evt, err, "Error converting", true) + go ms.sendMessageMetrics(ctx, evt, err, "Error converting", true) return } if extraMeta == nil { @@ -4596,64 +4881,79 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event, timing if msg.PollCreationMessage != nil || msg.PollCreationMessageV2 != nil || msg.PollCreationMessageV3 != nil { dbMsgType = database.MsgMatrixPoll } else if msg.EditedMessage == nil { - portal.MarkDisappearing(nil, origEvtID, time.Duration(portal.ExpirationTime)*time.Second, time.Now()) + portal.MarkDisappearing(ctx, origEvtID, time.Duration(portal.ExpirationTime)*time.Second, time.Now()) } else { dbMsgType = database.MsgEdit } info := portal.generateMessageInfo(sender) if dbMsg == nil { - dbMsg = portal.markHandled(nil, nil, info, evt.ID, evt.Sender, false, true, dbMsgType, 0, database.MsgNoError) + dbMsg = portal.markHandled(ctx, nil, info, evt.ID, evt.Sender, false, true, dbMsgType, 0, database.MsgNoError) } else { info.ID = dbMsg.JID } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("wa_message_id", info.ID) + }) if dbMsgType == database.MsgMatrixPoll && extraMeta.PollOptions != nil { - dbMsg.PutPollOptions(extraMeta.PollOptions) + err = dbMsg.PutPollOptions(ctx, extraMeta.PollOptions) + if err != nil { + log.Err(err).Msg("Failed to save poll options in message to database") + } } - portal.log.Debugln("Sending event", evt.ID, "to WhatsApp", info.ID) + log.Debug().Msg("Sending Matrix event to WhatsApp") start = time.Now() - resp, err := sender.Client.SendMessage(ctx, portal.Key.JID, msg, whatsmeow.SendRequestExtra{ + resp, err := sender.Client.SendMessage(timedCtx, portal.Key.JID, msg, whatsmeow.SendRequestExtra{ ID: info.ID, MediaHandle: extraMeta.MediaHandle, }) timings.totalSend = time.Since(start) timings.whatsmeow = resp.DebugTimings if err != nil { - go ms.sendMessageMetrics(evt, err, "Error sending", true) + go ms.sendMessageMetrics(ctx, evt, err, "Error sending", true) return } - dbMsg.MarkSent(resp.Timestamp) + err = dbMsg.MarkSent(ctx, resp.Timestamp) + if err != nil { + log.Err(err).Msg("Failed to mark message as sent in database") + } if extraMeta != nil && len(extraMeta.GalleryExtraParts) > 0 { for i, part := range extraMeta.GalleryExtraParts { partInfo := portal.generateMessageInfo(sender) - partDBMsg := portal.markHandled(nil, nil, partInfo, evt.ID, evt.Sender, false, true, database.MsgBeeperGallery, i+1, database.MsgNoError) - portal.log.Debugln("Sending gallery part", i+2, "of event", evt.ID, "to WhatsApp", partInfo.ID) - resp, err = sender.Client.SendMessage(ctx, portal.Key.JID, part, whatsmeow.SendRequestExtra{ID: partInfo.ID}) + partDBMsg := portal.markHandled(ctx, nil, partInfo, evt.ID, evt.Sender, false, true, database.MsgBeeperGallery, i+1, database.MsgNoError) + log.Debug().Int("part_index", i+1).Str("wa_part_message_id", partInfo.ID).Msg("Sending gallery part to WhatsApp") + resp, err = sender.Client.SendMessage(timedCtx, portal.Key.JID, part, whatsmeow.SendRequestExtra{ID: partInfo.ID}) if err != nil { - go ms.sendMessageMetrics(evt, err, "Error sending", true) + go ms.sendMessageMetrics(ctx, evt, err, "Error sending", true) return } - portal.log.Debugfln("Sent gallery part", i+2, "of event", evt.ID) - partDBMsg.MarkSent(resp.Timestamp) + log.Debug().Int("part_index", i+1).Str("wa_part_message_id", partInfo.ID).Msg("Sent gallery part to WhatsApp") + err = partDBMsg.MarkSent(ctx, resp.Timestamp) + if err != nil { + log.Err(err). + Str("part_id", partInfo.ID). + Msg("Failed to mark gallery extra part as sent in database") + } } } - go ms.sendMessageMetrics(evt, nil, "", true) + go ms.sendMessageMetrics(ctx, evt, nil, "", true) } -func (portal *Portal) HandleMatrixReaction(sender *User, evt *event.Event) { +func (portal *Portal) HandleMatrixReaction(ctx context.Context, sender *User, evt *event.Event) { + log := zerolog.Ctx(ctx) if err := portal.canBridgeFrom(sender, false, true); err != nil { - go portal.sendMessageMetrics(evt, err, "Ignoring", nil) + go portal.sendMessageMetrics(ctx, evt, err, "Ignoring", nil) return } else if portal.Key.JID.Server == types.BroadcastServer { // TODO implement this, probably by only sending the reaction to the sender of the status message? // (whatsapp hasn't published the feature yet) - go portal.sendMessageMetrics(evt, errBroadcastReactionNotSupported, "Ignoring", nil) + go portal.sendMessageMetrics(ctx, evt, errBroadcastReactionNotSupported, "Ignoring", nil) return } content, ok := evt.Content.Parsed.(*event.ReactionEventContent) if ok && strings.Contains(content.RelatesTo.Key, "retry") || strings.HasPrefix(content.RelatesTo.Key, "\u267b") { // ♻️ - if retryRequested, _ := portal.requestMediaRetry(sender, content.RelatesTo.EventID, nil); retryRequested { - _, _ = portal.MainIntent().RedactEvent(portal.MXID, evt.ID, mautrix.ReqRedact{ + if retryRequested, _ := portal.requestMediaRetry(ctx, sender, content.RelatesTo.EventID, nil); retryRequested { + _, _ = portal.MainIntent().RedactEvent(ctx, portal.MXID, evt.ID, mautrix.ReqRedact{ Reason: "requested media from phone", }) // Errored media, don't try to send as reaction @@ -4661,27 +4961,34 @@ func (portal *Portal) HandleMatrixReaction(sender *User, evt *event.Event) { } } - portal.log.Debugfln("Received reaction event %s from %s", evt.ID, evt.Sender) - err := portal.handleMatrixReaction(sender, evt) - go portal.sendMessageMetrics(evt, err, "Error sending", nil) + log.Debug().Msg("Received Matrix reaction event") + err := portal.handleMatrixReaction(ctx, sender, evt) + go portal.sendMessageMetrics(ctx, evt, err, "Error sending", nil) } -func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) error { +func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *User, evt *event.Event) error { + log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(*event.ReactionEventContent) if !ok { return fmt.Errorf("unexpected parsed content type %T", evt.Content.Parsed) } - target := portal.bridge.DB.Message.GetByMXID(content.RelatesTo.EventID) - if target == nil || target.Type == database.MsgReaction { + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Stringer("target_event_id", content.RelatesTo.EventID) + }) + target, err := portal.bridge.DB.Message.GetByMXID(ctx, content.RelatesTo.EventID) + if err != nil { + log.Err(err).Msg("Failed to get target message from database") + return fmt.Errorf("failed to get target event") + } else if target == nil || target.Type == database.MsgReaction { return fmt.Errorf("unknown target event %s", content.RelatesTo.EventID) } info := portal.generateMessageInfo(sender) - dbMsg := portal.markHandled(nil, nil, info, evt.ID, evt.Sender, false, true, database.MsgReaction, 0, database.MsgNoError) - portal.upsertReaction(nil, nil, target.JID, sender.JID, evt.ID, info.ID) - portal.log.Debugln("Sending reaction", evt.ID, "to WhatsApp", info.ID) + dbMsg := portal.markHandled(ctx, nil, info, evt.ID, evt.Sender, false, true, database.MsgReaction, 0, database.MsgNoError) + portal.upsertReaction(ctx, nil, target.JID, sender.JID, evt.ID, info.ID) + log.Debug().Str("whatsapp_reaction_id", info.ID).Msg("Sending Matrix reaction to WhatsApp") resp, err := portal.sendReactionToWhatsApp(sender, info.ID, target, content.RelatesTo.Key, evt.Timestamp) if err == nil { - dbMsg.MarkSent(resp.Timestamp) + err = dbMsg.MarkSent(ctx, resp.Timestamp) } return err } @@ -4708,37 +5015,49 @@ func (portal *Portal) sendReactionToWhatsApp(sender *User, id types.MessageID, t }, whatsmeow.SendRequestExtra{ID: id}) } -func (portal *Portal) upsertReaction(txn dbutil.Transaction, intent *appservice.IntentAPI, targetJID types.MessageID, senderJID types.JID, mxid id.EventID, jid types.MessageID) { - dbReaction := portal.bridge.DB.Reaction.GetByTargetJID(portal.Key, targetJID, senderJID) +func (portal *Portal) upsertReaction(ctx context.Context, intent *appservice.IntentAPI, targetJID types.MessageID, senderJID types.JID, mxid id.EventID, jid types.MessageID) { + log := zerolog.Ctx(ctx) + dbReaction, err := portal.bridge.DB.Reaction.GetByTargetJID(ctx, portal.Key, targetJID, senderJID) + if err != nil { + log.Err(err).Msg("Failed to get existing reaction from database for upsert") + return + } if dbReaction == nil { dbReaction = portal.bridge.DB.Reaction.New() dbReaction.Chat = portal.Key dbReaction.TargetJID = targetJID dbReaction.Sender = senderJID } else if intent != nil { - portal.log.Debugfln("Redacting old Matrix reaction %s after new one (%s) was sent", dbReaction.MXID, mxid) - var err error + log.Debug(). + Stringer("old_reaction_mxid", dbReaction.MXID). + Msg("Redacting old Matrix reaction after new one was sent") if intent != nil { - _, err = intent.RedactEvent(portal.MXID, dbReaction.MXID) + _, err = intent.RedactEvent(ctx, portal.MXID, dbReaction.MXID) } if intent == nil || errors.Is(err, mautrix.MForbidden) { - _, err = portal.MainIntent().RedactEvent(portal.MXID, dbReaction.MXID) + _, err = portal.MainIntent().RedactEvent(ctx, portal.MXID, dbReaction.MXID) } if err != nil { - portal.log.Warnfln("Failed to remove old reaction %s: %v", dbReaction.MXID, err) + log.Err(err). + Stringer("old_reaction_mxid", dbReaction.MXID). + Msg("Failed to redact old reaction") } } dbReaction.MXID = mxid dbReaction.JID = jid - dbReaction.Upsert(txn) + err = dbReaction.Upsert(ctx) + if err != nil { + log.Err(err).Msg("Failed to upsert reaction to database") + } } -func (portal *Portal) HandleMatrixRedaction(sender *User, evt *event.Event) { +func (portal *Portal) HandleMatrixRedaction(ctx context.Context, sender *User, evt *event.Event) { + log := zerolog.Ctx(ctx) if err := portal.canBridgeFrom(sender, true, true); err != nil { - go portal.sendMessageMetrics(evt, err, "Ignoring", nil) + go portal.sendMessageMetrics(ctx, evt, err, "Ignoring", nil) return } - portal.log.Debugfln("Received redaction %s from %s", evt.ID, evt.Sender) + log.Debug().Msg("Received Matrix redaction") senderLogIdentifier := sender.MXID if !sender.HasSession() { @@ -4746,24 +5065,33 @@ func (portal *Portal) HandleMatrixRedaction(sender *User, evt *event.Event) { senderLogIdentifier += " (through relaybot)" } - msg := portal.bridge.DB.Message.GetByMXID(evt.Redacts) - if msg == nil { - go portal.sendMessageMetrics(evt, errTargetNotFound, "Ignoring", nil) + msg, err := portal.bridge.DB.Message.GetByMXID(ctx, evt.Redacts) + if err != nil { + log.Err(err).Msg("Failed to get redaction target event from database") + go portal.sendMessageMetrics(ctx, evt, errTargetNotFound, "Ignoring", nil) + } else if msg == nil { + go portal.sendMessageMetrics(ctx, evt, errTargetNotFound, "Ignoring", nil) } else if msg.IsFakeJID() { - go portal.sendMessageMetrics(evt, errTargetIsFake, "Ignoring", nil) + go portal.sendMessageMetrics(ctx, evt, errTargetIsFake, "Ignoring", nil) } else if portal.Key.JID == types.StatusBroadcastJID && portal.bridge.Config.Bridge.DisableStatusBroadcastSend { - go portal.sendMessageMetrics(evt, errBroadcastSendDisabled, "Ignoring", nil) + go portal.sendMessageMetrics(ctx, evt, errBroadcastSendDisabled, "Ignoring", nil) } else if msg.Type == database.MsgReaction { if msg.Sender.User != sender.JID.User { - go portal.sendMessageMetrics(evt, errReactionSentBySomeoneElse, "Ignoring", nil) - } else if reaction := portal.bridge.DB.Reaction.GetByMXID(evt.Redacts); reaction == nil { - go portal.sendMessageMetrics(evt, errReactionDatabaseNotFound, "Ignoring", nil) - } else if reactionTarget := reaction.GetTarget(); reactionTarget == nil { - go portal.sendMessageMetrics(evt, errReactionTargetNotFound, "Ignoring", nil) + go portal.sendMessageMetrics(ctx, evt, errReactionSentBySomeoneElse, "Ignoring", nil) + } else if reaction, err := portal.bridge.DB.Reaction.GetByMXID(ctx, evt.Redacts); err != nil { + log.Err(err).Msg("Failed to get target reaction from database") + go portal.sendMessageMetrics(ctx, evt, errReactionDatabaseNotFound, "Ignoring", nil) + } else if reaction == nil { + go portal.sendMessageMetrics(ctx, evt, errReactionDatabaseNotFound, "Ignoring", nil) + } else if reactionTarget, err := portal.bridge.DB.Message.GetByJID(ctx, reaction.Chat, reaction.TargetJID); err != nil { + log.Err(err).Msg("Failed to get target reaction's target message from database") + go portal.sendMessageMetrics(ctx, evt, errReactionTargetNotFound, "Ignoring", nil) + } else if reactionTarget == nil { + go portal.sendMessageMetrics(ctx, evt, errReactionTargetNotFound, "Ignoring", nil) } else { - portal.log.Debugfln("Sending redaction reaction %s of %s/%s to WhatsApp", evt.ID, msg.MXID, msg.JID) - _, err := portal.sendReactionToWhatsApp(sender, "", reactionTarget, "", evt.Timestamp) - go portal.sendMessageMetrics(evt, err, "Error sending", nil) + log.Debug().Str("reaction_target_message_id", msg.JID).Msg("Sending redaction of reaction to WhatsApp") + _, err = portal.sendReactionToWhatsApp(sender, "", reactionTarget, "", evt.Timestamp) + go portal.sendMessageMetrics(ctx, evt, err, "Error sending", nil) } } else { key := &waProto.MessageKey{ @@ -4773,33 +5101,42 @@ func (portal *Portal) HandleMatrixRedaction(sender *User, evt *event.Event) { } if msg.Sender.User != sender.JID.User { if portal.IsPrivateChat() { - go portal.sendMessageMetrics(evt, errDMSentByOtherUser, "Ignoring", nil) + go portal.sendMessageMetrics(ctx, evt, errDMSentByOtherUser, "Ignoring", nil) return } key.FromMe = proto.Bool(false) key.Participant = proto.String(msg.Sender.ToNonAD().String()) } - portal.log.Debugfln("Sending redaction %s of %s/%s to WhatsApp", evt.ID, msg.MXID, msg.JID) - ctx, cancel := context.WithTimeout(context.TODO(), 60*time.Second) + log.Debug().Str("target_message_id", msg.JID).Msg("Sending redaction of message to WhatsApp") + timedCtx, cancel := context.WithTimeout(ctx, 60*time.Second) defer cancel() - _, err := sender.Client.SendMessage(ctx, portal.Key.JID, &waProto.Message{ + _, err = sender.Client.SendMessage(timedCtx, portal.Key.JID, &waProto.Message{ ProtocolMessage: &waProto.ProtocolMessage{ Type: waProto.ProtocolMessage_REVOKE.Enum(), Key: key, }, }) - go portal.sendMessageMetrics(evt, err, "Error sending", nil) + go portal.sendMessageMetrics(ctx, evt, err, "Error sending", nil) } } func (portal *Portal) HandleMatrixReadReceipt(sender bridge.User, eventID id.EventID, receipt event.ReadReceipt) { - portal.handleMatrixReadReceipt(sender.(*User), eventID, receipt.Timestamp, true) + log := portal.zlog.With(). + Str("action", "handle matrix read receipt"). + Stringer("event_id", eventID). + Stringer("user_id", sender.GetMXID()). + Logger() + ctx := log.WithContext(context.TODO()) + portal.handleMatrixReadReceipt(ctx, sender.(*User), eventID, receipt.Timestamp, true) } -func (portal *Portal) handleMatrixReadReceipt(sender *User, eventID id.EventID, receiptTimestamp time.Time, isExplicit bool) { +func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, sender *User, eventID id.EventID, receiptTimestamp time.Time, isExplicit bool) { + log := zerolog.Ctx(ctx).With(). + Stringer("sender_jid", sender.JID). + Logger() if !sender.IsLoggedIn() { if isExplicit { - portal.log.Debugfln("Ignoring read receipt by %s/%s: user is not connected to WhatsApp", sender.MXID, sender.JID) + log.Debug().Msg("Ignoring read receipt: user is not connected to WhatsApp") } return } @@ -4807,21 +5144,27 @@ func (portal *Portal) handleMatrixReadReceipt(sender *User, eventID id.EventID, maxTimestamp := receiptTimestamp // Implicit read receipts don't have an event ID that's already bridged if isExplicit { - if message := portal.bridge.DB.Message.GetByMXID(eventID); message != nil { + if message, err := portal.bridge.DB.Message.GetByMXID(ctx, eventID); err != nil { + log.Err(err).Msg("Failed to get read receipt target message") + } else if message != nil { maxTimestamp = message.Timestamp } } - prevTimestamp := sender.GetLastReadTS(portal.Key) + prevTimestamp := sender.GetLastReadTS(ctx, portal.Key) lastReadIsZero := false if prevTimestamp.IsZero() { prevTimestamp = maxTimestamp.Add(-2 * time.Second) lastReadIsZero = true } - messages := portal.bridge.DB.Message.GetMessagesBetween(portal.Key, prevTimestamp, maxTimestamp) + messages, err := portal.bridge.DB.Message.GetMessagesBetween(ctx, portal.Key, prevTimestamp, maxTimestamp) + if err != nil { + log.Err(err).Msg("Failed to get messages that need receipts") + return + } if len(messages) > 0 { - sender.SetLastReadTS(portal.Key, messages[len(messages)-1].Timestamp) + sender.SetLastReadTS(ctx, portal.Key, messages[len(messages)-1].Timestamp) } groupedMessages := make(map[types.JID][]types.MessageID) for _, msg := range messages { @@ -4838,8 +5181,12 @@ func (portal *Portal) handleMatrixReadReceipt(sender *User, eventID id.EventID, } // For explicit read receipts, log even if there are no targets. For implicit ones only log when there are targets if len(groupedMessages) > 0 || isExplicit { - portal.log.Debugfln("Sending read receipts by %s (last read: %d, was zero: %t, explicit: %t): %v", - sender.JID, prevTimestamp.Unix(), lastReadIsZero, isExplicit, groupedMessages) + log.Debug(). + Bool("explicit", isExplicit). + Time("last_read", prevTimestamp). + Bool("last_read_is_zero", lastReadIsZero). + Any("receipts", groupedMessages). + Msg("Sending read receipts to WhatsApp") } for messageSender, ids := range groupedMessages { chatJID := portal.Key.JID @@ -4847,9 +5194,12 @@ func (portal *Portal) handleMatrixReadReceipt(sender *User, eventID id.EventID, chatJID = messageSender messageSender = portal.Key.JID } - err := sender.Client.MarkRead(ids, receiptTimestamp, chatJID, messageSender) + err = sender.Client.MarkRead(ids, receiptTimestamp, chatJID, messageSender) if err != nil { - portal.log.Warnfln("Failed to mark %v as read by %s: %v", ids, sender.JID, err) + log.Err(err). + Array("message_ids", exzerolog.ArrayOfStrs(ids)). + Stringer("target_user_jid", messageSender). + Msg("Failed to send read receipt") } } } @@ -4882,15 +5232,23 @@ func (portal *Portal) setTyping(userIDs []id.UserID, state types.ChatPresence) { if user == nil || !user.IsLoggedIn() { continue } - portal.log.Debugfln("Bridging typing change from %s to chat presence %s", state, user.MXID) + portal.zlog.Debug(). + Stringer("user_jid", user.JID). + Stringer("user_mxid", user.MXID). + Str("state", string(state)). + Msg("Bridging typing change to chat presence") err := user.Client.SendChatPresence(portal.Key.JID, state, types.ChatPresenceMediaText) if err != nil { - portal.log.Warnln("Error sending chat presence:", err) + portal.zlog.Err(err). + Stringer("user_jid", user.JID). + Stringer("user_mxid", user.MXID). + Str("state", string(state)). + Msg("Failed to send chat presence") } if portal.bridge.Config.Bridge.SendPresenceOnTyping { err = user.Client.SendPresence(types.PresenceAvailable) if err != nil { - user.log.Warnln("Failed to set presence:", err) + user.zlog.Warn().Err(err).Msg("Failed to set presence on typing") } } } @@ -4938,8 +5296,11 @@ func (portal *Portal) resetChildSpaceStatus() { } } -func (portal *Portal) Delete() { - portal.Portal.Delete() +func (portal *Portal) Delete(ctx context.Context) { + err := portal.Portal.Delete(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to delete portal from database") + } portal.bridge.portalsLock.Lock() delete(portal.bridge.portalsByJID, portal.Key) if len(portal.MXID) > 0 { @@ -4949,8 +5310,8 @@ func (portal *Portal) Delete() { portal.bridge.portalsLock.Unlock() } -func (portal *Portal) GetMatrixUsers() ([]id.UserID, error) { - members, err := portal.MainIntent().JoinedMembers(portal.MXID) +func (portal *Portal) GetMatrixUsers(ctx context.Context) ([]id.UserID, error) { + members, err := portal.MainIntent().JoinedMembers(ctx, portal.MXID) if err != nil { return nil, fmt.Errorf("failed to get member list: %w", err) } @@ -4964,35 +5325,36 @@ func (portal *Portal) GetMatrixUsers() ([]id.UserID, error) { return users, nil } -func (portal *Portal) CleanupIfEmpty() { - users, err := portal.GetMatrixUsers() +func (portal *Portal) CleanupIfEmpty(ctx context.Context) { + users, err := portal.GetMatrixUsers(ctx) if err != nil { - portal.log.Errorfln("Failed to get Matrix user list to determine if portal needs to be cleaned up: %v", err) + zerolog.Ctx(ctx).Err(err).Msg("Failed to get Matrix user list to determine if portal needs to be cleaned up") return } if len(users) == 0 { - portal.log.Infoln("Room seems to be empty, cleaning up...") - portal.Delete() - portal.Cleanup(false) + zerolog.Ctx(ctx).Info().Msg("Room seems to be empty, cleaning up...") + portal.Delete(ctx) + portal.Cleanup(ctx, false) } } -func (portal *Portal) Cleanup(puppetsOnly bool) { +func (portal *Portal) Cleanup(ctx context.Context, puppetsOnly bool) { if len(portal.MXID) == 0 { return } + log := zerolog.Ctx(ctx) intent := portal.MainIntent() if portal.bridge.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting) { - err := intent.BeeperDeleteRoom(portal.MXID) + err := intent.BeeperDeleteRoom(ctx, portal.MXID) if err == nil || errors.Is(err, mautrix.MNotFound) { return } - portal.log.Warnfln("Failed to delete %s using hungryserv yeet endpoint, falling back to normal behavior: %v", portal.MXID, err) + log.Warn().Err(err).Msg("Failed to delete room using beeper yeet endpoint, falling back to normal behavior") } - members, err := intent.JoinedMembers(portal.MXID) + members, err := intent.JoinedMembers(ctx, portal.MXID) if err != nil { - portal.log.Errorln("Failed to get portal members for cleanup:", err) + log.Err(err).Msg("Failed to get portal members for cleanup") return } for member := range members.Joined { @@ -5001,58 +5363,72 @@ func (portal *Portal) Cleanup(puppetsOnly bool) { } puppet := portal.bridge.GetPuppetByMXID(member) if puppet != nil { - _, err = puppet.DefaultIntent().LeaveRoom(portal.MXID) + _, err = puppet.DefaultIntent().LeaveRoom(ctx, portal.MXID) if err != nil { - portal.log.Errorln("Error leaving as puppet while cleaning up portal:", err) + log.Err(err).Stringer("puppet_mxid", puppet.MXID).Msg("Failed to leave room as puppet while cleaning up portal") } } else if !puppetsOnly { - _, err = intent.KickUser(portal.MXID, &mautrix.ReqKickUser{UserID: member, Reason: "Deleting portal"}) + _, err = intent.KickUser(ctx, portal.MXID, &mautrix.ReqKickUser{UserID: member, Reason: "Deleting portal"}) if err != nil { - portal.log.Errorln("Error kicking user while cleaning up portal:", err) + log.Err(err).Stringer("user_mxid", member).Msg("Failed to kick user while cleaning up portal") } } } - _, err = intent.LeaveRoom(portal.MXID) + _, err = intent.LeaveRoom(ctx, portal.MXID) if err != nil { - portal.log.Errorln("Error leaving with main intent while cleaning up portal:", err) + log.Err(err).Msg("Failed to leave room with main intent while cleaning up portal") } } -func (portal *Portal) HandleMatrixLeave(brSender bridge.User) { +func (portal *Portal) HandleMatrixLeave(brSender bridge.User, evt *event.Event) { + log := portal.zlog.With(). + Str("action", "handle matrix leave"). + Stringer("event_id", evt.ID). + Stringer("user_id", brSender.GetMXID()). + Logger() + ctx := log.WithContext(context.TODO()) sender := brSender.(*User) if portal.IsPrivateChat() { - portal.log.Debugln("User left private chat portal, cleaning up and deleting...") - portal.Delete() - portal.Cleanup(false) + log.Debug().Msg("User left private chat portal, cleaning up and deleting...") + portal.Delete(ctx) + portal.Cleanup(ctx, false) return } else if portal.bridge.Config.Bridge.BridgeMatrixLeave { err := sender.Client.LeaveGroup(portal.Key.JID) if err != nil { - portal.log.Errorfln("Failed to leave group as %s: %v", sender.MXID, err) + log.Err(err).Msg("Failed to leave group") return } //portal.log.Infoln("Leave response:", <-resp) } - portal.CleanupIfEmpty() + portal.CleanupIfEmpty(ctx) } -func (portal *Portal) HandleMatrixKick(brSender bridge.User, brTarget bridge.Ghost) { +func (portal *Portal) HandleMatrixKick(brSender bridge.User, brTarget bridge.Ghost, evt *event.Event) { sender := brSender.(*User) target := brTarget.(*Puppet) _, err := sender.Client.UpdateGroupParticipants(portal.Key.JID, []types.JID{target.JID}, whatsmeow.ParticipantChangeRemove) if err != nil { - portal.log.Errorfln("Failed to kick %s from group as %s: %v", target.JID, sender.MXID, err) + portal.zlog.Err(err). + Stringer("kicked_by_mxid", sender.MXID). + Stringer("kicked_by_jid", sender.JID). + Stringer("target_jid", target.JID). + Msg("Failed to kick user from group") return } //portal.log.Infoln("Kick %s response: %s", puppet.JID, <-resp) } -func (portal *Portal) HandleMatrixInvite(brSender bridge.User, brTarget bridge.Ghost) { +func (portal *Portal) HandleMatrixInvite(brSender bridge.User, brTarget bridge.Ghost, evt *event.Event) { sender := brSender.(*User) target := brTarget.(*Puppet) _, err := sender.Client.UpdateGroupParticipants(portal.Key.JID, []types.JID{target.JID}, whatsmeow.ParticipantChangeAdd) if err != nil { - portal.log.Errorfln("Failed to add %s to group as %s: %v", target.JID, sender.MXID, err) + portal.zlog.Err(err). + Stringer("inviter_mxid", sender.MXID). + Stringer("inviter_jid", sender.JID). + Stringer("target_jid", target.JID). + Msg("Failed to add user to group") return } //portal.log.Infofln("Add %s response: %s", puppet.JID, <-resp) @@ -5063,6 +5439,13 @@ func (portal *Portal) HandleMatrixMeta(brSender bridge.User, evt *event.Event) { if !sender.Whitelisted || !sender.IsLoggedIn() { return } + log := portal.zlog.With(). + Str("action", "handle matrix metadata"). + Str("event_type", evt.Type.Type). + Stringer("event_id", evt.ID). + Stringer("sender", sender.MXID). + Logger() + ctx := log.WithContext(context.TODO()) switch content := evt.Content.Parsed.(type) { case *event.RoomNameEventContent: @@ -5072,7 +5455,8 @@ func (portal *Portal) HandleMatrixMeta(brSender bridge.User, evt *event.Event) { portal.Name = content.Name err := sender.Client.SetGroupName(portal.Key.JID, content.Name) if err != nil { - portal.log.Errorln("Failed to update group name:", err) + log.Err(err).Msg("Failed to update group name") + return } case *event.TopicEventContent: if content.Topic == portal.Topic { @@ -5081,7 +5465,8 @@ func (portal *Portal) HandleMatrixMeta(brSender bridge.User, evt *event.Event) { portal.Topic = content.Topic err := sender.Client.SetGroupTopic(portal.Key.JID, "", "", content.Topic) if err != nil { - portal.log.Errorln("Failed to update group description:", err) + log.Err(err).Msg("Failed to update group topic") + return } case *event.RoomAvatarEventContent: portal.avatarLock.Lock() @@ -5092,24 +5477,30 @@ func (portal *Portal) HandleMatrixMeta(brSender bridge.User, evt *event.Event) { var data []byte var err error if !content.URL.IsEmpty() { - data, err = portal.MainIntent().DownloadBytes(content.URL) + data, err = portal.MainIntent().DownloadBytes(ctx, content.URL) if err != nil { - portal.log.Errorfln("Failed to download updated avatar %s: %v", content.URL, err) + log.Err(err).Stringer("mxc_uri", content.URL).Msg("Failed to download updated avatar") return } - portal.log.Debugfln("%s set the group avatar to %s", sender.MXID, content.URL) + log.Debug().Stringer("mxc_uri", content.URL).Msg("Updating group avatar") } else { - portal.log.Debugfln("%s removed the group avatar", sender.MXID) + log.Debug().Msg("Removing group avatar") } newID, err := sender.Client.SetGroupPhoto(portal.Key.JID, data) if err != nil { - portal.log.Errorfln("Failed to update group avatar: %v", err) + log.Err(err).Msg("Failed to update group avatar") return } - portal.log.Debugfln("Successfully updated group avatar to %s", newID) + log.Debug().Str("avatar_id", newID).Msg("Successfully updated group avatar") portal.Avatar = newID portal.AvatarURL = content.URL - portal.UpdateBridgeInfo() - portal.Update(nil) + default: + log.Debug().Type("content_type", content).Msg("Ignoring unknown metadata event type") + return + } + portal.UpdateBridgeInfo(ctx) + err := portal.Update(ctx) + if err != nil { + log.Err(err).Msg("Failed to update portal after handling metadata") } } diff --git a/provisioning.go b/provisioning.go index fe1d0f1..8a3b13a 100644 --- a/provisioning.go +++ b/provisioning.go @@ -17,41 +17,40 @@ package main import ( - "bufio" "context" "encoding/json" "errors" "fmt" - "net" "net/http" _ "net/http/pprof" "strings" "time" + "github.com/beeper/libserv/pkg/requestlog" "github.com/gorilla/mux" "github.com/gorilla/websocket" + "github.com/rs/zerolog" + "github.com/rs/zerolog/hlog" "go.mau.fi/whatsmeow/appstate" "go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/id" ) type ProvisioningAPI struct { bridge *WABridge - log log.Logger + log zerolog.Logger } func (prov *ProvisioningAPI) Init() { - prov.log = prov.bridge.Log.Sub("Provisioning") - - prov.log.Debugln("Enabling provisioning API at", prov.bridge.Config.Bridge.Provisioning.Prefix) + prov.log.Debug().Str("base_path", prov.bridge.Config.Bridge.Provisioning.Prefix).Msg("Enabling provisioning API") r := prov.bridge.AS.Router.PathPrefix(prov.bridge.Config.Bridge.Provisioning.Prefix).Subrouter() + r.Use(hlog.NewHandler(prov.log)) + r.Use(requestlog.AccessLogger(true)) r.Use(prov.AuthMiddleware) r.HandleFunc("/v1/ping", prov.Ping).Methods(http.MethodGet) r.HandleFunc("/v1/login", prov.Login).Methods(http.MethodGet) @@ -73,7 +72,7 @@ func (prov *ProvisioningAPI) Init() { prov.bridge.AS.Router.HandleFunc("/_matrix/app/com.beeper.bridge_state", prov.BridgeStatePing).Methods(http.MethodPost) if prov.bridge.Config.Bridge.Provisioning.DebugEndpoints { - prov.log.Debugln("Enabling debug API at /debug") + prov.log.Debug().Msg("Enabling debug API at /debug") r := prov.bridge.AS.Router.PathPrefix("/debug").Subrouter() r.Use(prov.AuthMiddleware) r.PathPrefix("/pprof").Handler(http.DefaultServeMux) @@ -83,26 +82,6 @@ func (prov *ProvisioningAPI) Init() { r.HandleFunc("/v1/delete_connection", prov.Disconnect).Methods(http.MethodPost) } -type responseWrap struct { - http.ResponseWriter - statusCode int -} - -var _ http.Hijacker = (*responseWrap)(nil) - -func (rw *responseWrap) WriteHeader(statusCode int) { - rw.ResponseWriter.WriteHeader(statusCode) - rw.statusCode = statusCode -} - -func (rw *responseWrap) Hijack() (net.Conn, *bufio.ReadWriter, error) { - hijacker, ok := rw.ResponseWriter.(http.Hijacker) - if !ok { - return nil, nil, errors.New("response does not implement http.Hijacker") - } - return hijacker.Hijack() -} - func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := r.Header.Get("Authorization") @@ -119,7 +98,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { auth = auth[len("Bearer "):] } if auth != prov.bridge.Config.Bridge.Provisioning.SharedSecret { - prov.log.Infof("Authentication token does not match shared secret") + hlog.FromRequest(r).Debug().Msg("Authentication token does not match shared secret") jsonResponse(w, http.StatusForbidden, map[string]interface{}{ "error": "Authentication token does not match shared secret", "errcode": "M_FORBIDDEN", @@ -128,11 +107,12 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { } userID := r.URL.Query().Get("user_id") user := prov.bridge.GetUserByMXID(id.UserID(userID)) - start := time.Now() - wWrap := &responseWrap{w, 200} - h.ServeHTTP(wWrap, r.WithContext(context.WithValue(r.Context(), "user", user))) - duration := time.Now().Sub(start).Seconds() - prov.log.Infofln("%s %s from %s took %.2f seconds and returned status %d", r.Method, r.URL.Path, user.MXID, duration, wWrap.statusCode) + if user != nil { + hlog.FromRequest(r).UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Stringer("user_id", user.MXID) + }) + } + h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), "user", user))) }) } @@ -157,7 +137,7 @@ func (prov *ProvisioningAPI) DeleteSession(w http.ResponseWriter, r *http.Reques return } user.DeleteConnection() - user.DeleteSession() + user.DeleteSession(r.Context()) jsonResponse(w, http.StatusOK, Response{true, "Session information purged"}) user.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut}) } @@ -245,7 +225,7 @@ func (prov *ProvisioningAPI) ListContacts(w http.ResponseWriter, r *http.Request ErrCode: "no session", }) } else if contacts, err := user.Session.Contacts.GetAllContacts(); err != nil { - prov.log.Errorfln("Failed to fetch %s's contacts: %v", user.MXID, err) + hlog.FromRequest(r).Err(err).Msg("Failed to fetch all contacts") jsonResponse(w, http.StatusInternalServerError, Error{ Error: "Internal server error while fetching contact list", ErrCode: "failed to get contacts", @@ -282,7 +262,7 @@ func (prov *ProvisioningAPI) ListGroups(w http.ResponseWriter, r *http.Request) if r.Method == http.MethodPost { err := user.ResyncGroups(r.URL.Query().Get("create_portals") == "true") if err != nil { - prov.log.Errorfln("Failed to resync %s's groups: %v", user.MXID, err) + hlog.FromRequest(r).Err(err).Msg("Failed to resync groups") jsonResponse(w, http.StatusInternalServerError, Error{ Error: "Internal server error while resyncing groups", ErrCode: "failed to sync groups", @@ -291,7 +271,7 @@ func (prov *ProvisioningAPI) ListGroups(w http.ResponseWriter, r *http.Request) } } if groups, err := user.getCachedGroupList(); err != nil { - prov.log.Errorfln("Failed to fetch %s's groups: %v", user.MXID, err) + hlog.FromRequest(r).Err(err).Msg("Failed to fetch group list") jsonResponse(w, http.StatusInternalServerError, Error{ Error: "Internal server error while fetching group list", ErrCode: "failed to get groups", @@ -368,17 +348,17 @@ func (prov *ProvisioningAPI) StartPM(w http.ResponseWriter, r *http.Request) { // resolveIdentifier already responded with an error return } - portal, puppet, justCreated, err := user.StartPM(jid, "provisioning API PM") + portal, puppet, justCreated, err := user.StartPM(r.Context(), jid, "provisioning API PM") if err != nil { jsonResponse(w, http.StatusInternalServerError, Error{ Error: fmt.Sprintf("Failed to create portal: %v", err), }) } - status := http.StatusOK + statusCode := http.StatusOK if justCreated { - status = http.StatusCreated + statusCode = http.StatusCreated } - jsonResponse(w, status, PortalInfo{ + jsonResponse(w, statusCode, PortalInfo{ RoomID: portal.MXID, OtherUser: &OtherUserInfo{ JID: puppet.JID, @@ -449,29 +429,30 @@ func (prov *ProvisioningAPI) OpenGroup(w http.ResponseWriter, r *http.Request) { ErrCode: "invalid group id", }) } else if info, err := user.Client.GetGroupInfo(jid); err != nil { + hlog.FromRequest(r).Err(err).Msg("Failed to get group info by JID") // TODO return better responses for different errors (like ErrGroupNotFound and ErrNotInGroup) jsonResponse(w, http.StatusInternalServerError, Error{ Error: fmt.Sprintf("Failed to get group info: %v", err), ErrCode: "error getting group info", }) } else { - prov.log.Debugln("Importing", jid, "for", user.MXID) + hlog.FromRequest(r).Debug().Stringer("chat_jid", jid).Msg("Importing group chat for user") portal := user.GetPortalByJID(info.JID) - status := http.StatusOK + statusCode := http.StatusOK if len(portal.MXID) == 0 { - err = portal.CreateMatrixRoom(user, info, nil, true, true) + err = portal.CreateMatrixRoom(r.Context(), user, info, nil, true, true) if err != nil { jsonResponse(w, http.StatusInternalServerError, Error{ Error: fmt.Sprintf("Failed to create portal: %v", err), }) return } - status = http.StatusCreated + statusCode = http.StatusCreated } - jsonResponse(w, status, PortalInfo{ + jsonResponse(w, statusCode, PortalInfo{ RoomID: portal.MXID, GroupInfo: info, - JustCreated: status == http.StatusCreated, + JustCreated: statusCode == http.StatusCreated, }) } } @@ -495,6 +476,7 @@ func (prov *ProvisioningAPI) resolveGroupInvite(w http.ResponseWriter, r *http.R ErrCode: "invalid invite link", }) } else { + hlog.FromRequest(r).Err(err).Msg("Failed to get group info from link") jsonResponse(w, http.StatusInternalServerError, Error{ Error: fmt.Sprintf("Failed to fetch group info with link: %v", err), ErrCode: "error getting group info", @@ -530,29 +512,30 @@ func (prov *ProvisioningAPI) JoinGroup(w http.ResponseWriter, r *http.Request) { }() inviteCode, _ := mux.Vars(r)["inviteCode"] if jid, err := user.Client.JoinGroupWithLink(inviteCode); err != nil { + hlog.FromRequest(r).Err(err).Msg("Failed to join group") jsonResponse(w, http.StatusInternalServerError, Error{ Error: fmt.Sprintf("Failed to join group: %v", err), ErrCode: "error joining group", }) } else { - prov.log.Debugln(user.MXID, "successfully joined group", jid) + hlog.FromRequest(r).Debug().Stringer("chat_jid", jid).Msg("Successfully joined group") portal := user.GetPortalByJID(jid) - status := http.StatusOK + statusCode := http.StatusOK if len(portal.MXID) == 0 { time.Sleep(500 * time.Millisecond) // Wait for incoming group info to create the portal automatically - err = portal.CreateMatrixRoom(user, info, nil, true, true) + err = portal.CreateMatrixRoom(r.Context(), user, info, nil, true, true) if err != nil { jsonResponse(w, http.StatusInternalServerError, Error{ Error: fmt.Sprintf("Failed to create portal: %v", err), }) return } - status = http.StatusCreated + statusCode = http.StatusCreated } - jsonResponse(w, status, PortalInfo{ + jsonResponse(w, statusCode, PortalInfo{ RoomID: portal.MXID, GroupInfo: info, - JustCreated: status == http.StatusCreated, + JustCreated: statusCode == http.StatusCreated, }) } } @@ -616,7 +599,7 @@ func (prov *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) { } else { err := user.Client.Logout() if err != nil { - user.log.Warnln("Error while logging out:", err) + hlog.FromRequest(r).Err(err).Msg("Unknown error while logging out") if !force { jsonResponse(w, http.StatusInternalServerError, Error{ Error: fmt.Sprintf("Unknown error while logging out: %v", err), @@ -632,7 +615,7 @@ func (prov *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) { user.bridge.Metrics.TrackConnectionState(user.JID, false) user.removeFromJIDMap(status.BridgeState{StateEvent: status.StateLoggedOut}) - user.DeleteSession() + user.DeleteSession(r.Context()) jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."}) } @@ -646,16 +629,17 @@ var upgrader = websocket.Upgrader{ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { userID := r.URL.Query().Get("user_id") user := prov.bridge.GetUserByMXID(id.UserID(userID)) + log := hlog.FromRequest(r) c, err := upgrader.Upgrade(w, r, nil) if err != nil { - prov.log.Errorln("Failed to upgrade connection to websocket:", err) + log.Err(err).Msg("Failed to upgrade connection to websocket") return } defer func() { err := c.Close() if err != nil { - user.log.Debugln("Error closing websocket:", err) + log.Debug().Err(err).Msg("Error closing websocket") } }() @@ -670,23 +654,26 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { }() ctx, cancel := context.WithCancel(context.Background()) c.SetCloseHandler(func(code int, text string) error { - user.log.Debugfln("Login websocket closed (%d), cancelling login", code) + log.Debug().Int("close_code", code).Msg("Login websocket closed, cancelling login") cancel() return nil }) if userTimezone := r.URL.Query().Get("tz"); userTimezone != "" { - user.log.Debug("Setting timezone to %s", userTimezone) + log.Debug().Str("timezone", userTimezone).Msg("Updating user timezone") user.Timezone = userTimezone - user.Update() + err = user.Update(r.Context()) + if err != nil { + log.Err(err).Msg("Failed to save user after updating timezone") + } } else { - user.log.Debug("No timezone provided in request") + log.Debug().Msg("No timezone provided in request") } qrChan, err := user.Login(ctx) expiryTime := time.Now().Add(160 * time.Second) if err != nil { - user.log.Errorln("Failed to log in from provisioning API:", err) + log.Err(err).Msg("Failed to log in via provisioning API") if errors.Is(err, ErrAlreadyLoggedIn) { go user.Connect() _ = c.WriteJSON(Error{ @@ -704,7 +691,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { if phoneNum != "" { pairingCode, err := user.Client.PairPhone(phoneNum, true, whatsmeow.PairClientChrome, "Chrome (Linux)") if err != nil { - user.zlog.Err(err).Msg("Failed to start phone code login") + log.Err(err).Msg("Failed to start phone code login") _ = c.WriteJSON(Error{ Error: "Failed to request pairing code", ErrCode: "code error", @@ -712,6 +699,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { go user.DeleteConnection() return } else { + log.Debug().Msg("Started phone number login") _ = c.WriteJSON(map[string]any{ "pairing_code": pairingCode, "timeout": int(time.Until(expiryTime).Seconds()), @@ -719,7 +707,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { } } - user.log.Debugln("Started login via provisioning API") + log.Debug().Msg("Started login via provisioning API") Analytics.Track(user.MXID, "$login_start") for { @@ -728,7 +716,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { switch evt.Event { case whatsmeow.QRChannelSuccess.Event: jid := user.Client.Store.ID - user.log.Debugln("Successful login as", jid, "via provisioning API") + log.Debug().Stringer("jid", jid).Msg("Successful login via provisioning API") Analytics.Track(user.MXID, "$login_success") _ = c.WriteJSON(map[string]interface{}{ "success": true, @@ -737,7 +725,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { "platform": user.Client.Store.Platform, }) case whatsmeow.QRChannelTimeout.Event: - user.log.Debugln("Login via provisioning API timed out") + log.Debug().Msg("Login via provisioning API timed out") errCode := "login timed out" Analytics.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode}) _ = c.WriteJSON(Error{ @@ -745,7 +733,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { ErrCode: errCode, }) case whatsmeow.QRChannelErrUnexpectedEvent.Event: - user.log.Debugln("Login via provisioning API failed due to unexpected event") + log.Debug().Msg("Login via provisioning API failed due to unexpected event") errCode := "unexpected event" Analytics.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode}) _ = c.WriteJSON(Error{ @@ -753,7 +741,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { ErrCode: errCode, }) case whatsmeow.QRChannelClientOutdated.Event: - user.log.Debugln("Login via provisioning API failed due to outdated client") + log.Debug().Msg("Login via provisioning API failed due to outdated client") errCode := "bridge outdated" Analytics.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode}) _ = c.WriteJSON(Error{ diff --git a/puppet.go b/puppet.go index 0c3f9ac..46ba62a 100644 --- a/puppet.go +++ b/puppet.go @@ -17,15 +17,15 @@ package main import ( + "context" "fmt" "regexp" "sync" "time" + "github.com/rs/zerolog" "go.mau.fi/whatsmeow/types" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/bridge" @@ -59,6 +59,7 @@ func (br *WABridge) GetPuppetByMXID(mxid id.UserID) *Puppet { } func (br *WABridge) GetPuppetByJID(jid types.JID) *Puppet { + ctx := context.TODO() jid = jid.ToNonAD() if jid.Server == types.LegacyUserServer { jid.Server = types.DefaultUserServer @@ -69,11 +70,19 @@ func (br *WABridge) GetPuppetByJID(jid types.JID) *Puppet { defer br.puppetsLock.Unlock() puppet, ok := br.puppets[jid] if !ok { - dbPuppet := br.DB.Puppet.Get(jid) + dbPuppet, err := br.DB.Puppet.Get(ctx, jid) + if err != nil { + br.ZLog.Err(err).Stringer("jid", jid).Msg("Failed to get puppet from database") + return nil + } if dbPuppet == nil { dbPuppet = br.DB.Puppet.New() dbPuppet.JID = jid - dbPuppet.Insert() + err = dbPuppet.Insert(ctx) + if err != nil { + br.ZLog.Err(err).Stringer("jid", jid).Msg("Failed to insert new puppet to database") + return nil + } } puppet = br.NewPuppet(dbPuppet) br.puppets[puppet.JID] = puppet @@ -89,7 +98,10 @@ func (br *WABridge) GetPuppetByCustomMXID(mxid id.UserID) *Puppet { defer br.puppetsLock.Unlock() puppet, ok := br.puppetsByCustomMXID[mxid] if !ok { - dbPuppet := br.DB.Puppet.GetByCustomMXID(mxid) + dbPuppet, err := br.DB.Puppet.GetByCustomMXID(context.TODO(), mxid) + if err != nil { + br.ZLog.Err(err).Stringer("mxid", mxid).Msg("Failed to get puppet by custom mxid from database") + } if dbPuppet == nil { return nil } @@ -137,14 +149,18 @@ func (puppet *Puppet) GetMXID() id.UserID { } func (br *WABridge) GetAllPuppetsWithCustomMXID() []*Puppet { - return br.dbPuppetsToPuppets(br.DB.Puppet.GetAllWithCustomMXID()) + return br.dbPuppetsToPuppets(br.DB.Puppet.GetAllWithCustomMXID(context.TODO())) } func (br *WABridge) GetAllPuppets() []*Puppet { - return br.dbPuppetsToPuppets(br.DB.Puppet.GetAll()) + return br.dbPuppetsToPuppets(br.DB.Puppet.GetAll(context.TODO())) } -func (br *WABridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet { +func (br *WABridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet, err error) []*Puppet { + if err != nil { + br.ZLog.Err(err).Msg("Error getting puppets from database") + return nil + } br.puppetsLock.Lock() defer br.puppetsLock.Unlock() output := make([]*Puppet, len(dbPuppets)) @@ -175,7 +191,7 @@ func (br *WABridge) NewPuppet(dbPuppet *database.Puppet) *Puppet { return &Puppet{ Puppet: dbPuppet, bridge: br, - log: br.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)), + zlog: br.ZLog.With().Stringer("puppet_jid", dbPuppet.JID).Logger(), MXID: br.FormatPuppetMXID(dbPuppet.JID), } @@ -185,7 +201,7 @@ type Puppet struct { *database.Puppet bridge *WABridge - log log.Logger + zlog zerolog.Logger typingIn id.RoomID typingAt time.Time @@ -223,47 +239,47 @@ func (puppet *Puppet) DefaultIntent() *appservice.IntentAPI { return puppet.bridge.AS.Intent(puppet.MXID) } -func (puppet *Puppet) UpdateAvatar(source *User, forcePortalSync bool) bool { - changed := source.updateAvatar(puppet.JID, false, &puppet.Avatar, &puppet.AvatarURL, &puppet.AvatarSet, puppet.log, puppet.DefaultIntent()) +func (puppet *Puppet) UpdateAvatar(ctx context.Context, source *User, forcePortalSync bool) bool { + changed := source.updateAvatar(ctx, puppet.JID, false, &puppet.Avatar, &puppet.AvatarURL, &puppet.AvatarSet, puppet.DefaultIntent()) if !changed || puppet.Avatar == "unauthorized" { if forcePortalSync { - go puppet.updatePortalAvatar() + go puppet.updatePortalAvatar(ctx) } return changed } - err := puppet.DefaultIntent().SetAvatarURL(puppet.AvatarURL) + err := puppet.DefaultIntent().SetAvatarURL(ctx, puppet.AvatarURL) if err != nil { - puppet.log.Warnln("Failed to set avatar:", err) + zerolog.Ctx(ctx).Err(err).Msg("Failed to set avatar from puppet") } else { puppet.AvatarSet = true } - go puppet.updatePortalAvatar() + go puppet.updatePortalAvatar(ctx) return true } -func (puppet *Puppet) UpdateName(contact types.ContactInfo, forcePortalSync bool) bool { +func (puppet *Puppet) UpdateName(ctx context.Context, contact types.ContactInfo, forcePortalSync bool) bool { newName, quality := puppet.bridge.Config.Bridge.FormatDisplayname(puppet.JID, contact) if (puppet.Displayname != newName || !puppet.NameSet) && quality >= puppet.NameQuality { oldName := puppet.Displayname puppet.Displayname = newName puppet.NameQuality = quality puppet.NameSet = false - err := puppet.DefaultIntent().SetDisplayName(newName) + err := puppet.DefaultIntent().SetDisplayName(ctx, newName) if err == nil { - puppet.log.Debugln("Updated name", oldName, "->", newName) + puppet.zlog.Debug().Str("old_name", oldName).Str("new_name", newName).Msg("Updated name") puppet.NameSet = true - go puppet.updatePortalName() + go puppet.updatePortalName(ctx) } else { - puppet.log.Warnln("Failed to set display name:", err) + puppet.zlog.Err(err).Msg("Failed to set displayname") } return true } else if forcePortalSync { - go puppet.updatePortalName() + go puppet.updatePortalName(ctx) } return false } -func (puppet *Puppet) UpdateContactInfo() bool { +func (puppet *Puppet) UpdateContactInfo(ctx context.Context) bool { if !puppet.bridge.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) { return false } @@ -281,9 +297,9 @@ func (puppet *Puppet) UpdateContactInfo() bool { "com.beeper.bridge.service": "whatsapp", "com.beeper.bridge.network": "whatsapp", } - err := puppet.DefaultIntent().BeeperUpdateProfile(contactInfo) + err := puppet.DefaultIntent().BeeperUpdateProfile(ctx, contactInfo) if err != nil { - puppet.log.Warnln("Failed to store custom contact info in profile:", err) + puppet.zlog.Err(err).Msg("Failed to store custom contact info in profile") return false } else { puppet.ContactInfoSet = true @@ -300,7 +316,7 @@ func (puppet *Puppet) updatePortalMeta(meta func(portal *Portal)) { } } -func (puppet *Puppet) updatePortalAvatar() { +func (puppet *Puppet) updatePortalAvatar(ctx context.Context) { puppet.updatePortalMeta(func(portal *Portal) { if portal.Avatar == puppet.Avatar && portal.AvatarURL == puppet.AvatarURL && (portal.AvatarSet || !portal.shouldSetDMRoomMetadata()) { return @@ -308,28 +324,31 @@ func (puppet *Puppet) updatePortalAvatar() { portal.AvatarURL = puppet.AvatarURL portal.Avatar = puppet.Avatar portal.AvatarSet = false - defer portal.Update(nil) if len(portal.MXID) > 0 && !portal.shouldSetDMRoomMetadata() { - portal.UpdateBridgeInfo() + portal.UpdateBridgeInfo(ctx) } else if len(portal.MXID) > 0 { - _, err := portal.MainIntent().SetRoomAvatar(portal.MXID, puppet.AvatarURL) + _, err := portal.MainIntent().SetRoomAvatar(ctx, portal.MXID, puppet.AvatarURL) if err != nil { - portal.log.Warnln("Failed to set avatar:", err) + portal.zlog.Err(err).Msg("Failed to set avatar from puppet") } else { portal.AvatarSet = true - portal.UpdateBridgeInfo() + portal.UpdateBridgeInfo(ctx) } } + err := portal.Update(ctx) + if err != nil { + portal.zlog.Err(err).Msg("Failed to save portal after updating avatar from puppet") + } }) } -func (puppet *Puppet) updatePortalName() { +func (puppet *Puppet) updatePortalName(ctx context.Context) { puppet.updatePortalMeta(func(portal *Portal) { - portal.UpdateName(puppet.Displayname, types.EmptyJID, true) + portal.UpdateName(ctx, puppet.Displayname, types.EmptyJID, true) }) } -func (puppet *Puppet) SyncContact(source *User, onlyIfNoName, shouldHavePushName bool, reason string) { +func (puppet *Puppet) SyncContact(ctx context.Context, source *User, onlyIfNoName, shouldHavePushName bool, reason string) { if puppet == nil { return } @@ -337,39 +356,67 @@ func (puppet *Puppet) SyncContact(source *User, onlyIfNoName, shouldHavePushName source.EnqueuePuppetResync(puppet) return } + log := zerolog.Ctx(ctx).With(). + Str("method", "Puppet.SyncContact"). + Stringer("puppet_jid", puppet.JID). + Stringer("source_user_jid", source.JID). + Stringer("source_user_mxid", source.MXID). + Logger() + ctx = log.WithContext(ctx) contact, err := source.Client.Store.Contacts.GetContact(puppet.JID) if err != nil { - puppet.log.Warnfln("Failed to get contact info through %s in SyncContact: %v (sync reason: %s)", source.MXID, reason) + log.Err(err). + Stringer("source_mxid", source.MXID). + Str("sync_reason", reason). + Msg("Failed to get contact info through user in SyncContact") } else if !contact.Found { - puppet.log.Warnfln("No contact info found through %s in SyncContact (sync reason: %s)", source.MXID, reason) + log.Warn(). + Stringer("source_mxid", source.MXID). + Str("sync_reason", reason). + Msg("No contact info found through user in SyncContact") } - puppet.Sync(source, &contact, false, false) + puppet.syncInternal(ctx, source, &contact, false, false) } -func (puppet *Puppet) Sync(source *User, contact *types.ContactInfo, forceAvatarSync, forcePortalSync bool) { +func (puppet *Puppet) Sync(ctx context.Context, source *User, contact *types.ContactInfo, forceAvatarSync, forcePortalSync bool) { + log := zerolog.Ctx(ctx).With(). + Str("method", "Puppet.Sync"). + Stringer("puppet_jid", puppet.JID). + Stringer("source_user_jid", source.JID). + Stringer("source_user_mxid", source.MXID). + Logger() + ctx = log.WithContext(ctx) + puppet.syncInternal(ctx, source, contact, forceAvatarSync, forcePortalSync) +} + +func (puppet *Puppet) syncInternal(ctx context.Context, source *User, contact *types.ContactInfo, forceAvatarSync, forcePortalSync bool) { + log := zerolog.Ctx(ctx) puppet.syncLock.Lock() defer puppet.syncLock.Unlock() - err := puppet.DefaultIntent().EnsureRegistered() + err := puppet.DefaultIntent().EnsureRegistered(ctx) if err != nil { - puppet.log.Errorln("Failed to ensure registered:", err) + log.Err(err).Msg("Failed to ensure registered") } - puppet.log.Debugfln("Syncing info through %s", source.JID) + log.Debug().Stringer("source_jid", source.JID).Msg("Syncing info through user") update := false if contact != nil { if puppet.JID.User == source.JID.User { contact.PushName = source.Client.Store.PushName } - update = puppet.UpdateName(*contact, forcePortalSync) || update + update = puppet.UpdateName(ctx, *contact, forcePortalSync) || update } if len(puppet.Avatar) == 0 || forceAvatarSync || puppet.bridge.Config.Bridge.UserAvatarSync { - update = puppet.UpdateAvatar(source, forcePortalSync) || update + update = puppet.UpdateAvatar(ctx, source, forcePortalSync) || update } - update = puppet.UpdateContactInfo() || update + update = puppet.UpdateContactInfo(ctx) || update if update || puppet.LastSync.Add(24*time.Hour).Before(time.Now()) { puppet.LastSync = time.Now() - puppet.Update() + err = puppet.Update(ctx) + if err != nil { + log.Err(err).Msg("Failed to save puppet after sync") + } } } diff --git a/urlpreview.go b/urlpreview.go index 9715152..192464c 100644 --- a/urlpreview.go +++ b/urlpreview.go @@ -1,5 +1,5 @@ // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2022 Tulir Asokan +// Copyright (C) 2024 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -19,7 +19,6 @@ package main import ( "bytes" "context" - "encoding/json" "image" "net/http" "net/url" @@ -27,33 +26,26 @@ import ( "strings" "time" - "github.com/tidwall/gjson" + "github.com/rs/zerolog" "golang.org/x/net/idna" "google.golang.org/protobuf/proto" "go.mau.fi/whatsmeow" waProto "go.mau.fi/whatsmeow/binary/proto" - "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/crypto/attachment" "maunium.net/go/mautrix/event" ) -type BeeperLinkPreview struct { - mautrix.RespPreviewURL - MatchedURL string `json:"matched_url"` - ImageEncryption *event.EncryptedFileInfo `json:"beeper:image:encryption,omitempty"` -} - -func (portal *Portal) convertURLPreviewToBeeper(intent *appservice.IntentAPI, source *User, msg *waProto.ExtendedTextMessage) []*BeeperLinkPreview { +func (portal *Portal) convertURLPreviewToBeeper(ctx context.Context, intent *appservice.IntentAPI, source *User, msg *waProto.ExtendedTextMessage) []*event.BeeperLinkPreview { if msg.GetMatchedText() == "" { - return []*BeeperLinkPreview{} + return []*event.BeeperLinkPreview{} } - output := &BeeperLinkPreview{ + output := &event.BeeperLinkPreview{ MatchedURL: msg.GetMatchedText(), - RespPreviewURL: mautrix.RespPreviewURL{ + LinkPreview: event.LinkPreview{ CanonicalURL: msg.GetCanonicalUrl(), Title: msg.GetTitle(), Description: msg.GetDescription(), @@ -68,7 +60,7 @@ func (portal *Portal) convertURLPreviewToBeeper(intent *appservice.IntentAPI, so var err error thumbnailData, err = source.Client.DownloadThumbnail(msg) if err != nil { - portal.log.Warnfln("Failed to download thumbnail for link preview: %v", err) + zerolog.Ctx(ctx).Err(err).Msg("Failed to download thumbnail for link preview") } } if thumbnailData == nil && msg.JpegThumbnail != nil { @@ -93,9 +85,9 @@ func (portal *Portal) convertURLPreviewToBeeper(intent *appservice.IntentAPI, so uploadMime = "application/octet-stream" output.ImageEncryption = &event.EncryptedFileInfo{EncryptedFile: *crypto} } - resp, err := intent.UploadBytes(uploadData, uploadMime) + resp, err := intent.UploadBytes(ctx, uploadData, uploadMime) if err != nil { - portal.log.Warnfln("Failed to reupload thumbnail for link preview: %v", err) + zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload thumbnail for link preview") } else { if output.ImageEncryption != nil { output.ImageEncryption.URL = resp.ContentURI.CUString() @@ -108,36 +100,37 @@ func (portal *Portal) convertURLPreviewToBeeper(intent *appservice.IntentAPI, so output.Type = "video.other" } - return []*BeeperLinkPreview{output} + return []*event.BeeperLinkPreview{output} } var URLRegex = regexp.MustCompile(`https?://[^\s/_*]+(?:/\S*)?`) -func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *User, evt *event.Event, dest *waProto.ExtendedTextMessage) bool { - var preview *BeeperLinkPreview +func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *User, content *event.MessageEventContent, dest *waProto.ExtendedTextMessage) bool { + log := zerolog.Ctx(ctx) + var preview *event.BeeperLinkPreview - rawPreview := gjson.GetBytes(evt.Content.VeryRaw, `com\.beeper\.linkpreviews`) - if rawPreview.Exists() && rawPreview.IsArray() { - var previews []BeeperLinkPreview - if err := json.Unmarshal([]byte(rawPreview.Raw), &previews); err != nil || len(previews) == 0 { + if content.BeeperLinkPreviews != nil { + // Note: this check explicitly happens after checking for nil: empty arrays are treated as no previews, + // but omitting the field means the bridge may look for URLs in the message text. + if len(content.BeeperLinkPreviews) == 0 { return false } // WhatsApp only supports a single preview. - preview = &previews[0] + preview = content.BeeperLinkPreviews[0] } else if portal.bridge.Config.Bridge.URLPreviews { - if matchedURL := URLRegex.FindString(evt.Content.AsMessage().Body); len(matchedURL) == 0 { + if matchedURL := URLRegex.FindString(content.Body); len(matchedURL) == 0 { return false } else if parsed, err := url.Parse(matchedURL); err != nil { return false } else if parsed.Host, err = idna.ToASCII(parsed.Host); err != nil { return false - } else if mxPreview, err := portal.MainIntent().GetURLPreview(parsed.String()); err != nil { - portal.log.Warnfln("Failed to fetch preview for %s: %v", matchedURL, err) + } else if mxPreview, err := portal.MainIntent().GetURLPreview(ctx, parsed.String()); err != nil { + log.Err(err).Str("url", matchedURL).Msg("Failed to fetch URL preview") return false } else { - preview = &BeeperLinkPreview{ - RespPreviewURL: *mxPreview, - MatchedURL: matchedURL, + preview = &event.BeeperLinkPreview{ + LinkPreview: *mxPreview, + MatchedURL: matchedURL, } } } @@ -163,22 +156,22 @@ func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *U imageMXC = preview.ImageEncryption.URL.ParseOrIgnore() } if !imageMXC.IsEmpty() { - data, err := portal.MainIntent().DownloadBytesContext(ctx, imageMXC) + data, err := portal.MainIntent().DownloadBytes(ctx, imageMXC) if err != nil { - portal.log.Errorfln("Failed to download URL preview image %s in %s: %v", preview.ImageURL, evt.ID, err) + log.Err(err).Str("image_url", string(preview.ImageURL)).Msg("Failed to download URL preview image") return true } if preview.ImageEncryption != nil { err = preview.ImageEncryption.DecryptInPlace(data) if err != nil { - portal.log.Errorfln("Failed to decrypt URL preview image in %s: %v", evt.ID, err) + log.Err(err).Msg("Failed to decrypt URL preview image") return true } } dest.MediaKeyTimestamp = proto.Int64(time.Now().Unix()) uploadResp, err := sender.Client.Upload(ctx, data, whatsmeow.MediaLinkThumbnail) if err != nil { - portal.log.Errorfln("Failed to upload URL preview thumbnail in %s: %v", evt.ID, err) + log.Err(err).Msg("Failed to reupload URL preview thumbnail") return true } dest.ThumbnailSha256 = uploadResp.FileSHA256 @@ -188,7 +181,7 @@ func (portal *Portal) convertURLPreviewToWhatsApp(ctx context.Context, sender *U var width, height int dest.JpegThumbnail, width, height, err = createThumbnailAndGetSize(data, false) if err != nil { - portal.log.Warnfln("Failed to create JPEG thumbnail for URL preview in %s: %v", evt.ID, err) + log.Err(err).Msg("Failed to create JPEG thumbnail for URL preview") } if preview.ImageHeight > 0 && preview.ImageWidth > 0 { dest.ThumbnailWidth = proto.Uint32(uint32(preview.ImageWidth)) diff --git a/user.go b/user.go index a6f9f68..83890c5 100644 --- a/user.go +++ b/user.go @@ -1,5 +1,5 @@ // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2022 Tulir Asokan +// Copyright (C) 2024 Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -33,9 +33,14 @@ import ( "time" "github.com/rs/zerolog" - "maunium.net/go/maulogger/v2" - "maunium.net/go/maulogger/v2/maulogadapt" - + "go.mau.fi/util/exzerolog" + "go.mau.fi/whatsmeow" + "go.mau.fi/whatsmeow/appstate" + waProto "go.mau.fi/whatsmeow/binary/proto" + "go.mau.fi/whatsmeow/store" + "go.mau.fi/whatsmeow/types" + "go.mau.fi/whatsmeow/types/events" + waLog "go.mau.fi/whatsmeow/util/log" "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/bridge" @@ -46,14 +51,6 @@ import ( "maunium.net/go/mautrix/id" "maunium.net/go/mautrix/pushrules" - "go.mau.fi/whatsmeow" - "go.mau.fi/whatsmeow/appstate" - waProto "go.mau.fi/whatsmeow/binary/proto" - "go.mau.fi/whatsmeow/store" - "go.mau.fi/whatsmeow/types" - "go.mau.fi/whatsmeow/types/events" - waLog "go.mau.fi/whatsmeow/util/log" - "maunium.net/go/mautrix-whatsapp/database" ) @@ -64,8 +61,6 @@ type User struct { bridge *WABridge zlog zerolog.Logger - // Deprecated - log maulogger.Logger Admin bool Whitelisted bool @@ -118,7 +113,13 @@ func (br *WABridge) getUserByMXID(userID id.UserID, onlyIfExists bool) *User { if onlyIfExists { userIDPtr = nil } - return br.loadDBUser(br.DB.User.GetByMXID(userID), userIDPtr) + ctx := context.TODO() + dbUser, err := br.DB.User.GetByMXID(ctx, userID) + if err != nil { + br.ZLog.Err(err).Stringer("mxid", userID).Msg("Failed to get user by MXID from database") + return nil + } + return br.loadDBUser(ctx, dbUser, userIDPtr) } return user } @@ -160,7 +161,13 @@ func (br *WABridge) GetUserByJID(jid types.JID) *User { defer br.usersLock.Unlock() user, ok := br.usersByUsername[jid.User] if !ok { - return br.loadDBUser(br.DB.User.GetByUsername(jid.User), nil) + ctx := context.TODO() + dbUser, err := br.DB.User.GetByUsername(ctx, jid.User) + if err != nil { + br.ZLog.Err(err).Stringer("jid", jid).Msg("Failed to get user by JID from database") + return nil + } + return br.loadDBUser(ctx, dbUser, nil) } return user } @@ -185,26 +192,35 @@ func (user *User) removeFromJIDMap(state status.BridgeState) { func (br *WABridge) GetAllUsers() []*User { br.usersLock.Lock() defer br.usersLock.Unlock() - dbUsers := br.DB.User.GetAll() + ctx := context.TODO() + dbUsers, err := br.DB.User.GetAll(ctx) + if err != nil { + br.ZLog.Error().Err(err).Msg("Failed to get all users from database") + return nil + } output := make([]*User, len(dbUsers)) for index, dbUser := range dbUsers { user, ok := br.usersByMXID[dbUser.MXID] if !ok { - user = br.loadDBUser(dbUser, nil) + user = br.loadDBUser(ctx, dbUser, nil) } output[index] = user } return output } -func (br *WABridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User { +func (br *WABridge) loadDBUser(ctx context.Context, dbUser *database.User, mxid *id.UserID) *User { if dbUser == nil { if mxid == nil { return nil } dbUser = br.DB.User.New() dbUser.MXID = *mxid - dbUser.Insert() + err := dbUser.Insert(ctx) + if err != nil { + br.ZLog.Error().Err(err).Msg("Failed to insert new user into database") + return nil + } } user := br.NewUser(dbUser) br.usersByMXID[user.MXID] = user @@ -212,13 +228,16 @@ func (br *WABridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User { var err error user.Session, err = br.WAContainer.GetDevice(user.JID) if err != nil { - user.log.Errorfln("Failed to load user's whatsapp session: %v", err) + user.zlog.Err(err).Msg("Failed to load user's whatsapp session") } else if user.Session == nil { - user.log.Warnfln("Didn't find session data for %s, treating user as logged out", user.JID) + user.zlog.Warn().Stringer("jid", user.JID).Msg("Didn't find session data for user's JID, treating user as logged out") user.JID = types.EmptyJID - user.Update() + err = user.Update(ctx) + if err != nil { + user.zlog.Err(err).Msg("Failed to save user after clearing JID") + } } else { - user.Session.Log = &waLogger{user.log.Sub("Session")} + user.Session.Log = waLog.Zerolog(user.zlog.With().Str("component", "whatsmeow").Str("db_section", "whatsmeow").Logger()) br.usersByUsername[user.JID.User] = user } } @@ -239,7 +258,6 @@ func (br *WABridge) NewUser(dbUser *database.User) *User { resyncQueue: make(map[types.JID]resyncQueueItem), } - user.log = maulogadapt.ZeroAsMau(&user.zlog) user.PermissionLevel = user.bridge.Config.Bridge.Permissions.Get(user.MXID) user.RelayWhitelisted = user.PermissionLevel >= bridgeconfig.PermissionLevelRelay @@ -271,7 +289,10 @@ func (user *User) EnqueuePuppetResync(puppet *Puppet) { user.resyncQueueLock.Lock() if _, exists := user.resyncQueue[puppet.JID]; !exists { user.resyncQueue[puppet.JID] = resyncQueueItem{puppet: puppet} - user.log.Debugfln("Enqueued resync for %s (next sync in %s)", puppet.JID, user.nextResync.Sub(time.Now())) + user.zlog.Debug(). + Stringer("jid", puppet.JID). + Str("next_resync", time.Until(user.nextResync).String()). + Msg("Enqueued resync for puppet") } user.resyncQueueLock.Unlock() } @@ -283,7 +304,10 @@ func (user *User) EnqueuePortalResync(portal *Portal) { user.resyncQueueLock.Lock() if _, exists := user.resyncQueue[portal.Key.JID]; !exists { user.resyncQueue[portal.Key.JID] = resyncQueueItem{portal: portal} - user.log.Debugfln("Enqueued resync for %s (next sync in %s)", portal.Key.JID, user.nextResync.Sub(time.Now())) + user.zlog.Debug(). + Stringer("jid", portal.Key.JID). + Str("next_resync", time.Until(user.nextResync).String()). + Msg("Enqueued resync for portal") } user.resyncQueueLock.Unlock() } @@ -297,6 +321,8 @@ func (user *User) doPuppetResync() { user.resyncQueueLock.Unlock() return } + log := user.zlog.With().Str("action", "puppet resync").Logger() + ctx := log.WithContext(context.TODO()) queue := user.resyncQueue user.resyncQueue = make(map[types.JID]resyncQueueItem) user.resyncQueueLock.Unlock() @@ -311,7 +337,10 @@ func (user *User) doPuppetResync() { lastSync = item.portal.LastSync } if lastSync.Add(resyncMinInterval).After(time.Now()) { - user.log.Debugfln("Not resyncing %s, last sync was %s ago", jid, time.Now().Sub(lastSync)) + log.Debug(). + Stringer("jid", jid). + Str("last_sync", time.Since(lastSync).String()). + Msg("Not resyncing, last sync was too recent") continue } if item.puppet != nil { @@ -324,39 +353,39 @@ func (user *User) doPuppetResync() { for _, portal := range portals { groupInfo, err := user.Client.GetGroupInfo(portal.Key.JID) if err != nil { - user.log.Warnfln("Failed to get group info for %s to do background sync: %v", portal.Key.JID, err) + log.Warn().Err(err).Stringer("jid", portal.Key.JID).Msg("Failed to get group info for background sync") } else { - user.log.Debugfln("Doing background sync for %s", portal.Key.JID) - portal.UpdateMatrixRoom(user, groupInfo, nil) + log.Debug().Stringer("jid", portal.Key.JID).Msg("Doing background sync for group") + portal.UpdateMatrixRoom(ctx, user, groupInfo, nil) } } if len(puppetJIDs) == 0 { return } - user.log.Debugfln("Doing background sync for users: %+v", puppetJIDs) + log.Debug().Array("jids", exzerolog.ArrayOfStringers(puppetJIDs)).Msg("Doing background sync for users") infos, err := user.Client.GetUserInfo(puppetJIDs) if err != nil { - user.log.Errorfln("Error getting user info for background sync: %v", err) + log.Err(err).Msg("Failed to get user info for background sync") return } for _, puppet := range puppets { info, ok := infos[puppet.JID] if !ok { - user.log.Warnfln("Didn't get info for %s in background sync", puppet.JID) + log.Warn().Stringer("jid", puppet.JID).Msg("Didn't get info for puppet in background sync") continue } var contactPtr *types.ContactInfo contact, err := user.Session.Contacts.GetContact(puppet.JID) if err != nil { - user.log.Warnfln("Failed to get contact info for %s in background sync: %v", puppet.JID, err) + log.Err(err).Stringer("jid", puppet.JID).Msg("Failed to get contact info for puppet in background sync") } else if contact.Found { contactPtr = &contact } - puppet.Sync(user, contactPtr, info.PictureID != "" && info.PictureID != puppet.Avatar, true) + puppet.Sync(ctx, user, contactPtr, info.PictureID != "" && info.PictureID != puppet.Avatar, true) } } -func (user *User) ensureInvited(intent *appservice.IntentAPI, roomID id.RoomID, isDirect bool) (ok bool) { +func (user *User) ensureInvited(ctx context.Context, intent *appservice.IntentAPI, roomID id.RoomID, isDirect bool) (ok bool) { extraContent := make(map[string]interface{}) if isDirect { extraContent["is_direct"] = true @@ -365,22 +394,25 @@ func (user *User) ensureInvited(intent *appservice.IntentAPI, roomID id.RoomID, if customPuppet != nil && customPuppet.CustomIntent() != nil { extraContent["fi.mau.will_auto_accept"] = true } - _, err := intent.InviteUser(roomID, &mautrix.ReqInviteUser{UserID: user.MXID}, extraContent) + _, err := intent.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{UserID: user.MXID}, extraContent) var httpErr mautrix.HTTPError if err != nil && errors.As(err, &httpErr) && httpErr.RespError != nil && strings.Contains(httpErr.RespError.Err, "is already in the room") { - user.bridge.StateStore.SetMembership(roomID, user.MXID, event.MembershipJoin) + err = user.bridge.StateStore.SetMembership(ctx, roomID, user.MXID, event.MembershipJoin) + if err != nil { + user.zlog.Err(err).Stringer("room_id", roomID).Msg("Failed to update membership to join in state store after invite failed") + } ok = true return } else if err != nil { - user.log.Warnfln("Failed to invite user to %s: %v", roomID, err) + user.zlog.Err(err).Stringer("room_id", roomID).Msg("Failed to invite user to room") } else { ok = true } if customPuppet != nil && customPuppet.CustomIntent() != nil { - err = customPuppet.CustomIntent().EnsureJoined(roomID, appservice.EnsureJoinedParams{IgnoreCache: true}) + err = customPuppet.CustomIntent().EnsureJoined(ctx, roomID, appservice.EnsureJoinedParams{IgnoreCache: true}) if err != nil { - user.log.Warnfln("Failed to auto-join %s: %v", roomID, err) + user.zlog.Err(err).Stringer("room_id", roomID).Msg("Failed to auto-join room") ok = false } else { ok = true @@ -389,7 +421,7 @@ func (user *User) ensureInvited(intent *appservice.IntentAPI, roomID id.RoomID, return } -func (user *User) GetSpaceRoom() id.RoomID { +func (user *User) GetSpaceRoom(ctx context.Context) id.RoomID { if !user.bridge.Config.Bridge.PersonalFilteringSpaces { return "" } @@ -401,7 +433,7 @@ func (user *User) GetSpaceRoom() id.RoomID { return user.SpaceRoom } - resp, err := user.bridge.Bot.CreateRoom(&mautrix.ReqCreateRoom{ + resp, err := user.bridge.Bot.CreateRoom(ctx, &mautrix.ReqCreateRoom{ Visibility: "private", Name: "WhatsApp", Topic: "Your WhatsApp bridged chats", @@ -425,21 +457,24 @@ func (user *User) GetSpaceRoom() id.RoomID { }) if err != nil { - user.log.Errorln("Failed to auto-create space room:", err) + user.zlog.Err(err).Msg("Failed to auto-create space room") } else { user.SpaceRoom = resp.RoomID - user.Update() - user.ensureInvited(user.bridge.Bot, user.SpaceRoom, false) + err = user.Update(ctx) + if err != nil { + user.zlog.Err(err).Msg("Failed to save user after creating space room") + } + user.ensureInvited(ctx, user.bridge.Bot, user.SpaceRoom, false) } - } else if !user.spaceMembershipChecked && !user.bridge.StateStore.IsInRoom(user.SpaceRoom, user.MXID) { - user.ensureInvited(user.bridge.Bot, user.SpaceRoom, false) + } else if !user.spaceMembershipChecked && !user.bridge.StateStore.IsInRoom(ctx, user.SpaceRoom, user.MXID) { + user.ensureInvited(ctx, user.bridge.Bot, user.SpaceRoom, false) } user.spaceMembershipChecked = true return user.SpaceRoom } -func (user *User) GetManagementRoom() id.RoomID { +func (user *User) GetManagementRoom(ctx context.Context) id.RoomID { if len(user.ManagementRoom) == 0 { user.mgmtCreateLock.Lock() defer user.mgmtCreateLock.Unlock() @@ -450,13 +485,13 @@ func (user *User) GetManagementRoom() id.RoomID { if !user.bridge.Config.Bridge.FederateRooms { creationContent["m.federate"] = false } - resp, err := user.bridge.Bot.CreateRoom(&mautrix.ReqCreateRoom{ + resp, err := user.bridge.Bot.CreateRoom(ctx, &mautrix.ReqCreateRoom{ Topic: "WhatsApp bridge notices", IsDirect: true, CreationContent: creationContent, }) if err != nil { - user.log.Errorln("Failed to auto-create management room:", err) + user.zlog.Err(err).Msg("Failed to auto-create management room") } else { user.SetManagementRoom(resp.RoomID) } @@ -465,25 +500,27 @@ func (user *User) GetManagementRoom() id.RoomID { } func (user *User) SetManagementRoom(roomID id.RoomID) { + ctx := context.TODO() + existingUser, ok := user.bridge.managementRooms[roomID] if ok { existingUser.ManagementRoom = "" - existingUser.Update() + err := existingUser.Update(ctx) + if err != nil { + user.zlog.Err(err). + Stringer("other_user_mxid", existingUser.MXID). + Msg("Failed to save previous user after removing from old management room") + } } user.ManagementRoom = roomID user.bridge.managementRooms[user.ManagementRoom] = user - user.Update() + err := user.Update(ctx) + if err != nil { + user.zlog.Err(err).Msg("Failed to save user after setting management room") + } } -type waLogger struct{ l maulogger.Logger } - -func (w *waLogger) Debugf(msg string, args ...interface{}) { w.l.Debugfln(msg, args...) } -func (w *waLogger) Infof(msg string, args ...interface{}) { w.l.Infofln(msg, args...) } -func (w *waLogger) Warnf(msg string, args ...interface{}) { w.l.Warnfln(msg, args...) } -func (w *waLogger) Errorf(msg string, args ...interface{}) { w.l.Errorfln(msg, args...) } -func (w *waLogger) Sub(module string) waLog.Logger { return &waLogger{l: w.l.Sub(module)} } - var ErrAlreadyLoggedIn = errors.New("already logged in") func (user *User) obfuscateJID(jid types.JID) string { @@ -493,7 +530,7 @@ func (user *User) obfuscateJID(jid types.JID) string { } func (user *User) createClient(sess *store.Device) { - user.Client = whatsmeow.NewClient(sess, &waLogger{user.log.Sub("Client")}) + user.Client = whatsmeow.NewClient(sess, waLog.Zerolog(user.zlog.With().Str("component", "whatsmeow").Logger())) user.Client.AddEventHandler(user.HandleEvent) user.Client.SetForceActiveDeliveryReceipts(user.bridge.Config.Bridge.ForceActiveDeliveryReceipts) user.Client.AutomaticMessageRerequestFromPhone = true @@ -525,7 +562,7 @@ func (user *User) Login(ctx context.Context) (<-chan whatsmeow.QRChannelItem, er user.unlockedDeleteConnection() } newSession := user.bridge.WAContainer.NewDevice() - newSession.Log = &waLogger{user.log.Sub("Session")} + newSession.Log = waLog.Zerolog(user.zlog.With().Str("component", "whatsmeow session").Logger()) user.createClient(newSession) qrChan, err := user.Client.GetQRChannel(ctx) if err != nil { @@ -546,12 +583,12 @@ func (user *User) Connect() bool { } else if user.Session == nil { return false } - user.log.Debugln("Connecting to WhatsApp") + user.zlog.Debug().Msg("Connecting to WhatsApp") user.BridgeState.Send(status.BridgeState{StateEvent: status.StateConnecting, Error: WAConnecting}) user.createClient(user.Session) err := user.Client.Connect() if err != nil { - user.log.Warnln("Error connecting to WhatsApp:", err) + user.zlog.Err(err).Msg("Error connecting to WhatsApp") user.BridgeState.Send(status.BridgeState{ StateEvent: status.StateUnknownError, Error: WAConnectionFailed, @@ -584,24 +621,40 @@ func (user *User) HasSession() bool { return user.Session != nil } -func (user *User) DeleteSession() { +func (user *User) DeleteSession(ctx context.Context) { + log := zerolog.Ctx(ctx) if user.Session != nil { err := user.Session.Delete() if err != nil { - user.log.Warnln("Failed to delete session:", err) + log.Err(err).Msg("Failed to delete session") } user.Session = nil } if !user.JID.IsEmpty() { user.JID = types.EmptyJID - user.Update() + err := user.Update(ctx) + if err != nil { + log.Err(err).Msg("Failed to save user after clearing JID") + } } // Delete all of the backfill and history sync data. - user.bridge.DB.Backfill.DeleteAll(user.MXID) - user.bridge.DB.HistorySync.DeleteAllConversations(user.MXID) - user.bridge.DB.HistorySync.DeleteAllMessages(user.MXID) - user.bridge.DB.MediaBackfillRequest.DeleteAllMediaBackfillRequests(user.MXID) + err := user.bridge.DB.BackfillQueue.DeleteAll(ctx, user.MXID) + if err != nil { + log.Err(err).Msg("Failed to delete backfill queue data") + } + err = user.bridge.DB.HistorySync.DeleteAllConversations(ctx, user.MXID) + if err != nil { + log.Err(err).Msg("Failed to delete historical conversation list") + } + err = user.bridge.DB.HistorySync.DeleteAllMessages(ctx, user.MXID) + if err != nil { + log.Err(err).Msg("Failed to delete historical messages") + } + err = user.bridge.DB.MediaBackfillRequest.DeleteAllMediaBackfillRequests(ctx, user.MXID) + if err != nil { + log.Err(err).Msg("Failed to delete media backfill requests") + } } func (user *User) IsConnected() bool { @@ -612,15 +665,15 @@ func (user *User) IsLoggedIn() bool { return user.IsConnected() && user.Client.IsLoggedIn() } -func (user *User) sendMarkdownBridgeAlert(formatString string, args ...interface{}) { +func (user *User) sendMarkdownBridgeAlert(ctx context.Context, formatString string, args ...interface{}) { if user.bridge.Config.Bridge.DisableBridgeAlerts { return } notice := fmt.Sprintf(formatString, args...) content := format.RenderMarkdown(notice, true, false) - _, err := user.bridge.Bot.SendMessageEvent(user.GetManagementRoom(), event.EventMessage, content) + _, err := user.bridge.Bot.SendMessageEvent(ctx, user.GetManagementRoom(ctx), event.EventMessage, content) if err != nil { - user.log.Warnf("Failed to send bridge alert \"%s\": %v", notice, err) + user.zlog.Warn().Err(err).Str("notice", notice).Msg("Failed to send bridge alert") } } @@ -653,19 +706,19 @@ const PhoneDisconnectWarningTime = 12 * 24 * time.Hour // 12 days const PhoneDisconnectPingTime = 10 * 24 * time.Hour const PhoneMinPingInterval = 24 * time.Hour -func (user *User) sendHackyPhonePing() { +func (user *User) sendHackyPhonePing(ctx context.Context) { user.PhoneLastPinged = time.Now() msgID := user.Client.GenerateMessageID() keyIDs := make([]*waProto.AppStateSyncKeyId, 0, 1) - lastKeyID, err := user.GetLastAppStateKeyID() + lastKeyID, err := user.GetLastAppStateKeyID(ctx) if lastKeyID != nil { keyIDs = append(keyIDs, &waProto.AppStateSyncKeyId{ KeyId: lastKeyID, }) } else { - user.log.Warnfln("Failed to get last app state key ID to send hacky phone ping: %v - sending empty request", err) + user.zlog.Warn().Err(err).Msg("Failed to get last app state key ID to send hacky phone ping - sending empty request") } - resp, err := user.Client.SendMessage(context.Background(), user.JID.ToNonAD(), &waProto.Message{ + resp, err := user.Client.SendMessage(ctx, user.JID.ToNonAD(), &waProto.Message{ ProtocolMessage: &waProto.ProtocolMessage{ Type: waProto.ProtocolMessage_APP_STATE_SYNC_KEY_REQUEST.Enum(), AppStateSyncKeyRequest: &waProto.AppStateSyncKeyRequest{ @@ -674,18 +727,24 @@ func (user *User) sendHackyPhonePing() { }, }, whatsmeow.SendRequestExtra{Peer: true, ID: msgID}) if err != nil { - user.log.Warnfln("Failed to send hacky phone ping: %v", err) + user.zlog.Err(err).Msg("Failed to send hacky phone ping") } else { - user.log.Debugfln("Sent hacky phone ping %s/%s because phone has been offline for >10 days", msgID, resp.Timestamp.Unix()) + user.zlog.Debug(). + Str("message_id", msgID). + Int64("message_ts", resp.Timestamp.Unix()). + Msg("Sent hacky phone ping because phone has been offline for >10 days") user.PhoneLastPinged = resp.Timestamp - user.Update() + err = user.Update(ctx) + if err != nil { + user.zlog.Err(err).Msg("Failed to save user after sending hacky phone ping") + } } } func (user *User) PhoneRecentlySeen(doPing bool) bool { if doPing && !user.PhoneLastSeen.IsZero() && user.PhoneLastSeen.Add(PhoneDisconnectPingTime).Before(time.Now()) && user.PhoneLastPinged.Add(PhoneMinPingInterval).Before(time.Now()) { // Over 10 days since the phone was seen and over a day since the last somewhat hacky ping, send a new ping. - go user.sendHackyPhonePing() + go user.sendHackyPhonePing(context.TODO()) } return user.PhoneLastSeen.IsZero() || user.PhoneLastSeen.Add(PhoneDisconnectWarningTime).After(time.Now()) } @@ -699,14 +758,22 @@ func (user *User) phoneSeen(ts time.Time) { return } else if !user.PhoneRecentlySeen(false) { if user.BridgeState.GetPrev().Error == WAPhoneOffline && user.IsConnected() { - user.log.Debugfln("Saw phone after current bridge state said it has been offline, switching state back to connected") + user.zlog.Debug().Msg("Saw phone after current bridge state said it has been offline, switching state back to connected") user.BridgeState.Send(status.BridgeState{StateEvent: status.StateConnected}) } else { - user.log.Debugfln("Saw phone after current bridge state said it has been offline, not sending new bridge state (prev: %s, connected: %t)", user.BridgeState.GetPrev().Error, user.IsConnected()) + user.zlog.Debug(). + Bool("is_connected", user.IsConnected()). + Str("prev_error", string(user.BridgeState.GetPrev().Error)). + Msg("Saw phone after current bridge state said it has been offline, not sending new bridge state") } } user.PhoneLastSeen = ts - go user.Update() + go func() { + err := user.Update(context.TODO()) + if err != nil { + user.zlog.Err(err).Msg("Failed to save user after updating phone last seen") + } + }() } func formatDisconnectTime(dur time.Duration) string { @@ -721,20 +788,25 @@ func formatDisconnectTime(dur time.Duration) string { } } -func (user *User) sendPhoneOfflineWarning() { +func (user *User) sendPhoneOfflineWarning(ctx context.Context) { if user.lastPhoneOfflineWarning.Add(12 * time.Hour).After(time.Now()) { // Don't spam the warning too much return } user.lastPhoneOfflineWarning = time.Now() timeSinceSeen := time.Now().Sub(user.PhoneLastSeen) - user.sendMarkdownBridgeAlert("Your phone hasn't been seen in %s. The server will force the bridge to log out if the phone is not active at least every 2 weeks.", formatDisconnectTime(timeSinceSeen)) + user.sendMarkdownBridgeAlert(ctx, "Your phone hasn't been seen in %s. The server will force the bridge to log out if the phone is not active at least every 2 weeks.", formatDisconnectTime(timeSinceSeen)) } func (user *User) HandleEvent(event interface{}) { + ctx := user.zlog.With(). + Str("action", "handle whatsapp event"). + Type("wa_event_type", event). + Logger(). + WithContext(context.TODO()) switch v := event.(type) { case *events.LoggedOut: - go user.handleLoggedOut(v.OnConnect, v.Reason) + go user.handleLoggedOut(ctx, v.OnConnect, v.Reason) case *events.Connected: user.bridge.Metrics.TrackConnectionState(user.JID, true) user.bridge.Metrics.TrackLoginState(user.JID, true) @@ -742,7 +814,7 @@ func (user *User) HandleEvent(event interface{}) { go func() { err := user.Client.SendPresence(user.lastPresence) if err != nil { - user.log.Warnln("Failed to send initial presence:", err) + user.zlog.Warn().Err(err).Msg("Failed to send initial presence after connecting") } }() } @@ -753,18 +825,25 @@ func (user *User) HandleEvent(event interface{}) { user.historySyncLoopsStarted = true } case *events.OfflineSyncPreview: - user.log.Infofln("Server says it's going to send %d messages and %d receipts that were missed during downtime", v.Messages, v.Receipts) + user.zlog.Info(). + Int("message_count", v.Messages). + Int("receipt_count", v.Receipts). + Int("notification_count", v.Notifications). + Int("app_data_change_count", v.AppDataChanges). + Msg("Server sent number of events that were missed during downtime") user.BridgeState.Send(status.BridgeState{ StateEvent: status.StateBackfilling, Message: fmt.Sprintf("backfilling %d messages and %d receipts", v.Messages, v.Receipts), }) case *events.OfflineSyncCompleted: if !user.PhoneRecentlySeen(true) { - user.log.Infofln("Offline sync completed, but phone last seen date is still %s - sending phone offline bridge status", user.PhoneLastSeen) + user.zlog.Info(). + Time("phone_last_seen", user.PhoneLastSeen). + Msg("Offline sync completed, but phone last seen date is still old - sending phone offline bridge status") user.BridgeState.Send(status.BridgeState{StateEvent: status.StateTransientDisconnect, Error: WAPhoneOffline}) } else { if user.BridgeState.GetPrev().StateEvent == status.StateBackfilling { - user.log.Infoln("Offline sync completed") + user.zlog.Info().Msg("Offline sync completed") } user.BridgeState.Send(status.BridgeState{StateEvent: status.StateConnected}) } @@ -772,13 +851,13 @@ func (user *User) HandleEvent(event interface{}) { if len(user.Client.Store.PushName) > 0 && v.Name == appstate.WAPatchCriticalBlock { err := user.Client.SendPresence(user.lastPresence) if err != nil { - user.log.Warnln("Failed to send presence after app state sync:", err) + user.zlog.Warn().Err(err).Msg("Failed to send presence after app state sync") } } else if v.Name == appstate.WAPatchCriticalUnblockLow { go func() { err := user.ResyncContacts(false) if err != nil { - user.log.Errorln("Failed to resync puppets: %v", err) + user.zlog.Err(err).Msg("Failed to resync contacts after app state sync") } }() } @@ -787,11 +866,11 @@ func (user *User) HandleEvent(event interface{}) { // This makes sure that outgoing messages always have the right pushname. err := user.Client.SendPresence(user.lastPresence) if err != nil { - user.log.Warnln("Failed to send presence after push name update:", err) + user.zlog.Warn().Err(err).Msg("Failed to send presence after push name update") } _, _, err = user.Client.Store.Contacts.PutPushName(user.JID.ToNonAD(), v.Action.GetName()) if err != nil { - user.log.Warnln("Failed to update push name in store:", err) + user.zlog.Err(err).Msg("Failed to update push name in store") } go user.syncPuppet(user.JID.ToNonAD(), "push name setting") case *events.PairSuccess: @@ -799,7 +878,10 @@ func (user *User) HandleEvent(event interface{}) { user.Session = user.Client.Store user.JID = v.ID user.addToJIDMap() - user.Update() + err := user.Update(ctx) + if err != nil { + user.zlog.Err(err).Msg("Failed to save user after pair success") + } case *events.StreamError: var message string if v.Code != "" { @@ -813,19 +895,19 @@ func (user *User) HandleEvent(event interface{}) { user.bridge.Metrics.TrackConnectionState(user.JID, false) case *events.StreamReplaced: if user.bridge.Config.Bridge.CrashOnStreamReplaced { - user.log.Infofln("Stopping bridge due to StreamReplaced event") + user.zlog.Info().Msg("Stopping bridge due to StreamReplaced event") user.bridge.ManualStop(60) } else { user.BridgeState.Send(status.BridgeState{StateEvent: status.StateUnknownError, Message: "Stream replaced"}) user.bridge.Metrics.TrackConnectionState(user.JID, false) - user.sendMarkdownBridgeAlert("The bridge was started in another location. Use `reconnect` to reconnect this one.") + user.sendMarkdownBridgeAlert(ctx, "The bridge was started in another location. Use `reconnect` to reconnect this one.") } case *events.ConnectFailure: user.BridgeState.Send(status.BridgeState{StateEvent: status.StateUnknownError, Message: fmt.Sprintf("Unknown connection failure: %s (%s)", v.Reason, v.Message)}) user.bridge.Metrics.TrackConnectionState(user.JID, false) user.bridge.Metrics.TrackConnectionFailure(fmt.Sprintf("status-%d", v.Reason)) case *events.ClientOutdated: - user.log.Errorfln("Got a client outdated connect failure. The bridge is likely out of date, please update immediately.") + user.zlog.Error().Msg("Got a client outdated connect failure. The bridge is likely out of date, please update immediately.") user.BridgeState.Send(status.BridgeState{StateEvent: status.StateUnknownError, Message: "Connect failure: 405 client outdated"}) user.bridge.Metrics.TrackConnectionState(user.JID, false) user.bridge.Metrics.TrackConnectionFailure("client-outdated") @@ -857,14 +939,14 @@ func (user *User) HandleEvent(event interface{}) { case *events.NewsletterLeave: go user.handleNewsletterLeave(v) case *events.Picture: - go user.handlePictureUpdate(v) + go user.handlePictureUpdate(ctx, v) case *events.Receipt: if v.IsFromMe && v.Sender.Device == 0 { user.phoneSeen(v.Timestamp) } go user.handleReceipt(v) case *events.ChatPresence: - go user.handleChatPresence(v) + go user.handleChatPresence(ctx, v) case *events.Message: portal := user.GetPortalByMessageSource(v.Info.MessageSource) portal.events <- &PortalEvent{ @@ -922,45 +1004,45 @@ func (user *User) HandleEvent(event interface{}) { if v.Action.GetMuted() { mutedUntil = time.Unix(v.Action.GetMuteEndTimestamp(), 0) } - go user.updateChatMute(nil, portal, mutedUntil) + go user.updateChatMute(ctx, nil, portal, mutedUntil) } case *events.Archive: portal := user.GetPortalByJID(v.JID) if portal != nil { - go user.updateChatTag(nil, portal, user.bridge.Config.Bridge.ArchiveTag, v.Action.GetArchived()) + go user.updateChatTag(ctx, nil, portal, user.bridge.Config.Bridge.ArchiveTag, v.Action.GetArchived()) } case *events.Pin: portal := user.GetPortalByJID(v.JID) if portal != nil { - go user.updateChatTag(nil, portal, user.bridge.Config.Bridge.PinnedTag, v.Action.GetPinned()) + go user.updateChatTag(ctx, nil, portal, user.bridge.Config.Bridge.PinnedTag, v.Action.GetPinned()) } case *events.AppState: // Ignore case *events.KeepAliveTimeout: user.BridgeState.Send(status.BridgeState{StateEvent: status.StateTransientDisconnect, Error: WAKeepaliveTimeout}) case *events.KeepAliveRestored: - user.log.Infof("Keepalive restored after timeouts, sending connected event") + user.zlog.Info().Msg("Keepalive restored after timeouts, sending connected event") user.BridgeState.Send(status.BridgeState{StateEvent: status.StateConnected}) case *events.MarkChatAsRead: if user.bridge.Config.Bridge.SyncManualMarkedUnread { - user.markUnread(user.GetPortalByJID(v.JID), !v.Action.GetRead()) + user.markUnread(ctx, user.GetPortalByJID(v.JID), !v.Action.GetRead()) } case *events.DeleteForMe: portal := user.GetPortalByJID(v.ChatJID) if portal != nil { - portal.deleteForMe(user, v) + portal.deleteForMe(ctx, user, v) } case *events.DeleteChat: portal := user.GetPortalByJID(v.JID) if portal != nil { - portal.HandleWhatsAppDeleteChat(user) + portal.HandleWhatsAppDeleteChat(ctx, user) } default: - user.log.Debugfln("Unknown type of event in HandleEvent: %T", v) + user.zlog.Debug().Type("event_type", v).Msg("Unknown type of event in HandleEvent") } } -func (user *User) updateChatMute(intent *appservice.IntentAPI, portal *Portal, mutedUntil time.Time) { +func (user *User) updateChatMute(ctx context.Context, intent *appservice.IntentAPI, portal *Portal, mutedUntil time.Time) { if len(portal.MXID) == 0 || !user.bridge.Config.Bridge.MuteBridging { return } else if intent == nil { @@ -972,16 +1054,22 @@ func (user *User) updateChatMute(intent *appservice.IntentAPI, portal *Portal, m } var err error if mutedUntil.IsZero() && mutedUntil.Before(time.Now()) { - user.log.Debugfln("Portal %s is muted until %s, unmuting...", portal.MXID, mutedUntil) - err = intent.DeletePushRule("global", pushrules.RoomRule, string(portal.MXID)) + user.zlog.Debug(). + Stringer("portal_mxid", portal.MXID). + Time("muted_until", mutedUntil). + Msg("Portal muted until time is in the past, unmuting") + err = intent.DeletePushRule(ctx, "global", pushrules.RoomRule, string(portal.MXID)) } else { - user.log.Debugfln("Portal %s is muted until %s, muting...", portal.MXID, mutedUntil) - err = intent.PutPushRule("global", pushrules.RoomRule, string(portal.MXID), &mautrix.ReqPutPushRule{ + user.zlog.Debug(). + Stringer("portal_mxid", portal.MXID). + Time("muted_until", mutedUntil). + Msg("Portal muted until time is in the future, muting") + err = intent.PutPushRule(ctx, "global", pushrules.RoomRule, string(portal.MXID), &mautrix.ReqPutPushRule{ Actions: []pushrules.PushActionType{pushrules.ActionDontNotify}, }) } if err != nil && !errors.Is(err, mautrix.MNotFound) { - user.log.Warnfln("Failed to update push rule for %s through double puppet: %v", portal.MXID, err) + user.zlog.Err(err).Stringer("portal_mxid", portal.MXID).Msg("Failed to update push rule through double puppet") } } @@ -994,7 +1082,7 @@ type CustomTagEventContent struct { Tags map[string]CustomTagData `json:"tags"` } -func (user *User) updateChatTag(intent *appservice.IntentAPI, portal *Portal, tag string, active bool) { +func (user *User) updateChatTag(ctx context.Context, intent *appservice.IntentAPI, portal *Portal, tag string, active bool) { if len(portal.MXID) == 0 || len(tag) == 0 { return } else if intent == nil { @@ -1005,23 +1093,23 @@ func (user *User) updateChatTag(intent *appservice.IntentAPI, portal *Portal, ta intent = doublePuppet.CustomIntent() } var existingTags CustomTagEventContent - err := intent.GetTagsWithCustomData(portal.MXID, &existingTags) + err := intent.GetTagsWithCustomData(ctx, portal.MXID, &existingTags) if err != nil && !errors.Is(err, mautrix.MNotFound) { - user.log.Warnfln("Failed to get tags of %s: %v", portal.MXID, err) + user.zlog.Err(err).Stringer("portal_mxid", portal.MXID).Msg("Failed to get tags through double puppet") } currentTag, ok := existingTags.Tags[tag] if active && !ok { - user.log.Debugln("Adding tag", tag, "to", portal.MXID) + user.zlog.Debug().Stringer("portal_mxid", portal.MXID).Str("tag", tag).Msg("Adding tag to portal") data := CustomTagData{Order: "0.5", DoublePuppet: user.bridge.Name} - err = intent.AddTagWithCustomData(portal.MXID, tag, &data) + err = intent.AddTagWithCustomData(ctx, portal.MXID, tag, &data) } else if !active && ok && currentTag.DoublePuppet == user.bridge.Name { - user.log.Debugln("Removing tag", tag, "from", portal.MXID) - err = intent.RemoveTag(portal.MXID, tag) + user.zlog.Debug().Stringer("portal_mxid", portal.MXID).Str("tag", tag).Msg("Removing tag from portal") + err = intent.RemoveTag(ctx, portal.MXID, tag) } else { err = nil } if err != nil { - user.log.Warnfln("Failed to update tag %s for %s through double puppet: %v", tag, portal.MXID, err) + user.zlog.Err(err).Stringer("portal_mxid", portal.MXID).Str("tag", tag).Msg("Failed to update tag through double puppet") } } @@ -1036,7 +1124,7 @@ type CustomReadMarkers struct { FullyReadExtra CustomReadReceipt `json:"com.beeper.fully_read.extra"` } -func (user *User) syncChatDoublePuppetDetails(portal *Portal, justCreated bool) { +func (user *User) syncChatDoublePuppetDetails(ctx context.Context, portal *Portal, justCreated bool) { doublePuppet := portal.bridge.GetPuppetByCustomMXID(user.MXID) if doublePuppet == nil { return @@ -1047,30 +1135,34 @@ func (user *User) syncChatDoublePuppetDetails(portal *Portal, justCreated bool) if justCreated || !user.bridge.Config.Bridge.TagOnlyOnCreate { chat, err := user.Client.Store.ChatSettings.GetChatSettings(portal.Key.JID) if err != nil { - user.log.Warnfln("Failed to get settings of %s: %v", portal.Key.JID, err) + user.zlog.Err(err).Stringer("portal_jid", portal.Key.JID).Msg("Failed to get chat settings from store") return } intent := doublePuppet.CustomIntent() if portal.Key.JID == types.StatusBroadcastJID && justCreated { if user.bridge.Config.Bridge.MuteStatusBroadcast { - user.updateChatMute(intent, portal, time.Now().Add(365*24*time.Hour)) + user.updateChatMute(ctx, intent, portal, time.Now().Add(365*24*time.Hour)) } if len(user.bridge.Config.Bridge.StatusBroadcastTag) > 0 { - user.updateChatTag(intent, portal, user.bridge.Config.Bridge.StatusBroadcastTag, true) + user.updateChatTag(ctx, intent, portal, user.bridge.Config.Bridge.StatusBroadcastTag, true) } return } else if !chat.Found { return } - user.updateChatMute(intent, portal, chat.MutedUntil) - user.updateChatTag(intent, portal, user.bridge.Config.Bridge.ArchiveTag, chat.Archived) - user.updateChatTag(intent, portal, user.bridge.Config.Bridge.PinnedTag, chat.Pinned) + user.updateChatMute(ctx, intent, portal, chat.MutedUntil) + user.updateChatTag(ctx, intent, portal, user.bridge.Config.Bridge.ArchiveTag, chat.Archived) + user.updateChatTag(ctx, intent, portal, user.bridge.Config.Bridge.PinnedTag, chat.Pinned) } } -func (user *User) getDirectChats() map[id.UserID][]id.RoomID { +func (user *User) getDirectChats(ctx context.Context) map[id.UserID][]id.RoomID { res := make(map[id.UserID][]id.RoomID) - privateChats := user.bridge.DB.Portal.FindPrivateChats(user.JID.ToNonAD()) + privateChats, err := user.bridge.DB.Portal.FindPrivateChats(ctx, user.JID.ToNonAD()) + if err != nil { + user.zlog.Err(err).Msg("Failed to get private chats of user") + return res + } for _, portal := range privateChats { if len(portal.MXID) > 0 { res[user.bridge.FormatPuppetMXID(portal.Key.JID)] = []id.RoomID{portal.MXID} @@ -1079,7 +1171,7 @@ func (user *User) getDirectChats() map[id.UserID][]id.RoomID { return res } -func (user *User) UpdateDirectChats(chats map[id.UserID][]id.RoomID) { +func (user *User) UpdateDirectChats(ctx context.Context, chats map[id.UserID][]id.RoomID) { if !user.bridge.Config.Bridge.SyncDirectChatList { return } @@ -1090,14 +1182,14 @@ func (user *User) UpdateDirectChats(chats map[id.UserID][]id.RoomID) { intent := puppet.CustomIntent() method := http.MethodPatch if chats == nil { - chats = user.getDirectChats() + chats = user.getDirectChats(ctx) method = http.MethodPut } - user.log.Debugln("Updating m.direct list on homeserver") + user.zlog.Debug().Msg("Updating m.direct list on homeserver") var err error if user.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareAsmux { urlPath := intent.BuildClientURL("unstable", "com.beeper.asmux", "dms") - _, err = intent.MakeFullRequest(mautrix.FullRequest{ + _, err = intent.MakeFullRequest(ctx, mautrix.FullRequest{ Method: method, URL: urlPath, Headers: http.Header{"X-Asmux-Auth": {user.bridge.AS.Registration.AppToken}}, @@ -1105,9 +1197,9 @@ func (user *User) UpdateDirectChats(chats map[id.UserID][]id.RoomID) { }) } else { existingChats := make(map[id.UserID][]id.RoomID) - err = intent.GetAccountData(event.AccountDataDirectChats.Type, &existingChats) + err = intent.GetAccountData(ctx, event.AccountDataDirectChats.Type, &existingChats) if err != nil { - user.log.Warnln("Failed to get m.direct list to update it:", err) + user.zlog.Err(err).Msg("Failed to get m.direct list to update it") return } for userID, rooms := range existingChats { @@ -1119,14 +1211,14 @@ func (user *User) UpdateDirectChats(chats map[id.UserID][]id.RoomID) { chats[userID] = rooms } } - err = intent.SetAccountData(event.AccountDataDirectChats.Type, &chats) + err = intent.SetAccountData(ctx, event.AccountDataDirectChats.Type, &chats) } if err != nil { - user.log.Warnln("Failed to update m.direct list:", err) + user.zlog.Err(err).Msg("Failed to update m.direct list") } } -func (user *User) handleLoggedOut(onConnect bool, reason events.ConnectFailureReason) { +func (user *User) handleLoggedOut(ctx context.Context, onConnect bool, reason events.ConnectFailureReason) { errorCode := WAUnknownLogout if reason == events.ConnectFailureLoggedOut { errorCode = WALoggedOut @@ -1137,11 +1229,14 @@ func (user *User) handleLoggedOut(onConnect bool, reason events.ConnectFailureRe user.DeleteConnection() user.Session = nil user.JID = types.EmptyJID - user.Update() + err := user.Update(ctx) + if err != nil { + user.zlog.Err(err).Msg("Failed to save user after getting logged out") + } if onConnect { - user.sendMarkdownBridgeAlert("Connecting to WhatsApp failed as the device was unlinked (error %s). Please link the bridge to your phone again.", reason) + user.sendMarkdownBridgeAlert(ctx, "Connecting to WhatsApp failed as the device was unlinked (error %s). Please link the bridge to your phone again.", reason) } else { - user.sendMarkdownBridgeAlert("You were logged out from another device. Please link the bridge to your phone again.") + user.sendMarkdownBridgeAlert(ctx, "You were logged out from another device. Please link the bridge to your phone again.") } } @@ -1165,7 +1260,7 @@ func (user *User) GetPortalByJID(jid types.JID) *Portal { } func (user *User) syncPuppet(jid types.JID, reason string) { - user.bridge.GetPuppetByJID(jid).SyncContact(user, false, false, reason) + user.bridge.GetPuppetByJID(jid).SyncContact(user.zlog.WithContext(context.TODO()), user, false, false, reason) } func (user *User) ResyncContacts(forceAvatarSync bool) error { @@ -1173,13 +1268,14 @@ func (user *User) ResyncContacts(forceAvatarSync bool) error { if err != nil { return fmt.Errorf("failed to get cached contacts: %w", err) } - user.log.Infofln("Resyncing displaynames with %d contacts", len(contacts)) + user.zlog.Info().Int("contact_count", len(contacts)).Msg("Resyncing displaynames with contact info") + ctx := user.zlog.With().Str("action", "resync contacts").Logger().WithContext(context.TODO()) for jid, contact := range contacts { puppet := user.bridge.GetPuppetByJID(jid) if puppet != nil { - puppet.Sync(user, &contact, forceAvatarSync, true) + puppet.Sync(ctx, user, &contact, forceAvatarSync, true) } else { - user.log.Warnfln("Got a nil puppet for %s while syncing contacts", jid) + user.zlog.Warn().Stringer("jid", jid).Msg("Got a nil puppet while syncing contacts") } } return nil @@ -1194,17 +1290,18 @@ func (user *User) ResyncGroups(createPortals bool) error { user.groupListCache = groups user.groupListCacheTime = time.Now() user.groupListCacheLock.Unlock() + ctx := user.zlog.With().Str("method", "ResyncGroups").Logger().WithContext(context.TODO()) for _, group := range groups { portal := user.GetPortalByJID(group.JID) if len(portal.MXID) == 0 { if createPortals { - err = portal.CreateMatrixRoom(user, group, nil, true, true) + err = portal.CreateMatrixRoom(ctx, user, group, nil, true, true) if err != nil { return fmt.Errorf("failed to create room for %s: %w", group.JID, err) } } } else { - portal.UpdateMatrixRoom(user, group, nil) + portal.UpdateMatrixRoom(ctx, user, group, nil) } } return nil @@ -1212,7 +1309,7 @@ func (user *User) ResyncGroups(createPortals bool) error { const WATypingTimeout = 15 * time.Second -func (user *User) handleChatPresence(presence *events.ChatPresence) { +func (user *User) handleChatPresence(ctx context.Context, presence *events.ChatPresence) { puppet := user.bridge.GetPuppetByJID(presence.Sender) if puppet == nil { return @@ -1226,13 +1323,13 @@ func (user *User) handleChatPresence(presence *events.ChatPresence) { if puppet.typingIn == portal.MXID { return } - _, _ = puppet.IntentFor(portal).UserTyping(puppet.typingIn, false, 0) + _, _ = puppet.IntentFor(portal).UserTyping(ctx, puppet.typingIn, false, 0) } - _, _ = puppet.IntentFor(portal).UserTyping(portal.MXID, true, WATypingTimeout) + _, _ = puppet.IntentFor(portal).UserTyping(ctx, portal.MXID, true, WATypingTimeout) puppet.typingIn = portal.MXID puppet.typingAt = time.Now() } else { - _, _ = puppet.IntentFor(portal).UserTyping(portal.MXID, false, 0) + _, _ = puppet.IntentFor(portal).UserTyping(ctx, portal.MXID, false, 0) puppet.typingIn = "" } } @@ -1265,64 +1362,75 @@ func (user *User) makeReadMarkerContent(eventID id.EventID, doublePuppet bool) C } } -func (user *User) markSelfReadFull(portal *Portal) { +func (user *User) markSelfReadFull(ctx context.Context, portal *Portal) { puppet := user.bridge.GetPuppetByCustomMXID(user.MXID) if puppet == nil || puppet.CustomIntent() == nil { return } - lastMessage := user.bridge.DB.Message.GetLastInChat(portal.Key) - if lastMessage == nil { + lastMessage, err := user.bridge.DB.Message.GetLastInChat(ctx, portal.Key) + if err != nil { + user.zlog.Err(err).Msg("Failed to get last message in chat to mark as read") + return + } else if lastMessage == nil { return } - user.SetLastReadTS(portal.Key, lastMessage.Timestamp) - err := puppet.CustomIntent().SetReadMarkers(portal.MXID, user.makeReadMarkerContent(lastMessage.MXID, true)) + user.SetLastReadTS(ctx, portal.Key, lastMessage.Timestamp) + err = puppet.CustomIntent().SetReadMarkers(ctx, portal.MXID, user.makeReadMarkerContent(lastMessage.MXID, true)) if err != nil { - user.log.Warnfln("Failed to mark %s (last message) in %s as read: %v", lastMessage.MXID, portal.MXID, err) + user.zlog.Err(err). + Stringer("portal_mxid", portal.MXID). + Stringer("last_message_mxid", lastMessage.MXID). + Msg("Failed to mark last message in chat as read") } else { - user.log.Debugfln("Marked %s (last message) in %s as read", lastMessage.MXID, portal.MXID) + user.zlog.Debug(). + Stringer("portal_mxid", portal.MXID). + Stringer("last_message_mxid", lastMessage.MXID). + Msg("Marked last message in chat as read") } } -func (user *User) markUnread(portal *Portal, unread bool) { +func (user *User) markUnread(ctx context.Context, portal *Portal, unread bool) { puppet := user.bridge.GetPuppetByCustomMXID(user.MXID) if puppet == nil || puppet.CustomIntent() == nil { return } - err := puppet.CustomIntent().SetRoomAccountData(portal.MXID, "m.marked_unread", + err := puppet.CustomIntent().SetRoomAccountData(ctx, portal.MXID, "m.marked_unread", map[string]bool{"unread": unread}) if err != nil { - user.log.Warnfln("Failed to mark %s as unread via m.marked_unread: %v", portal.MXID, err) + user.zlog.Err(err).Stringer("portal_mxid", portal.MXID).Msg("Failed to mark room as unread (m.marked_unread)") } else { - user.log.Debugfln("Marked %s as unread via m.marked_unread: %v", portal.MXID, err) + user.zlog.Debug().Stringer("portal_mxid", portal.MXID).Msg("Marked room as unread (m.marked_unread)") } - err = puppet.CustomIntent().SetRoomAccountData(portal.MXID, "com.famedly.marked_unread", + err = puppet.CustomIntent().SetRoomAccountData(ctx, portal.MXID, "com.famedly.marked_unread", map[string]bool{"unread": unread}) if err != nil { - user.log.Warnfln("Failed to mark %s as unread via com.famedly.marked_unread: %v", portal.MXID, err) + user.zlog.Err(err).Stringer("portal_mxid", portal.MXID).Msg("Failed to mark room as unread (com.famedly.marked_unread)") } else { - user.log.Debugfln("Marked %s as unread via com.famedly.marked_unread: %v", portal.MXID, err) + user.zlog.Debug().Stringer("portal_mxid", portal.MXID).Msg("Marked room as unread (com.famedly.marked_unread)") } } func (user *User) handleGroupCreate(evt *events.JoinedGroup) { + log := user.zlog.With().Str("whatsapp_event", "JoinedGroup").Logger() + ctx := log.WithContext(context.TODO()) portal := user.GetPortalByJID(evt.JID) if evt.CreateKey == "" && len(portal.MXID) == 0 && portal.Key.JID != user.skipGroupCreateDelay { - user.log.Debugfln("Delaying handling group create with empty key to avoid race conditions") + log.Debug().Msg("Delaying handling group create with empty key to avoid race conditions") time.Sleep(5 * time.Second) } if len(portal.MXID) == 0 { if user.createKeyDedup != "" && evt.CreateKey == user.createKeyDedup { - user.log.Debugfln("Ignoring group create event with key %s", evt.CreateKey) + log.Debug().Str("create_key", evt.CreateKey).Msg("Ignoring group create event with cached create key") return } - err := portal.CreateMatrixRoom(user, &evt.GroupInfo, nil, true, true) + err := portal.CreateMatrixRoom(ctx, user, &evt.GroupInfo, nil, true, true) if err != nil { - user.log.Errorln("Failed to create Matrix room after join notification: %v", err) + log.Err(err).Msg("Failed to create Matrix room after join notification") } } else { - portal.UpdateMatrixRoom(user, &evt.GroupInfo, nil) + portal.UpdateMatrixRoom(ctx, user, &evt.GroupInfo, nil) } } @@ -1343,104 +1451,116 @@ func (user *User) handleGroupUpdate(evt *events.GroupInfo) { log.Debug().Str("sender", evt.Sender.String()).Msg("Ignoring group info update from @lid user") return } + ctx := log.WithContext(context.TODO()) switch { case evt.Announce != nil: log.Debug().Msg("Group announcement mode (message send permission) changed") - portal.RestrictMessageSending(evt.Announce.IsAnnounce) + portal.RestrictMessageSending(ctx, evt.Announce.IsAnnounce) case evt.Locked != nil: log.Debug().Msg("Group locked mode (metadata change permission) changed") - portal.RestrictMetadataChanges(evt.Locked.IsLocked) + portal.RestrictMetadataChanges(ctx, evt.Locked.IsLocked) case evt.Name != nil: log.Debug().Msg("Group name changed") - portal.UpdateName(evt.Name.Name, evt.Name.NameSetBy, true) + portal.UpdateName(ctx, evt.Name.Name, evt.Name.NameSetBy, true) case evt.Topic != nil: log.Debug().Msg("Group topic changed") - portal.UpdateTopic(evt.Topic.Topic, evt.Topic.TopicSetBy, true) + portal.UpdateTopic(ctx, evt.Topic.Topic, evt.Topic.TopicSetBy, true) case evt.Leave != nil: log.Debug().Msg("Someone left the group") if evt.Sender != nil && !evt.Sender.IsEmpty() { - portal.HandleWhatsAppKick(user, *evt.Sender, evt.Leave) + portal.HandleWhatsAppKick(ctx, user, *evt.Sender, evt.Leave) } case evt.Join != nil: log.Debug().Msg("Someone joined the group") - portal.HandleWhatsAppInvite(user, evt.Sender, evt.Join) + portal.HandleWhatsAppInvite(ctx, user, evt.Sender, evt.Join) case evt.Promote != nil: log.Debug().Msg("Someone was promoted to admin") - portal.ChangeAdminStatus(evt.Promote, true) + portal.ChangeAdminStatus(ctx, evt.Promote, true) case evt.Demote != nil: log.Debug().Msg("Someone was demoted from admin") - portal.ChangeAdminStatus(evt.Demote, false) + portal.ChangeAdminStatus(ctx, evt.Demote, false) case evt.Ephemeral != nil: log.Debug().Msg("Group ephemeral mode (disappearing message timer) changed") - portal.UpdateGroupDisappearingMessages(evt.Sender, evt.Timestamp, evt.Ephemeral.DisappearingTimer) + portal.UpdateGroupDisappearingMessages(ctx, evt.Sender, evt.Timestamp, evt.Ephemeral.DisappearingTimer) case evt.Link != nil: log.Debug().Msg("Group parent changed") if evt.Link.Type == types.GroupLinkChangeTypeParent { - portal.UpdateParentGroup(user, evt.Link.Group.JID, true) + portal.UpdateParentGroup(ctx, user, evt.Link.Group.JID, true) } case evt.Unlink != nil: log.Debug().Msg("Group parent removed") if evt.Unlink.Type == types.GroupLinkChangeTypeParent && portal.ParentGroup == evt.Unlink.Group.JID { - portal.UpdateParentGroup(user, types.EmptyJID, true) + portal.UpdateParentGroup(ctx, user, types.EmptyJID, true) } case evt.Delete != nil: log.Debug().Msg("Group deleted") - portal.Delete() - portal.Cleanup(false) + portal.Delete(ctx) + portal.Cleanup(ctx, false) default: log.Warn().Msg("Unhandled group info update") } } func (user *User) handleNewsletterJoin(evt *events.NewsletterJoin) { + ctx := user.zlog.With().Str("whatsapp_event", "NewsletterJoin").Logger().WithContext(context.TODO()) portal := user.GetPortalByJID(evt.ID) if portal.MXID == "" { - err := portal.CreateMatrixRoom(user, nil, &evt.NewsletterMetadata, true, false) + err := portal.CreateMatrixRoom(ctx, user, nil, &evt.NewsletterMetadata, true, false) if err != nil { user.zlog.Err(err).Msg("Failed to create room on newsletter join event") } } else { - portal.UpdateMatrixRoom(user, nil, &evt.NewsletterMetadata) + portal.UpdateMatrixRoom(ctx, user, nil, &evt.NewsletterMetadata) } } func (user *User) handleNewsletterLeave(evt *events.NewsletterLeave) { + ctx := user.zlog.With().Str("whatsapp_event", "NewsletterLeave").Logger().WithContext(context.TODO()) portal := user.GetPortalByJID(evt.ID) if portal.MXID != "" { - portal.HandleWhatsAppKick(user, user.JID, []types.JID{user.JID}) + portal.HandleWhatsAppKick(ctx, user, user.JID, []types.JID{user.JID}) } } -func (user *User) handlePictureUpdate(evt *events.Picture) { +func (user *User) handlePictureUpdate(ctx context.Context, evt *events.Picture) { if evt.JID.Server == types.DefaultUserServer { puppet := user.bridge.GetPuppetByJID(evt.JID) - user.log.Debugfln("Received picture update for puppet %s (current: %s, new: %s)", evt.JID, puppet.Avatar, evt.PictureID) + user.zlog.Debug(). + Stringer("jid", evt.JID). + Str("current_avatar", puppet.Avatar). + Str("new_avatar", evt.PictureID). + Msg("Received picture update for puppet") if puppet.Avatar != evt.PictureID { - puppet.Sync(user, nil, true, false) + puppet.Sync(ctx, user, nil, true, false) } } else if portal := user.GetPortalByJID(evt.JID); portal != nil { - user.log.Debugfln("Received picture update for portal %s (current: %s, new: %s)", evt.JID, portal.Avatar, evt.PictureID) + user.zlog.Debug(). + Stringer("jid", evt.JID). + Str("current_avatar", portal.Avatar). + Str("new_avatar", evt.PictureID). + Msg("Received picture update for portal") if portal.Avatar != evt.PictureID { - portal.UpdateAvatar(user, evt.Author, true) + portal.UpdateAvatar(ctx, user, evt.Author, true) } } } -func (user *User) StartPM(jid types.JID, reason string) (*Portal, *Puppet, bool, error) { - user.log.Debugln("Starting PM with", jid, "from", reason) +func (user *User) StartPM(ctx context.Context, jid types.JID, reason string) (*Portal, *Puppet, bool, error) { + zerolog.Ctx(ctx).Debug().Stringer("jid", jid).Str("source", reason).Msg("Starting PM with user") puppet := user.bridge.GetPuppetByJID(jid) - puppet.SyncContact(user, true, false, reason) + puppet.SyncContact(ctx, user, true, false, reason) portal := user.GetPortalByJID(puppet.JID) if len(portal.MXID) > 0 { - ok := portal.ensureUserInvited(user) + ok := portal.ensureUserInvited(ctx, user) if !ok { - portal.log.Warnfln("ensureUserInvited(%s) returned false, creating new portal", user.MXID) + zerolog.Ctx(ctx).Warn().Msg("Failed to ensure user is invited to room in StartPM, creating new portal") portal.MXID = "" + portal.updateLogger() } else { return portal, puppet, false, nil } } - err := portal.CreateMatrixRoom(user, nil, nil, false, true) + err := portal.CreateMatrixRoom(ctx, user, nil, nil, false, true) return portal, puppet, true, err }