diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index eba166599..144d22795 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -21,16 +21,15 @@ import ( "fmt" "github.com/getsentry/sentry-go" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" ) // updateLatestEvents updates the list of latest events for this room in the database and writes the @@ -71,7 +70,6 @@ func (r *Inputer) updateLatestEvents( defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) u := latestEventsUpdater{ - ctx: ctx, api: r, updater: updater, stateAtEvent: stateAtEvent, @@ -80,12 +78,12 @@ func (r *Inputer) updateLatestEvents( } var updates []api.OutputEvent - updates, err = u.doUpdateLatestEvents(roomInfo) + updates, err = u.doUpdateLatestEvents(ctx, roomInfo) if err != nil { return fmt.Errorf("u.doUpdateLatestEvents: %w", err) } - update, err := u.makeOutputNewRoomEvent(transactionID, sendAsServer, updater.LastEventIDSent(), historyVisibility) + update, err := u.makeOutputNewRoomEvent(ctx, transactionID, sendAsServer, updater.LastEventIDSent(), historyVisibility) if err != nil { return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err) } @@ -118,7 +116,6 @@ func (r *Inputer) updateLatestEvents( // The state could be passed using function arguments, but it becomes impractical // when there are so many variables to pass around. type latestEventsUpdater struct { - ctx context.Context api *Inputer updater *shared.RoomUpdater stateAtEvent types.StateAtEvent @@ -140,7 +137,7 @@ type latestEventsUpdater struct { newStateNID types.StateSnapshotNID } -func (u *latestEventsUpdater) doUpdateLatestEvents(roomInfo *types.RoomInfo) ([]api.OutputEvent, error) { +func (u *latestEventsUpdater) doUpdateLatestEvents(ctx context.Context, roomInfo *types.RoomInfo) ([]api.OutputEvent, error) { // If we are doing a regular event update then we will get the // previous latest events to use as a part of the calculation. If // we are overwriting the latest events because we have a complete @@ -164,6 +161,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents(roomInfo *types.RoomInfo) ([] // Work out what the latest events are. This will include the new // event if it is not already referenced. extremitiesChanged, err := u.calculateLatest( + ctx, u.oldLatest, u.event, types.StateAtEventAndReference{ EventID: u.event.EventID(), @@ -178,13 +176,13 @@ func (u *latestEventsUpdater) doUpdateLatestEvents(roomInfo *types.RoomInfo) ([] // latest state. var membershipUpdates []api.OutputEvent if extremitiesChanged || u.rewritesState { - if err = u.latestState(roomInfo); err != nil { + if err = u.latestState(ctx, roomInfo); err != nil { return nil, fmt.Errorf("u.latestState: %w", err) } // If we need to generate any output events then here's where we do it. // TODO: Move this! - if membershipUpdates, err = u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added); err != nil { + if membershipUpdates, err = u.api.updateMemberships(ctx, u.updater, u.removed, u.added); err != nil { return nil, fmt.Errorf("u.api.updateMemberships: %w", err) } } else { @@ -198,8 +196,8 @@ func (u *latestEventsUpdater) doUpdateLatestEvents(roomInfo *types.RoomInfo) ([] return membershipUpdates, nil } -func (u *latestEventsUpdater) latestState(roomInfo *types.RoomInfo) error { - trace, ctx := internal.StartRegion(u.ctx, "processEventWithMissingState") +func (u *latestEventsUpdater) latestState(ctx context.Context, roomInfo *types.RoomInfo) error { + trace, ctx := internal.StartRegion(ctx, "processEventWithMissingState") defer trace.EndRegion() var err error @@ -315,11 +313,12 @@ func (u *latestEventsUpdater) latestState(roomInfo *types.RoomInfo) error { // calculateLatest works out the new set of forward extremities. Returns // true if the new event is included in those extremites, false otherwise. func (u *latestEventsUpdater) calculateLatest( + ctx context.Context, oldLatest []types.StateAtEventAndReference, newEvent gomatrixserverlib.PDU, newStateAndRef types.StateAtEventAndReference, ) (bool, error) { - trace, _ := internal.StartRegion(u.ctx, "calculateLatest") + trace, _ := internal.StartRegion(ctx, "calculateLatest") defer trace.EndRegion() // First of all, get a list of all of the events in our current @@ -377,6 +376,7 @@ func (u *latestEventsUpdater) calculateLatest( } func (u *latestEventsUpdater) makeOutputNewRoomEvent( + ctx context.Context, transactionID *api.TransactionID, sendAsServer string, lastEventIDSent string, @@ -397,7 +397,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent( HistoryVisibility: historyVisibility, } - eventIDMap, err := u.stateEventMap() + eventIDMap, err := u.stateEventMap(ctx) if err != nil { return nil, err } @@ -421,7 +421,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent( } // retrieve an event nid -> event ID map for all events that need updating -func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error) { +func (u *latestEventsUpdater) stateEventMap(ctx context.Context) (map[types.EventNID]string, error) { cap := len(u.added) + len(u.removed) + len(u.stateBeforeEventRemoves) + len(u.stateBeforeEventAdds) stateEventNIDs := make(types.EventNIDs, 0, cap) allStateEntries := make([]types.StateEntry, 0, cap) @@ -433,5 +433,5 @@ func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error) stateEventNIDs = append(stateEventNIDs, entry.EventNID) } stateEventNIDs = stateEventNIDs[:util.SortAndUnique(stateEventNIDs)] - return u.updater.EventIDs(u.ctx, stateEventNIDs) + return u.updater.EventIDs(ctx, stateEventNIDs) }