diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index 57fc3f33a..af43064fe 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -342,8 +342,7 @@ func createRoom( } // send events to the room server - _, err = roomserverAPI.SendEvents(req.Context(), rsAPI, builtEvents, cfg.Matrix.ServerName, nil) - if err != nil { + if err = roomserverAPI.SendEvents(req.Context(), rsAPI, builtEvents, cfg.Matrix.ServerName, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index cba19a24b..202662ab6 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -75,13 +75,12 @@ func sendMembership(ctx context.Context, accountDB accounts.Database, device *us return jsonerror.InternalServerError() } - _, err = roomserverAPI.SendEvents( + if err = roomserverAPI.SendEvents( ctx, rsAPI, []gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)}, cfg.Matrix.ServerName, nil, - ) - if err != nil { + ); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 4c7895bd3..bc51b0b51 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -171,7 +171,7 @@ func SetAvatarURL( return jsonerror.InternalServerError() } - if _, err := api.SendEvents(req.Context(), rsAPI, events, cfg.Matrix.ServerName, nil); err != nil { + if err := api.SendEvents(req.Context(), rsAPI, events, cfg.Matrix.ServerName, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } @@ -289,7 +289,7 @@ func SetDisplayName( return jsonerror.InternalServerError() } - if _, err := api.SendEvents(req.Context(), rsAPI, events, cfg.Matrix.ServerName, nil); err != nil { + if err := api.SendEvents(req.Context(), rsAPI, events, cfg.Matrix.ServerName, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index a825da64d..178bfafc9 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -122,8 +122,7 @@ func SendRedaction( JSON: jsonerror.NotFound("Room does not exist"), } } - _, err = roomserverAPI.SendEvents(context.Background(), rsAPI, []gomatrixserverlib.HeaderedEvent{*e}, cfg.Matrix.ServerName, nil) - if err != nil { + if err = roomserverAPI.SendEvents(context.Background(), rsAPI, []gomatrixserverlib.HeaderedEvent{*e}, cfg.Matrix.ServerName, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index a25979ea0..9744a5640 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -90,27 +90,26 @@ func SendEvent( // pass the new event to the roomserver and receive the correct event ID // event ID in case of duplicate transaction is discarded - eventID, err := api.SendEvents( + if err := api.SendEvents( req.Context(), rsAPI, []gomatrixserverlib.HeaderedEvent{ e.Headered(verRes.RoomVersion), }, cfg.Matrix.ServerName, txnAndSessionID, - ) - if err != nil { + ); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } util.GetLogger(req.Context()).WithFields(logrus.Fields{ - "event_id": eventID, + "event_id": e.EventID(), "room_id": roomID, "room_version": verRes.RoomVersion, }).Info("Sent event to roomserver") res := util.JSONResponse{ Code: http.StatusOK, - JSON: sendEventResponse{eventID}, + JSON: sendEventResponse{e.EventID()}, } // Add response to transactionsCache if txnID != nil { diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index 2ffb6bb09..b9575a284 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -359,7 +359,7 @@ func emit3PIDInviteEvent( return err } - _, err = api.SendEvents( + return api.SendEvents( ctx, rsAPI, []gomatrixserverlib.HeaderedEvent{ (*event).Headered(queryRes.RoomVersion), @@ -367,5 +367,4 @@ func emit3PIDInviteEvent( cfg.Matrix.ServerName, nil, ) - return err } diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 6cac12451..36afe30ab 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -266,15 +266,14 @@ func SendJoin( // We are responsible for notifying other servers that the user has joined // the room, so set SendAsServer to cfg.Matrix.ServerName if !alreadyJoined { - _, err = api.SendEvents( + if err = api.SendEvents( httpReq.Context(), rsAPI, []gomatrixserverlib.HeaderedEvent{ event.Headered(stateAndAuthChainResponse.RoomVersion), }, cfg.Matrix.ServerName, nil, - ) - if err != nil { + ); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index 511623445..8bb0a8a94 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -247,15 +247,14 @@ func SendLeave( // Send the events to the room server. // We are responsible for notifying other servers that the user has left // the room, so set SendAsServer to cfg.Matrix.ServerName - _, err = api.SendEvents( + if err = api.SendEvents( httpReq.Context(), rsAPI, []gomatrixserverlib.HeaderedEvent{ event.Headered(verRes.RoomVersion), }, cfg.Matrix.ServerName, nil, - ) - if err != nil { + ); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("producer.SendEvents failed") return jsonerror.InternalServerError() } diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index cad779219..570062adc 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -382,7 +382,7 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) erro } // pass the event to the roomserver - _, err := api.SendEvents( + return api.SendEvents( t.context, t.rsAPI, []gomatrixserverlib.HeaderedEvent{ e.Headered(stateResp.RoomVersion), @@ -390,7 +390,6 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) erro api.DoNotSendToOtherServers, nil, ) - return err } func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserverlib.Event) error { diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index e8d9a9397..ec6cc1488 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -89,7 +89,7 @@ func CreateInvitesFrom3PIDInvites( } // Send all the events - if _, err := api.SendEvents(req.Context(), rsAPI, evs, cfg.Matrix.ServerName, nil); err != nil { + if err := api.SendEvents(req.Context(), rsAPI, evs, cfg.Matrix.ServerName, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } @@ -172,7 +172,7 @@ func ExchangeThirdPartyInvite( } // Send the event to the roomserver - if _, err = api.SendEvents( + if err = api.SendEvents( httpReq.Context(), rsAPI, []gomatrixserverlib.HeaderedEvent{ signedEvent.Event.Headered(verRes.RoomVersion), diff --git a/roomserver/api/input.go b/roomserver/api/input.go index 05c981df4..73c4994a7 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -83,5 +83,4 @@ type InputRoomEventsRequest struct { // InputRoomEventsResponse is a response to InputRoomEvents type InputRoomEventsResponse struct { - EventID string `json:"event_id"` } diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 207c12c8f..16f5e8e18 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -26,7 +26,7 @@ import ( func SendEvents( ctx context.Context, rsAPI RoomserverInternalAPI, events []gomatrixserverlib.HeaderedEvent, sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID, -) (string, error) { +) error { ires := make([]InputRoomEvent, len(events)) for i, event := range events { ires[i] = InputRoomEvent{ @@ -77,19 +77,16 @@ func SendEventWithState( StateEventIDs: stateEventIDs, }) - _, err = SendInputRoomEvents(ctx, rsAPI, ires) - return err + return SendInputRoomEvents(ctx, rsAPI, ires) } // SendInputRoomEvents to the roomserver. func SendInputRoomEvents( ctx context.Context, rsAPI RoomserverInternalAPI, ires []InputRoomEvent, -) (eventID string, err error) { +) error { request := InputRoomEventsRequest{InputRoomEvents: ires} var response InputRoomEventsResponse - err = rsAPI.InputRoomEvents(ctx, &request, &response) - eventID = response.EventID - return + return rsAPI.InputRoomEvents(ctx, &request, &response) } // SendInvite event to the roomserver. diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 87bdc5dbf..7a44ff42c 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -19,12 +19,14 @@ import ( "context" "encoding/json" "sync" + "time" "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" + "go.uber.org/atomic" ) type Inputer struct { @@ -33,7 +35,36 @@ type Inputer struct { ServerName gomatrixserverlib.ServerName OutputRoomEventTopic string - mutexes sync.Map // room ID -> *sync.Mutex, protects calls to processRoomEvent + workers sync.Map // room ID -> *inputWorker +} + +type inputTask struct { + ctx context.Context + event *api.InputRoomEvent + wg *sync.WaitGroup + err error // written back by worker, only safe to read when all tasks are done +} + +type inputWorker struct { + r *Inputer + running atomic.Bool + input chan *inputTask +} + +func (w *inputWorker) start() { + if !w.running.CAS(false, true) { + return + } + defer w.running.Store(false) + for { + select { + case task := <-w.input: + _, task.err = w.r.processRoomEvent(task.ctx, task.event) + task.wg.Done() + case <-time.After(time.Second * 5): + return + } + } } // WriteOutputEvents implements OutputRoomEventWriter @@ -73,19 +104,54 @@ func (r *Inputer) InputRoomEvents( ctx context.Context, request *api.InputRoomEventsRequest, response *api.InputRoomEventsResponse, -) (err error) { +) error { + // Create a wait group. Each task that we dispatch will call Done on + // this wait group so that we know when all of our events have been + // processed. + wg := &sync.WaitGroup{} + wg.Add(len(request.InputRoomEvents)) + tasks := make([]*inputTask, len(request.InputRoomEvents)) + for i, e := range request.InputRoomEvents { + // Work out if we are running per-room workers or if we're just doing + // it on a global basis (e.g. SQLite). roomID := "global" if r.DB.SupportsConcurrentRoomInputs() { roomID = e.Event.RoomID() } - mutex, _ := r.mutexes.LoadOrStore(roomID, &sync.Mutex{}) - mutex.(*sync.Mutex).Lock() - if response.EventID, err = r.processRoomEvent(ctx, request.InputRoomEvents[i]); err != nil { - mutex.(*sync.Mutex).Unlock() - return err + + // Look up the worker, or create it if it doesn't exist. This channel + // is buffered to reduce the chance that we'll be blocked by another + // room - the channel will be quite small as it's just pointer types. + w, _ := r.workers.LoadOrStore(roomID, &inputWorker{ + r: r, + input: make(chan *inputTask, 10), + }) + worker := w.(*inputWorker) + + // Create a task. This contains the input event and a reference to + // the wait group, so that the worker can notify us when this specific + // task has been finished. + tasks[i] = &inputTask{ + ctx: ctx, + event: &request.InputRoomEvents[i], + wg: wg, + } + + // Send the task to the worker. + go worker.start() + worker.input <- tasks[i] + } + + // Wait for all of the workers to return results about our tasks. + wg.Wait() + + // If any of the tasks returned an error, we should probably report + // that back to the caller. + for _, task := range tasks { + if task.err != nil { + return task.err } - mutex.(*sync.Mutex).Unlock() } return nil } diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 69f51f4b8..6ee679da6 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -38,7 +38,7 @@ import ( // nolint:gocyclo func (r *Inputer) processRoomEvent( ctx context.Context, - input api.InputRoomEvent, + input *api.InputRoomEvent, ) (eventID string, err error) { // Parse and validate the event JSON headered := input.Event @@ -143,7 +143,7 @@ func (r *Inputer) processRoomEvent( func (r *Inputer) calculateAndSetState( ctx context.Context, - input api.InputRoomEvent, + input *api.InputRoomEvent, roomInfo types.RoomInfo, stateAtEvent *types.StateAtEvent, event gomatrixserverlib.Event, diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 0deb7acb1..786d4f31f 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -114,8 +114,7 @@ func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []js rsAPI := NewInternalAPI(base, &test.NopJSONVerifier{}) hevents := mustLoadEvents(t, ver, events) - _, err = api.SendEvents(ctx, rsAPI, hevents, testOrigin, nil) - if err != nil { + if err = api.SendEvents(ctx, rsAPI, hevents, testOrigin, nil); err != nil { t.Errorf("failed to SendEvents: %s", err) } return rsAPI, dp, hevents