diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index a6de8ac84..7efad7af6 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "fmt" + "sort" "strings" "github.com/matrix-org/gomatrixserverlib" @@ -159,7 +160,7 @@ func GetMembershipsAtState( ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool, ) ([]types.Event, error) { - var eventNIDs []types.EventNID + var eventNIDs types.EventNIDs for _, entry := range stateEntries { // Filter the events to retrieve to only keep the membership events if entry.EventTypeNID == types.MRoomMemberNID { @@ -167,6 +168,14 @@ func GetMembershipsAtState( } } + // There are no events to get, don't bother asking the database + if len(eventNIDs) == 0 { + return []types.Event{}, nil + } + + sort.Sort(eventNIDs) + util.Unique(eventNIDs) + // Get all of the events in this state stateEvents, err := db.Events(ctx, eventNIDs) if err != nil { diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 0db046a86..8850e5c46 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -239,16 +239,42 @@ func (r *Queryer) QueryMembershipAtEvent( return fmt.Errorf("unable to get state before event: %w", err) } + // If we only have one or less state entries, we can short circuit the below + // loop and avoid hitting the database + allStateEventNIDs := make(map[types.EventNID]types.StateEntry) + for _, eventID := range request.EventIDs { + stateEntry := stateEntries[eventID] + for _, s := range stateEntry { + allStateEventNIDs[s.EventNID] = s + } + } + + var canShortCircuit bool + if len(allStateEventNIDs) <= 1 { + canShortCircuit = true + } + + var memberships []types.Event for _, eventID := range request.EventIDs { stateEntry, ok := stateEntries[eventID] - if !ok { + if !ok || len(stateEntry) == 0 { response.Memberships[eventID] = []*gomatrixserverlib.HeaderedEvent{} continue } - memberships, err := helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) + + // If we can short circuit, e.g. we only have 0 or 1 membership events, we only get the memberships + // once. If we have more than one membership event, we need to get the state for each state entry. + if canShortCircuit { + if len(memberships) == 0 { + memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) + } + } else { + memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) + } if err != nil { return fmt.Errorf("unable to get memberships at state: %w", err) } + res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships)) for i := range memberships { diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 018348466..1cfde5e4b 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -18,17 +18,17 @@ package state import ( "context" - "database/sql" "fmt" "sort" "sync" "time" - "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus" + + "github.com/matrix-org/dendrite/roomserver/types" ) type StateResolutionStorage interface { @@ -37,6 +37,7 @@ type StateResolutionStorage interface { StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) + BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) @@ -130,21 +131,10 @@ func (v *StateResolution) LoadMembershipAtEvent( span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadMembershipAtEvent") defer span.Finish() - // De-dupe snapshotNIDs - snapshotNIDMap := make(map[types.StateSnapshotNID][]string) // map from snapshot NID to eventIDs - for i := range eventIDs { - eventID := eventIDs[i] - snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) - if err != nil && err != sql.ErrNoRows { - return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %w", eventID, err) - } - if snapshotNID == 0 { - // If we don't know a state snapshot for this event then we can't calculate - // memberships at the time of the event, so skip over it. This means that - // it isn't guaranteed that the response map will contain every single event. - continue - } - snapshotNIDMap[snapshotNID] = append(snapshotNIDMap[snapshotNID], eventID) + // Get a mapping from snapshotNID -> eventIDs + snapshotNIDMap, err := v.db.BulkSelectSnapshotsFromEventIDs(ctx, eventIDs) + if err != nil { + return nil, err } snapshotNIDs := make([]types.StateSnapshotNID, 0, len(snapshotNIDMap)) @@ -157,24 +147,45 @@ func (v *StateResolution) LoadMembershipAtEvent( return nil, err } + var wantStateBlocks []types.StateBlockNID + for _, x := range stateBlockNIDLists { + wantStateBlocks = append(wantStateBlocks, x.StateBlockNIDs...) + } + + stateEntryLists, err := v.db.StateEntriesForTuples(ctx, uniqueStateBlockNIDs(wantStateBlocks), []types.StateKeyTuple{ + { + EventTypeNID: types.MRoomMemberNID, + EventStateKeyNID: stateKeyNID, + }, + }) + if err != nil { + return nil, err + } + + stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists) + stateEntriesMap := stateEntryListMap(stateEntryLists) + result := make(map[string][]types.StateEntry) for _, stateBlockNIDList := range stateBlockNIDLists { - // Query the membership event for the user at the given stateblocks - stateEntryLists, err := v.db.StateEntriesForTuples(ctx, stateBlockNIDList.StateBlockNIDs, []types.StateKeyTuple{ - { - EventTypeNID: types.MRoomMemberNID, - EventStateKeyNID: stateKeyNID, - }, - }) - if err != nil { - return nil, err + stateBlockNIDs, ok := stateBlockNIDsMap.lookup(stateBlockNIDList.StateSnapshotNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + return nil, fmt.Errorf("corrupt DB: Missing state snapshot numeric ID %d", stateBlockNIDList.StateSnapshotNID) } - evIDs := snapshotNIDMap[stateBlockNIDList.StateSnapshotNID] + for _, stateBlockNID := range stateBlockNIDs { + entries, ok := stateEntriesMap.lookup(stateBlockNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + return nil, fmt.Errorf("corrupt DB: Missing state block numeric ID %d", stateBlockNID) + } - for _, evID := range evIDs { - for _, x := range stateEntryLists { - result[evID] = append(result[evID], x.StateEntries...) + evIDs := snapshotNIDMap[stateBlockNIDList.StateSnapshotNID] + + for _, evID := range evIDs { + result[evID] = append(result[evID], entries...) } } } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 094537948..c39a8cbba 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -72,6 +72,7 @@ type Database interface { Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) // Look up snapshot NID for an event ID string SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) + BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. StoreEvent( ctx context.Context, event *gomatrixserverlib.Event, authEventNIDs []types.EventNID, diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 1e7ca7669..9b5ed6eda 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -22,11 +22,12 @@ import ( "sort" "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" ) const eventsSchema = ` @@ -80,6 +81,9 @@ const insertEventSQL = "" + const selectEventSQL = "" + "SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1" +const bulkSelectSnapshotsForEventIDsSQL = "" + + "SELECT event_id, state_snapshot_nid FROM roomserver_events WHERE event_id = ANY($1)" + // Bulk lookup of events by string ID. // Sort by the numeric IDs for event type and state key. // This means we can use binary search to lookup entries by type and state key. @@ -150,6 +154,7 @@ const selectEventRejectedSQL = "" + type eventStatements struct { insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt + bulkSelectSnapshotsForEventIDsStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt bulkSelectStateEventByIDExcludingRejectedStmt *sql.Stmt bulkSelectStateEventByNIDStmt *sql.Stmt @@ -179,6 +184,7 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) { return s, sqlutil.StatementList{ {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, + {&s.bulkSelectSnapshotsForEventIDsStmt, bulkSelectSnapshotsForEventIDsSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, {&s.bulkSelectStateEventByIDExcludingRejectedStmt, bulkSelectStateEventByIDExcludingRejectedSQL}, {&s.bulkSelectStateEventByNIDStmt, bulkSelectStateEventByNIDSQL}, @@ -230,6 +236,29 @@ func (s *eventStatements) SelectEvent( return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } +func (s *eventStatements) BulkSelectSnapshotsFromEventIDs( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) (map[types.StateSnapshotNID][]string, error) { + stmt := sqlutil.TxStmt(txn, s.bulkSelectSnapshotsForEventIDsStmt) + + rows, err := stmt.QueryContext(ctx, pq.Array(eventIDs)) + if err != nil { + return nil, err + } + + var eventID string + var stateNID types.StateSnapshotNID + result := make(map[types.StateSnapshotNID][]string) + for rows.Next() { + if err := rows.Scan(&eventID, &stateNID); err != nil { + return nil, err + } + result[stateNID] = append(result[stateNID], eventID) + } + + return result, rows.Err() +} + // bulkSelectStateEventByID lookups a list of state events by event ID. // If not excluding rejected events, and any of the requested events are missing from // the database it returns a types.MissingEventError. If excluding rejected events, diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 42c0c8f2d..cc880a6c8 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -5,8 +5,9 @@ import ( "database/sql" "fmt" - "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/roomserver/types" ) type RoomUpdater struct { @@ -186,6 +187,10 @@ func (u *RoomUpdater) EventIDs( return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs) } +func (u *RoomUpdater) BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) { + return u.d.EventsTable.BulkSelectSnapshotsFromEventIDs(ctx, u.txn, eventIDs) +} + func (u *RoomUpdater) StateAtEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index ed86280bf..4455ec3bf 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -469,6 +469,23 @@ func (d *Database) events( eventNIDs = append(eventNIDs, nid) } } + // If we don't need to get any events from the database, short circuit now + if len(eventNIDs) == 0 { + results := make([]types.Event, 0, len(inputEventNIDs)) + for _, nid := range inputEventNIDs { + event, ok := events[nid] + if !ok || event == nil { + return nil, fmt.Errorf("event %d missing", nid) + } + results = append(results, types.Event{ + EventNID: nid, + Event: event, + }) + } + if !redactionsArePermanent { + d.applyRedactions(results) + } + } eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs) if err != nil { return nil, err @@ -534,6 +551,12 @@ func (d *Database) events( return results, nil } +func (d *Database) BulkSelectSnapshotsFromEventIDs( + ctx context.Context, eventIDs []string, +) (map[types.StateSnapshotNID][]string, error) { + return d.EventsTable.BulkSelectSnapshotsFromEventIDs(ctx, nil, eventIDs) +} + func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 950d03b03..f39b9902d 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -23,11 +23,12 @@ import ( "sort" "strings" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" ) const eventsSchema = ` @@ -57,6 +58,9 @@ const insertEventSQL = ` const selectEventSQL = "" + "SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1" +const bulkSelectSnapshotsForEventIDsSQL = "" + + "SELECT event_id, state_snapshot_nid FROM roomserver_events WHERE event_id IN ($1)" + // Bulk lookup of events by string ID. // Sort by the numeric IDs for event type and state key. // This means we can use binary search to lookup entries by type and state key. @@ -124,6 +128,7 @@ type eventStatements struct { db *sql.DB insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt + bulkSelectSnapshotsForEventIDsStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt bulkSelectStateEventByIDExcludingRejectedStmt *sql.Stmt bulkSelectStateAtEventByIDStmt *sql.Stmt @@ -153,6 +158,7 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) { return s, sqlutil.StatementList{ {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, + {&s.bulkSelectSnapshotsForEventIDsStmt, bulkSelectSnapshotsForEventIDsSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, {&s.bulkSelectStateEventByIDExcludingRejectedStmt, bulkSelectStateEventByIDExcludingRejectedSQL}, {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, @@ -203,6 +209,40 @@ func (s *eventStatements) SelectEvent( return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } +func (s *eventStatements) BulkSelectSnapshotsFromEventIDs( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) (map[types.StateSnapshotNID][]string, error) { + qry := strings.Replace(bulkSelectSnapshotsForEventIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1) + stmt, err := s.db.Prepare(qry) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, stmt, "BulkSelectSnapshotsFromEventIDs: stmt.close() failed") + + params := make([]interface{}, len(eventIDs)) + for i := range eventIDs { + params[i] = eventIDs[i] + } + + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "BulkSelectSnapshotsFromEventIDs: rows.close() failed") + + var eventID string + var stateNID types.StateSnapshotNID + result := make(map[types.StateSnapshotNID][]string) + for rows.Next() { + if err := rows.Scan(&eventID, &stateNID); err != nil { + return nil, err + } + result[stateNID] = append(result[stateNID], eventID) + } + + return result, rows.Err() +} + // bulkSelectStateEventByID lookups a list of state events by event ID. // If not excluding rejected events, and any of the requested events are missing from // the database it returns a types.MissingEventError. If excluding rejected events, diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 8d6ca324c..50d27c756 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -44,6 +44,7 @@ type Events interface { referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, isRejected bool, ) (types.EventNID, types.StateSnapshotNID, error) SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) + BulkSelectSnapshotsFromEventIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[types.StateSnapshotNID][]string, error) // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error)