diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index d8ce9727f..fc712f47b 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -122,7 +122,7 @@ func (r *Inputer) processRoomEvent( } // Store the event. - _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected) + _, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected) if err != nil { return "", fmt.Errorf("r.DB.StoreEvent: %w", err) } diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 8c2477dee..e198f67d8 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -546,6 +546,7 @@ func joinEventsFromHistoryVisibility( func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) { var roomNID types.RoomNID + var eventNID types.EventNID backfilledEventMap := make(map[string]types.Event) for j, ev := range events { nidMap, err := db.EventNIDs(ctx, ev.AuthEventIDs()) @@ -559,10 +560,9 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs authNids[i] = nid i++ } - var stateAtEvent types.StateAtEvent var redactedEventID string var redactionEvent *gomatrixserverlib.Event - roomNID, stateAtEvent, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), authNids, false) + eventNID, roomNID, _, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), authNids, false) if err != nil { logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") continue @@ -581,7 +581,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs events[j] = ev } backfilledEventMap[ev.EventID()] = types.Event{ - EventNID: stateAtEvent.StateEntry.EventNID, + EventNID: eventNID, Event: ev.Unwrap(), } } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 7f6b98557..15764366b 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -70,7 +70,7 @@ type Database interface { StoreEvent( ctx context.Context, event *gomatrixserverlib.Event, authEventNIDs []types.EventNID, isRejected bool, - ) (types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) + ) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) // Look up the state entries for a list of string event IDs // Returns an error if the there is an error talking to the database // Returns a types.MissingEventError if the event IDs aren't in the database. diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index dbf706e5d..f49536f4e 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -461,7 +461,7 @@ func (d *Database) GetLatestEventsForUpdate( func (d *Database) StoreEvent( ctx context.Context, event *gomatrixserverlib.Event, authEventNIDs []types.EventNID, isRejected bool, -) (types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { +) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { var ( roomNID types.RoomNID eventTypeNID types.EventTypeNID @@ -538,7 +538,7 @@ func (d *Database) StoreEvent( return nil }) if err != nil { - return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err) + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err) } // We should attempt to update the previous events table with any @@ -551,10 +551,10 @@ func (d *Database) StoreEvent( if prevEvents := event.PrevEvents(); len(prevEvents) > 0 { roomInfo, err = d.RoomInfo(ctx, event.RoomID()) if err != nil { - return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err) + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err) } if roomInfo == nil && len(prevEvents) > 0 { - return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID()) + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID()) } // Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of // GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This @@ -563,7 +563,7 @@ func (d *Database) StoreEvent( // to do writes however then this will need to go inside `Writer.Do`. updater, err = d.GetLatestEventsForUpdate(ctx, *roomInfo) if err != nil { - return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("NewLatestEventsUpdater: %w", err) + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("NewLatestEventsUpdater: %w", err) } // Ensure that we atomically store prev events AND commit them. If we don't wrap StorePreviousEvents // and EndTransaction in a writer then it's possible for a new write txn to be made between the two @@ -580,11 +580,11 @@ func (d *Database) StoreEvent( return err }) if err != nil { - return 0, types.StateAtEvent{}, nil, "", err + return 0, 0, types.StateAtEvent{}, nil, "", err } } - return roomNID, types.StateAtEvent{ + return eventNID, roomNID, types.StateAtEvent{ BeforeStateSnapshotNID: stateNID, StateEntry: types.StateEntry{ StateKeyTuple: types.StateKeyTuple{ diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index b7fe7ee4f..3127eb17d 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -49,7 +49,8 @@ const eventsSchema = ` const insertEventSQL = ` INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - ON CONFLICT DO NOTHING; + ON CONFLICT DO NOTHING + RETURNING event_nid, state_snapshot_nid; ` const selectEventSQL = "" + @@ -161,20 +162,13 @@ func (s *eventStatements) InsertEvent( ) (types.EventNID, types.StateSnapshotNID, error) { // attempt to insert: the last_row_id is the event NID var eventNID int64 + var stateNID int64 insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) - result, err := insertStmt.ExecContext( + err := insertStmt.QueryRowContext( ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, isRejected, - ) - if err != nil { - return 0, 0, err - } - modified, err := result.RowsAffected() - if modified == 0 && err == nil { - return 0, 0, sql.ErrNoRows - } - eventNID, err = result.LastInsertId() - return types.EventNID(eventNID), 0, err + ).Scan(&eventNID, &stateNID) + return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } func (s *eventStatements) SelectEvent(