diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index d567408be..21b52bc3c 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -68,14 +68,15 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } // onMessage is called when the appservice component receives a new event from // the room server output log. -func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called // Parse out the event JSON var output api.OutputEvent if err := json.Unmarshal(msg.Data, &output); err != nil { diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 6d3cf0e46..f3314bc98 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -67,14 +67,15 @@ func NewKeyChangeConsumer( // Start consuming from key servers func (t *KeyChangeConsumer) Start() error { return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, - nats.DeliverAll(), nats.ManualAck(), + t.ctx, t.jetstream, t.topic, t.durable, 1, + t.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } // onMessage is called in response to a message received on the // key change events topic from the key server. -func (t *KeyChangeConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (t *KeyChangeConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called var m api.DeviceMessage if err := json.Unmarshal(msg.Data, &m); err != nil { logrus.WithError(err).Errorf("failed to read device message from key change topic") diff --git a/federationapi/consumers/presence.go b/federationapi/consumers/presence.go index a65d2aa04..e76103cd3 100644 --- a/federationapi/consumers/presence.go +++ b/federationapi/consumers/presence.go @@ -69,14 +69,15 @@ func (t *OutputPresenceConsumer) Start() error { return nil } return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, + t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage, nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(), ) } // onMessage is called in response to a message received on the presence // events topic from the client api. -func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called // only send presence events which originated from us userID := msg.Header.Get(jetstream.UserID) _, serverName, err := gomatrixserverlib.SplitID('@', userID) diff --git a/federationapi/consumers/receipts.go b/federationapi/consumers/receipts.go index 2c9d79bcb..366cb264e 100644 --- a/federationapi/consumers/receipts.go +++ b/federationapi/consumers/receipts.go @@ -65,14 +65,15 @@ func NewOutputReceiptConsumer( // Start consuming from the clientapi func (t *OutputReceiptConsumer) Start() error { return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, + t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage, nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(), ) } // onMessage is called in response to a message received on the receipt // events topic from the client api. -func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called receipt := syncTypes.OutputReceiptEvent{ UserID: msg.Header.Get(jetstream.UserID), RoomID: msg.Header.Get(jetstream.RoomID), diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 2622ecb3f..349b50b05 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -68,8 +68,8 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } @@ -77,7 +77,8 @@ func (s *OutputRoomEventConsumer) Start() error { // It is unsafe to call this with messages for the same room in multiple gorountines // because updates it will likely fail with a types.EventIDMismatchError when it // realises that it cannot update the room state using the deltas. -func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called // Parse out the event JSON var output api.OutputEvent if err := json.Unmarshal(msg.Data, &output); err != nil { diff --git a/federationapi/consumers/sendtodevice.go b/federationapi/consumers/sendtodevice.go index f99a895e0..e44bad723 100644 --- a/federationapi/consumers/sendtodevice.go +++ b/federationapi/consumers/sendtodevice.go @@ -63,14 +63,15 @@ func NewOutputSendToDeviceConsumer( // Start consuming from the client api func (t *OutputSendToDeviceConsumer) Start() error { return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, - nats.DeliverAll(), nats.ManualAck(), + t.ctx, t.jetstream, t.topic, t.durable, 1, + t.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } // onMessage is called in response to a message received on the // send-to-device events topic from the client api. -func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called // only send send-to-device events which originated from us sender := msg.Header.Get("sender") _, originServerName, err := gomatrixserverlib.SplitID('@', sender) diff --git a/federationapi/consumers/typing.go b/federationapi/consumers/typing.go index 428e1a867..9c7379136 100644 --- a/federationapi/consumers/typing.go +++ b/federationapi/consumers/typing.go @@ -62,14 +62,15 @@ func NewOutputTypingConsumer( // Start consuming from the clientapi func (t *OutputTypingConsumer) Start() error { return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, + t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage, nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(), ) } // onMessage is called in response to a message received on the typing // events topic from the client api. -func (t *OutputTypingConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (t *OutputTypingConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called // Extract the typing event from msg. roomID := msg.Header.Get(jetstream.RoomID) userID := msg.Header.Get(jetstream.UserID) diff --git a/keyserver/consumers/devicelistupdate.go b/keyserver/consumers/devicelistupdate.go index f4f246280..d15f94267 100644 --- a/keyserver/consumers/devicelistupdate.go +++ b/keyserver/consumers/devicelistupdate.go @@ -55,14 +55,15 @@ func NewDeviceListUpdateConsumer( // Start consuming from key servers func (t *DeviceListUpdateConsumer) Start() error { return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, - nats.DeliverAll(), nats.ManualAck(), + t.ctx, t.jetstream, t.topic, t.durable, 1, + t.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } // onMessage is called in response to a message received on the // key change events topic from the key server. -func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called var m gomatrixserverlib.DeviceListUpdateEvent if err := json.Unmarshal(msg.Data, &m); err != nil { logrus.WithError(err).Errorf("Failed to read from device list update input topic") diff --git a/setup/jetstream/helpers.go b/setup/jetstream/helpers.go index 1c07583e9..f47637c69 100644 --- a/setup/jetstream/helpers.go +++ b/setup/jetstream/helpers.go @@ -9,9 +9,16 @@ import ( "github.com/sirupsen/logrus" ) +// JetStreamConsumer starts a durable consumer on the given subject with the +// given durable name. The function will be called when one or more messages +// is available, up to the maximum batch size specified. If the batch is set to +// 1 then messages will be delivered one at a time. If the function is called, +// the messages array is guaranteed to be at least 1 in size. Any provided NATS +// options will be passed through to the pull subscriber creation. The consumer +// will continue to run until the context expires, at which point it will stop. func JetStreamConsumer( - ctx context.Context, js nats.JetStreamContext, subj, durable string, - f func(ctx context.Context, msg *nats.Msg) bool, + ctx context.Context, js nats.JetStreamContext, subj, durable string, batch int, + f func(ctx context.Context, msgs []*nats.Msg) bool, opts ...nats.SubOpt, ) error { defer func() { @@ -27,6 +34,14 @@ func JetStreamConsumer( } }() + // If the batch size is greater than 1, we will want to acknowledge all + // received messages in the batch. Below we will send an acknowledgement + // for the most recent message in the batch and AckAll will ensure that + // all messages that came before it are also acknowledged implicitly. + if batch > 1 { + opts = append(opts, nats.AckAll()) + } + name := durable + "Pull" sub, err := js.PullSubscribe(subj, name, opts...) if err != nil { @@ -50,7 +65,7 @@ func JetStreamConsumer( // enforce its own deadline (roughly 5 seconds by default). Therefore // it is our responsibility to check whether our context expired or // not when a context error is returned. Footguns. Footguns everywhere. - msgs, err := sub.Fetch(1, nats.Context(ctx)) + msgs, err := sub.Fetch(batch, nats.Context(ctx)) if err != nil { if err == context.Canceled || err == context.DeadlineExceeded { // Work out whether it was the JetStream context that expired @@ -74,13 +89,13 @@ func JetStreamConsumer( if len(msgs) < 1 { continue } - msg := msgs[0] + msg := msgs[len(msgs)-1] // most recent message, in case of AckAll if err = msg.InProgress(nats.Context(ctx)); err != nil { logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err)) sentry.CaptureException(err) continue } - if f(ctx, msg) { + if f(ctx, msgs) { if err = msg.AckSync(nats.Context(ctx)); err != nil { logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.AckSync: %w", err)) sentry.CaptureException(err) diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index 02633b567..f0588cab8 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -75,15 +75,16 @@ func NewOutputClientDataConsumer( // Start consuming from room servers func (s *OutputClientDataConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } // onMessage is called when the sync server receives a new event from the client API server output log. // It is not safe for this function to be called from multiple goroutines, or else the // sync stream position may race and be incorrectly calculated. -func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called // Parse out the event JSON userID := msg.Header.Get(jetstream.UserID) var output eventutil.AccountData diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index c8d88ddac..c42e71971 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -75,12 +75,13 @@ func NewOutputKeyChangeEventConsumer( // Start consuming from the key server func (s *OutputKeyChangeEventConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } -func (s *OutputKeyChangeEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputKeyChangeEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called var m api.DeviceMessage if err := json.Unmarshal(msg.Data, &m); err != nil { logrus.WithError(err).Errorf("failed to read device message from key change topic") diff --git a/syncapi/consumers/presence.go b/syncapi/consumers/presence.go index db7d67fa6..61bdc13de 100644 --- a/syncapi/consumers/presence.go +++ b/syncapi/consumers/presence.go @@ -128,12 +128,13 @@ func (s *PresenceConsumer) Start() error { return nil } return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.presenceTopic, s.durable, s.onMessage, + s.ctx, s.jetstream, s.presenceTopic, s.durable, 1, s.onMessage, nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(), ) } -func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *PresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called userID := msg.Header.Get(jetstream.UserID) presence := msg.Header.Get("presence") timestamp := msg.Header.Get("last_active_ts") diff --git a/syncapi/consumers/receipts.go b/syncapi/consumers/receipts.go index 83156cf93..a18244c44 100644 --- a/syncapi/consumers/receipts.go +++ b/syncapi/consumers/receipts.go @@ -74,12 +74,13 @@ func NewOutputReceiptEventConsumer( // Start consuming receipts events. func (s *OutputReceiptEventConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } -func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called output := types.OutputReceiptEvent{ UserID: msg.Header.Get(jetstream.UserID), RoomID: msg.Header.Get(jetstream.RoomID), diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index f77b1673b..6979eb484 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -79,15 +79,16 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } // onMessage is called when the sync server receives a new event from the room server output log. // It is not safe for this function to be called from multiple goroutines, or else the // sync stream position may race and be incorrectly calculated. -func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called // Parse out the event JSON var err error var output api.OutputEvent diff --git a/syncapi/consumers/sendtodevice.go b/syncapi/consumers/sendtodevice.go index 0b9153fcd..89b01d7e5 100644 --- a/syncapi/consumers/sendtodevice.go +++ b/syncapi/consumers/sendtodevice.go @@ -68,12 +68,13 @@ func NewOutputSendToDeviceEventConsumer( // Start consuming send-to-device events. func (s *OutputSendToDeviceEventConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } -func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called userID := msg.Header.Get(jetstream.UserID) _, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { diff --git a/syncapi/consumers/typing.go b/syncapi/consumers/typing.go index 48e484ec5..88db80f8c 100644 --- a/syncapi/consumers/typing.go +++ b/syncapi/consumers/typing.go @@ -64,12 +64,13 @@ func NewOutputTypingEventConsumer( // Start consuming typing events. func (s *OutputTypingEventConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } -func (s *OutputTypingEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputTypingEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called roomID := msg.Header.Get(jetstream.RoomID) userID := msg.Header.Get(jetstream.UserID) typing, err := strconv.ParseBool(msg.Header.Get("typing")) diff --git a/syncapi/consumers/userapi.go b/syncapi/consumers/userapi.go index 010fa7c8e..227823522 100644 --- a/syncapi/consumers/userapi.go +++ b/syncapi/consumers/userapi.go @@ -67,8 +67,8 @@ func NewOutputNotificationDataConsumer( // Start starts consumption. func (s *OutputNotificationDataConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } @@ -76,7 +76,8 @@ func (s *OutputNotificationDataConsumer) Start() error { // the push server. It is not safe for this function to be called from // multiple goroutines, or else the sync stream position may race and // be incorrectly calculated. -func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called userID := string(msg.Header.Get(jetstream.UserID)) // Parse out the event JSON diff --git a/userapi/consumers/syncapi_readupdate.go b/userapi/consumers/syncapi_readupdate.go index 067f93330..54654f757 100644 --- a/userapi/consumers/syncapi_readupdate.go +++ b/userapi/consumers/syncapi_readupdate.go @@ -56,15 +56,16 @@ func NewOutputReadUpdateConsumer( func (s *OutputReadUpdateConsumer) Start() error { if err := jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ); err != nil { return err } return nil } -func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called var read types.ReadUpdate if err := json.Unmarshal(msg.Data, &read); err != nil { log.WithError(err).Error("userapi clientapi consumer: message parse failure") diff --git a/userapi/consumers/syncapi_streamevent.go b/userapi/consumers/syncapi_streamevent.go index ec351ef7e..3ac6f58d0 100644 --- a/userapi/consumers/syncapi_streamevent.go +++ b/userapi/consumers/syncapi_streamevent.go @@ -65,15 +65,16 @@ func NewOutputStreamEventConsumer( func (s *OutputStreamEventConsumer) Start() error { if err := jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ); err != nil { return err } return nil } -func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called var output types.StreamedEvent output.Event = &gomatrixserverlib.HeaderedEvent{} if err := json.Unmarshal(msg.Data, &output); err != nil {