From 1e84a169f9fedded85549d9dcc855e75eb66e162 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 11 Apr 2024 17:27:49 +0300 Subject: [PATCH] Handle media retries asynchronously --- config/bridge.go | 1 + config/upgrade.go | 1 + example-config.yaml | 2 ++ go.mod | 1 + go.sum | 2 ++ portal.go | 15 +++++++-------- user.go | 9 ++++++--- 7 files changed, 20 insertions(+), 11 deletions(-) diff --git a/config/bridge.go b/config/bridge.go index 538908f..a27d976 100644 --- a/config/bridge.go +++ b/config/bridge.go @@ -78,6 +78,7 @@ type BridgeConfig struct { AutoRequestMedia bool `yaml:"auto_request_media"` RequestMethod MediaRequestMethod `yaml:"request_method"` RequestLocalTime int `yaml:"request_local_time"` + MaxAsyncHandle int64 `yaml:"max_async_handle"` } `yaml:"media_requests"` Deferred []DeferredConfig `yaml:"deferred"` diff --git a/config/upgrade.go b/config/upgrade.go index b1999c8..48d728c 100644 --- a/config/upgrade.go +++ b/config/upgrade.go @@ -54,6 +54,7 @@ func DoUpgrade(helper *up.Helper) { helper.Copy(up.Bool, "bridge", "history_sync", "media_requests", "auto_request_media") helper.Copy(up.Str, "bridge", "history_sync", "media_requests", "request_method") helper.Copy(up.Int, "bridge", "history_sync", "media_requests", "request_local_time") + helper.Copy(up.Int, "bridge", "history_sync", "media_requests", "max_async_handle") helper.Copy(up.Int, "bridge", "history_sync", "max_initial_conversations") helper.Copy(up.Int, "bridge", "history_sync", "message_count") helper.Copy(up.Int, "bridge", "history_sync", "unread_hours_threshold") diff --git a/example-config.yaml b/example-config.yaml index 058f3e1..1acd68e 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -173,6 +173,8 @@ bridge: request_method: immediate # If request_method is "local_time", what time should the requests be sent (in minutes after midnight)? request_local_time: 120 + # Maximum number of media request responses to handle in parallel per user. + max_async_handle: 2 # Settings for immediate backfills. These backfills should generally be small and their main purpose is # to populate each of the initial chats (as configured by max_initial_conversations) with a few messages # so that you can continue conversations without losing context. diff --git a/go.mod b/go.mod index b8a002a..20db9f2 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( golang.org/x/exp v0.0.0-20240314144324-c7f7c6466f7f golang.org/x/image v0.15.0 golang.org/x/net v0.22.0 + golang.org/x/sync v0.3.0 google.golang.org/protobuf v1.33.0 maunium.net/go/mautrix v0.18.0 ) diff --git a/go.sum b/go.sum index f5af216..fb94bec 100644 --- a/go.sum +++ b/go.sum @@ -85,6 +85,8 @@ golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/portal.go b/portal.go index a8df04e..8cbfbbf 100644 --- a/portal.go +++ b/portal.go @@ -269,7 +269,6 @@ type fakeMessage struct { type PortalEvent struct { Message *PortalMessage MatrixMessage *PortalMatrixMessage - MediaRetry *PortalMediaRetry } type PortalMessage struct { @@ -286,11 +285,6 @@ type PortalMatrixMessage struct { receivedAt time.Time } -type PortalMediaRetry struct { - evt *events.MediaRetry - source *User -} - type recentlyHandledWrapper struct { id types.MessageID err database.MessageErrorType @@ -572,8 +566,6 @@ func (portal *Portal) handleOneMessageLoopItem() { portal.handleWhatsAppMessageLoopItem(msg.Message) } else if msg.MatrixMessage != nil { portal.handleMatrixMessageLoopItem(msg.MatrixMessage) - } else if msg.MediaRetry != nil { - portal.handleMediaRetry(msg.MediaRetry.evt, msg.MediaRetry.source) } else { portal.zlog.Warn().Msg("Unexpected PortalEvent with no data") } @@ -3801,6 +3793,13 @@ func (portal *Portal) handleMediaRetry(retry *events.MediaRetry, source *User) { Str("retry_message_id", retry.MessageID). Logger() ctx := log.WithContext(context.TODO()) + err := source.mediaRetryLock.Acquire(ctx, 1) + if err != nil { + log.Err(err).Msg("Failed to acquire media retry semaphore") + return + } + defer source.mediaRetryLock.Release(1) + msg, err := portal.bridge.DB.Message.GetByJID(ctx, portal.Key, retry.MessageID) if msg == nil { log.Warn().Msg("Dropping media retry notification for unknown message") diff --git a/user.go b/user.go index 83890c5..a10be81 100644 --- a/user.go +++ b/user.go @@ -41,6 +41,7 @@ import ( "go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types/events" waLog "go.mau.fi/whatsmeow/util/log" + "golang.org/x/sync/semaphore" "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/bridge" @@ -74,6 +75,8 @@ type User struct { historySyncs chan *events.HistorySync lastPresence types.Presence + mediaRetryLock *semaphore.Weighted + historySyncLoopsStarted bool enqueueBackfillsTimer *time.Timer spaceMembershipChecked bool @@ -257,6 +260,8 @@ func (br *WABridge) NewUser(dbUser *database.User) *User { lastPresence: types.PresenceUnavailable, resyncQueue: make(map[types.JID]resyncQueueItem), + + mediaRetryLock: semaphore.NewWeighted(br.Config.Bridge.HistorySync.MediaRequests.MaxAsyncHandle), } user.PermissionLevel = user.bridge.Config.Bridge.Permissions.Get(user.MXID) @@ -955,9 +960,7 @@ func (user *User) HandleEvent(event interface{}) { case *events.MediaRetry: user.phoneSeen(v.Timestamp) portal := user.GetPortalByJID(v.ChatID) - portal.events <- &PortalEvent{ - MediaRetry: &PortalMediaRetry{evt: v, source: user}, - } + go portal.handleMediaRetry(v, user) case *events.CallOffer: user.handleCallStart(v.CallCreator, v.CallID, "", v.Timestamp) case *events.CallOfferNotice: