diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/events.go b/src/github.com/matrix-org/dendrite/roomserver/input/events.go index d26c95d1c..adc25661d 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/events.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/events.go @@ -38,9 +38,9 @@ type RoomEventDatabase interface { // The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. // Returns the latest events in the room and the last eventID sent to the log along with an updater. // If this returns an error then no further action is required. - GetLatestEventsForUpdate(roomNID types.RoomNID) ( - latestEvents []types.StateAtEventAndReference, lastEventIDSent string, updater types.RoomRecentEventsUpdater, err error, - ) + GetLatestEventsForUpdate(roomNID types.RoomNID) (updater types.RoomRecentEventsUpdater, err error) + // Lookup the string event IDs for a list of numeric event IDs + EventIDs(eventNIDs []types.EventNID) (map[types.EventNID]string, error) } // OutputRoomEventWriter has the APIs needed to write an event to the output logs. @@ -91,7 +91,7 @@ func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api. } } else { // We haven't been told what the state at the event is so we need to calculate it from the prev_events - if stateAtEvent.BeforeStateSnapshotNID, err = calculateAndStoreState(db, event, roomNID); err != nil { + if stateAtEvent.BeforeStateSnapshotNID, err = calculateAndStoreStateBeforeEvent(db, event, roomNID); err != nil { return err } } diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/latest_events.go b/src/github.com/matrix-org/dendrite/roomserver/input/latest_events.go index 55b712d82..feaeccdb0 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/latest_events.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/latest_events.go @@ -26,7 +26,7 @@ import ( func updateLatestEvents( db RoomEventDatabase, ow OutputRoomEventWriter, roomNID types.RoomNID, stateAtEvent types.StateAtEvent, event gomatrixserverlib.Event, ) (err error) { - oldLatest, lastEventIDSent, updater, err := db.GetLatestEventsForUpdate(roomNID) + updater, err := db.GetLatestEventsForUpdate(roomNID) if err != nil { return } @@ -44,16 +44,19 @@ func updateLatestEvents( } }() - err = doUpdateLatestEvents(updater, ow, oldLatest, lastEventIDSent, roomNID, stateAtEvent, event) + err = doUpdateLatestEvents(db, updater, ow, roomNID, stateAtEvent, event) return } func doUpdateLatestEvents( - updater types.RoomRecentEventsUpdater, ow OutputRoomEventWriter, oldLatest []types.StateAtEventAndReference, lastEventIDSent string, roomNID types.RoomNID, stateAtEvent types.StateAtEvent, event gomatrixserverlib.Event, + db RoomEventDatabase, updater types.RoomRecentEventsUpdater, ow OutputRoomEventWriter, roomNID types.RoomNID, stateAtEvent types.StateAtEvent, event gomatrixserverlib.Event, ) error { var err error var prevEvents []gomatrixserverlib.EventReference prevEvents = event.PrevEvents() + oldLatest := updater.LatestEvents() + lastEventIDSent := updater.LastEventIDSent() + oldStateNID := updater.CurrentStateSnapshotNID() if hasBeenSent, err := updater.HasEventBeenSent(stateAtEvent.EventNID); err != nil { return err @@ -78,6 +81,20 @@ func doUpdateLatestEvents( StateAtEvent: stateAtEvent, }) + latestStateAtEvents := make([]types.StateAtEvent, len(newLatest)) + for i := range newLatest { + latestStateAtEvents[i] = newLatest[i].StateAtEvent + } + newStateNID, err := calculateAndStoreStateAfterEvents(db, roomNID, latestStateAtEvents) + if err != nil { + return err + } + + removed, added, err := differenceBetweeenStateSnapshots(db, oldStateNID, newStateNID) + if err != nil { + return err + } + // Send the event to the output logs. // We do this inside the database transaction to ensure that we only mark an event as sent if we sent it. // (n.b. this means that it's possible that the same event will be sent twice if the transaction fails but @@ -86,11 +103,11 @@ func doUpdateLatestEvents( // send the event asynchronously but we would need to ensure that 1) the events are written to the log in // the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the // necessary bookkeeping we'll keep the event sending synchronous for now. - if err = writeEvent(ow, lastEventIDSent, event, newLatest); err != nil { + if err = writeEvent(db, ow, lastEventIDSent, event, newLatest, removed, added); err != nil { return err } - if err = updater.SetLatestEvents(roomNID, newLatest, stateAtEvent.EventNID); err != nil { + if err = updater.SetLatestEvents(roomNID, newLatest, stateAtEvent.EventNID, newStateNID); err != nil { return err } @@ -134,18 +151,41 @@ func calculateLatest(oldLatest []types.StateAtEventAndReference, alreadyReferenc return newLatest } -func writeEvent(ow OutputRoomEventWriter, lastEventIDSent string, event gomatrixserverlib.Event, latest []types.StateAtEventAndReference) error { +func writeEvent( + db RoomEventDatabase, ow OutputRoomEventWriter, lastEventIDSent string, + event gomatrixserverlib.Event, latest []types.StateAtEventAndReference, + removed, added []types.StateEntry, +) error { latestEventIDs := make([]string, len(latest)) for i := range latest { latestEventIDs[i] = latest[i].EventID } - // TODO: Fill out AddsStateEventIDs and RemovesStateEventIDs - // TODO: Fill out VisibilityStateIDs - return ow.WriteOutputRoomEvent(api.OutputRoomEvent{ + ore := api.OutputRoomEvent{ Event: event.JSON(), LastSentEventID: lastEventIDSent, LatestEventIDs: latestEventIDs, - }) + } + + var stateEventNIDs []types.EventNID + for _, entry := range added { + stateEventNIDs = append(stateEventNIDs, entry.EventNID) + } + for _, entry := range removed { + stateEventNIDs = append(stateEventNIDs, entry.EventNID) + } + eventIDMap, err := db.EventIDs(stateEventNIDs) + if err != nil { + return err + } + for _, entry := range added { + ore.AddsStateEventIDs = append(ore.AddsStateEventIDs, eventIDMap[entry.EventNID]) + } + for _, entry := range removed { + ore.RemovesStateEventIDs = append(ore.RemovesStateEventIDs, eventIDMap[entry.EventNID]) + } + + // TODO: Fill out VisibilityStateIDs + return ow.WriteOutputRoomEvent(ore) } diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/state.go b/src/github.com/matrix-org/dendrite/roomserver/input/state.go index c46dc6e14..36ab43b1c 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/state.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/state.go @@ -9,8 +9,8 @@ import ( // calculateAndStoreState calculates a snapshot of the state of a room before an event. // Stores the snapshot of the state in the database. -// Returns a numeric ID for that snapshot. -func calculateAndStoreState( +// Returns a numeric ID for the snapshot of the state before the event. +func calculateAndStoreStateBeforeEvent( db RoomEventDatabase, event gomatrixserverlib.Event, roomNID types.RoomNID, ) (types.StateSnapshotNID, error) { // Load the state at the prev events. @@ -25,6 +25,13 @@ func calculateAndStoreState( return 0, err } + // The state before this event will be the state after the events that came before it. + return calculateAndStoreStateAfterEvents(db, roomNID, prevStates) +} + +// calculateAndStoreStateAfterEvents finds the room state after the given events. +// Stores the resulting state in the database and returns a numeric ID for that snapshot. +func calculateAndStoreStateAfterEvents(db RoomEventDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent) (types.StateSnapshotNID, error) { if len(prevStates) == 0 { // 2) There weren't any prev_events for this event so the state is // empty. @@ -55,9 +62,9 @@ func calculateAndStoreState( ) } // If there are too many deltas then we need to calculate the full state - // So fall through to calculateAndStoreStateMany + // So fall through to calculateAndStoreStateAfterManyEvents } - return calculateAndStoreStateMany(db, roomNID, prevStates) + return calculateAndStoreStateAfterManyEvents(db, roomNID, prevStates) } // maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state. @@ -67,10 +74,10 @@ func calculateAndStoreState( // TODO: Tune this to get the right balance between size and lookup performance. const maxStateBlockNIDs = 64 -// calculateAndStoreStateMany calculates the state of the room before an event -// using the states at each of the event's prev events. +// calculateAndStoreStateAfterManyEvents finds the room state after the given events. +// This handles the slow path of calculateAndStoreStateAfterEvents for when there is more than one event. // Stores the resulting state and returns a numeric ID for the snapshot. -func calculateAndStoreStateMany(db RoomEventDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent) (types.StateSnapshotNID, error) { +func calculateAndStoreStateAfterManyEvents(db RoomEventDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent) (types.StateSnapshotNID, error) { // Conflict resolution. // First stage: load the state after each of the prev events. combined, err := loadCombinedStateAfterEvents(db, prevStates) @@ -107,6 +114,98 @@ func calculateAndStoreStateMany(db RoomEventDatabase, roomNID types.RoomNID, pre return db.AddState(roomNID, nil, state) } +// differenceBetweeenStateSnapshots works out which state entries have been added and removed between two snapshots. +func differenceBetweeenStateSnapshots(db RoomEventDatabase, oldStateNID, newStateNID types.StateSnapshotNID) ( + removed, added []types.StateEntry, err error, +) { + if oldStateNID == newStateNID { + // If the snapshot NIDs are the same then nothing has changed + return nil, nil, nil + } + + var oldEntries []types.StateEntry + var newEntries []types.StateEntry + if oldStateNID != 0 { + oldEntries, err = loadStateAtSnapshot(db, oldStateNID) + if err != nil { + return nil, nil, err + } + } + if newStateNID != 0 { + newEntries, err = loadStateAtSnapshot(db, newStateNID) + if err != nil { + return nil, nil, err + } + } + + var oldI int + var newI int + for { + switch { + case oldI == len(oldEntries): + // We've reached the end of the old entries. + // The rest of the new list must have been newly added. + added = append(added, newEntries[newI:]...) + return + case newI == len(newEntries): + // We've reached the end of the new entries. + // The rest of the old list must be have been removed. + removed = append(removed, oldEntries[oldI:]...) + return + case oldEntries[oldI] == newEntries[newI]: + // The entry is in both lists so skip over it. + oldI++ + newI++ + case oldEntries[oldI].LessThan(newEntries[newI]): + // The lists are sorted so the old entry being less than the new entry means that it only appears in the old list. + removed = append(removed, oldEntries[oldI]) + oldI++ + default: + // Reaching the default case implies that the new entry is less than the old entry. + // Since the lists are sorted this means that it only appears in the new list. + added = append(added, newEntries[newI]) + newI++ + } + } +} + +// loadStateAtSnapshot loads the full state of a room at a particular snapshot. +// This is typically the state before an event or the current state of a room. +// Returns a sorted list of state entries or an error if there was a problem talking to the database. +func loadStateAtSnapshot(db RoomEventDatabase, stateNID types.StateSnapshotNID) ([]types.StateEntry, error) { + stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{stateNID}) + if err != nil { + return nil, err + } + stateBlockNIDList := stateBlockNIDLists[0] + + stateEntryLists, err := db.StateEntries(stateBlockNIDList.StateBlockNIDs) + if err != nil { + return nil, err + } + stateEntriesMap := stateEntryListMap(stateEntryLists) + + // Combined all the state entries for this snapshot. + // The order of state block NIDs in the list tells us the order to combine them in. + var fullState []types.StateEntry + for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { + entries, ok := stateEntriesMap.lookup(stateBlockNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) + } + fullState = append(fullState, entries...) + } + + // Stable sort so that the most recent entry for each state key stays + // remains later in the list than the older entries for the same state key. + sort.Stable(stateEntryByStateKeySorter(fullState)) + // Unique returns the last entry and hence the most recent entry for each state key. + fullState = fullState[:unique(stateEntryByStateKeySorter(fullState))] + return fullState, nil +} + // loadCombinedStateAfterEvents loads a snapshot of the state after each of the events // and combines those snapshots together into a single list. func loadCombinedStateAfterEvents(db RoomEventDatabase, prevStates []types.StateAtEvent) ([]types.StateEntry, error) { @@ -146,18 +245,18 @@ func loadCombinedStateAfterEvents(db RoomEventDatabase, prevStates []types.State if !ok { // This should only get hit if the database is corrupt. // It should be impossible for an event to reference a NID that doesn't exist - panic(fmt.Errorf("Corrupt DB: Missing state numeric ID %d", prevState.BeforeStateSnapshotNID)) + panic(fmt.Errorf("Corrupt DB: Missing state snapshot numeric ID %d", prevState.BeforeStateSnapshotNID)) } // Combined all the state entries for this snapshot. - // The order of state data NIDs in the list tells us the order to combine them in. + // The order of state block NIDs in the list tells us the order to combine them in. var fullState []types.StateEntry for _, stateBlockNID := range stateBlockNIDs { entries, ok := stateEntriesMap.lookup(stateBlockNID) if !ok { // This should only get hit if the database is corrupt. // It should be impossible for an event to reference a NID that doesn't exist - panic(fmt.Errorf("Corrupt DB: Missing state numeric ID %d", prevState.BeforeStateSnapshotNID)) + panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) } fullState = append(fullState, entries...) } diff --git a/src/github.com/matrix-org/dendrite/roomserver/roomserver-integration-tests/main.go b/src/github.com/matrix-org/dendrite/roomserver/roomserver-integration-tests/main.go index b8766ca87..e7cd2d1fa 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/roomserver-integration-tests/main.go +++ b/src/github.com/matrix-org/dendrite/roomserver/roomserver-integration-tests/main.go @@ -356,7 +356,7 @@ func main() { }, "VisibilityEventIDs":null, "LatestEventIDs":["$1463671339126270PnVwC:matrix.org"], - "AddsStateEventIDs":null, + "AddsStateEventIDs":["$1463671337126266wrSBX:matrix.org", "$1463671339126270PnVwC:matrix.org"], "RemovesStateEventIDs":null, "LastSentEventID":"" }`, diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go index 471bdcfbb..63cfa266b 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go @@ -87,6 +87,9 @@ const bulkSelectStateAtEventAndReferenceSQL = "" + const bulkSelectEventReferenceSQL = "" + "SELECT event_id, reference_sha256 FROM events WHERE event_nid = ANY($1)" +const bulkSelectEventIDSQL = "" + + "SELECT event_nid, event_id FROM events WHERE event_nid = ANY($1)" + type eventStatements struct { insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt @@ -98,6 +101,7 @@ type eventStatements struct { selectEventIDStmt *sql.Stmt bulkSelectStateAtEventAndReferenceStmt *sql.Stmt bulkSelectEventReferenceStmt *sql.Stmt + bulkSelectEventIDStmt *sql.Stmt } func (s *eventStatements) prepare(db *sql.DB) (err error) { @@ -105,36 +109,30 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil { - return + + statements := []struct { + statement **sql.Stmt + sql string + }{ + {&s.insertEventStmt, insertEventSQL}, + {&s.selectEventStmt, selectEventSQL}, + {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, + {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, + {&s.updateEventStateStmt, updateEventStateSQL}, + {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, + {&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL}, + {&s.selectEventIDStmt, selectEventIDSQL}, + {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, + {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, + {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, } - if s.selectEventStmt, err = db.Prepare(selectEventSQL); err != nil { - return - } - if s.bulkSelectStateEventByIDStmt, err = db.Prepare(bulkSelectStateEventByIDSQL); err != nil { - return - } - if s.bulkSelectStateAtEventByIDStmt, err = db.Prepare(bulkSelectStateAtEventByIDSQL); err != nil { - return - } - if s.updateEventStateStmt, err = db.Prepare(updateEventStateSQL); err != nil { - return - } - if s.updateEventSentToOutputStmt, err = db.Prepare(updateEventSentToOutputSQL); err != nil { - return - } - if s.selectEventSentToOutputStmt, err = db.Prepare(selectEventSentToOutputSQL); err != nil { - return - } - if s.selectEventIDStmt, err = db.Prepare(selectEventIDSQL); err != nil { - return - } - if s.bulkSelectStateAtEventAndReferenceStmt, err = db.Prepare(bulkSelectStateAtEventAndReferenceSQL); err != nil { - return - } - if s.bulkSelectEventReferenceStmt, err = db.Prepare(bulkSelectEventReferenceSQL); err != nil { - return + + for _, statement := range statements { + if *statement.statement, err = db.Prepare(statement.sql); err != nil { + return + } } + return } @@ -297,6 +295,29 @@ func (s *eventStatements) bulkSelectEventReference(eventNIDs []types.EventNID) ( return results, nil } +// bulkSelectEventID returns a map from numeric event ID to string event ID. +func (s *eventStatements) bulkSelectEventID(eventNIDs []types.EventNID) (map[types.EventNID]string, error) { + rows, err := s.bulkSelectEventIDStmt.Query(eventNIDsAsArray(eventNIDs)) + if err != nil { + return nil, err + } + defer rows.Close() + results := make(map[types.EventNID]string, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + var eventNID int64 + var eventID string + if err = rows.Scan(&eventNID, &eventID); err != nil { + return nil, err + } + results[types.EventNID(eventNID)] = eventID + } + if i != len(eventNIDs) { + return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) + } + return results, nil +} + func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array { nids := make([]int64, len(eventNIDs)) for i := range eventNIDs { diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go index a6be8fcb5..93a65e473 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/rooms_table.go @@ -18,7 +18,10 @@ CREATE TABLE IF NOT EXISTS rooms ( -- (The server will be in that state while it stores the events for the initial state of the room) latest_event_nids BIGINT[] NOT NULL DEFAULT '{}'::BIGINT[], -- The last event written to the output log for this room. - last_event_sent_nid BIGINT NOT NULL DEFAULT 0 + last_event_sent_nid BIGINT NOT NULL DEFAULT 0, + -- The state of the room after the current set of latest events. + -- This will be 0 if there are no latest events in the room. + state_snapshot_nid BIGINT NOT NULL DEFAULT 0 ); ` @@ -35,10 +38,10 @@ const selectLatestEventNIDsSQL = "" + "SELECT latest_event_nids FROM rooms WHERE room_nid = $1" const selectLatestEventNIDsForUpdateSQL = "" + - "SELECT latest_event_nids, last_event_sent_nid FROM rooms WHERE room_nid = $1 FOR UPDATE" + "SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM rooms WHERE room_nid = $1 FOR UPDATE" const updateLatestEventNIDsSQL = "" + - "UPDATE rooms SET latest_event_nids = $2, last_event_sent_nid = $3 WHERE room_nid = $1" + "UPDATE rooms SET latest_event_nids = $2, last_event_sent_nid = $3, state_snapshot_nid = $4 WHERE room_nid = $1" type roomStatements struct { insertRoomNIDStmt *sql.Stmt @@ -96,21 +99,29 @@ func (s *roomStatements) selectLatestEventNIDs(roomNID types.RoomNID) ([]types.E return eventNIDs, nil } -func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, error) { +func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID types.RoomNID) ( + []types.EventNID, types.EventNID, types.StateSnapshotNID, error, +) { var nids pq.Int64Array var lastEventSentNID int64 - err := txn.Stmt(s.selectLatestEventNIDsForUpdateStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID) + var stateSnapshotNID int64 + err := txn.Stmt(s.selectLatestEventNIDsForUpdateStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID) if err != nil { - return nil, 0, err + return nil, 0, 0, err } eventNIDs := make([]types.EventNID, len(nids)) for i := range nids { eventNIDs[i] = types.EventNID(nids[i]) } - return eventNIDs, types.EventNID(lastEventSentNID), nil + return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil } -func (s *roomStatements) updateLatestEventNIDs(txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID) error { - _, err := txn.Stmt(s.updateLatestEventNIDsStmt).Exec(roomNID, eventNIDsAsArray(eventNIDs), int64(lastEventSentNID)) +func (s *roomStatements) updateLatestEventNIDs( + txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, + stateSnapshotNID types.StateSnapshotNID, +) error { + _, err := txn.Stmt(s.updateLatestEventNIDsStmt).Exec( + roomNID, eventNIDsAsArray(eventNIDs), int64(lastEventSentNID), int64(stateSnapshotNID), + ) return err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go index bede936ad..989d91b0b 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go @@ -205,38 +205,62 @@ func (d *Database) StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.S return d.statements.bulkSelectStateDataEntries(stateBlockNIDs) } +// EventIDs implements input.RoomEventDatabase +func (d *Database) EventIDs(eventNIDs []types.EventNID) (map[types.EventNID]string, error) { + return d.statements.bulkSelectEventID(eventNIDs) +} + // GetLatestEventsForUpdate implements input.EventDatabase -func (d *Database) GetLatestEventsForUpdate(roomNID types.RoomNID) ([]types.StateAtEventAndReference, string, types.RoomRecentEventsUpdater, error) { +func (d *Database) GetLatestEventsForUpdate(roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error) { txn, err := d.db.Begin() if err != nil { - return nil, "", nil, err + return nil, err } - eventNIDs, lastEventNIDSent, err := d.statements.selectLatestEventsNIDsForUpdate(txn, roomNID) + eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := d.statements.selectLatestEventsNIDsForUpdate(txn, roomNID) if err != nil { txn.Rollback() - return nil, "", nil, err + return nil, err } stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(txn, eventNIDs) if err != nil { txn.Rollback() - return nil, "", nil, err + return nil, err } var lastEventIDSent string if lastEventNIDSent != 0 { lastEventIDSent, err = d.statements.selectEventID(txn, lastEventNIDSent) if err != nil { txn.Rollback() - return nil, "", nil, err + return nil, err } } - return stateAndRefs, lastEventIDSent, &roomRecentEventsUpdater{txn, d}, nil + return &roomRecentEventsUpdater{txn, d, stateAndRefs, lastEventIDSent, currentStateSnapshotNID}, nil } type roomRecentEventsUpdater struct { - txn *sql.Tx - d *Database + txn *sql.Tx + d *Database + latestEvents []types.StateAtEventAndReference + lastEventIDSent string + currentStateSnapshotNID types.StateSnapshotNID } +// LatestEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference { + return u.latestEvents +} + +// LastEventIDSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) LastEventIDSent() string { + return u.lastEventIDSent +} + +// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { + return u.currentStateSnapshotNID +} + +// StorePreviousEvents implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { for _, ref := range previousEventReferences { if err := u.d.statements.insertPreviousEvent(u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { @@ -246,6 +270,7 @@ func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, p return nil } +// IsReferenced implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { err := u.d.statements.selectPreviousEventExists(u.txn, eventReference.EventID, eventReference.EventSHA256) if err == nil { @@ -257,26 +282,34 @@ func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib. return false, err } -func (u *roomRecentEventsUpdater) SetLatestEvents(roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID) error { +// SetLatestEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) SetLatestEvents( + roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, + currentStateSnapshotNID types.StateSnapshotNID, +) error { eventNIDs := make([]types.EventNID, len(latest)) for i := range latest { eventNIDs[i] = latest[i].EventNID } - return u.d.statements.updateLatestEventNIDs(u.txn, roomNID, eventNIDs, lastEventNIDSent) + return u.d.statements.updateLatestEventNIDs(u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) } +// HasEventBeenSent implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { return u.d.statements.selectEventSentToOutput(u.txn, eventNID) } +// MarkEventAsSent implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { return u.d.statements.updateEventSentToOutput(u.txn, eventNID) } +// Commit implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) Commit() error { return u.txn.Commit() } +// Rollback implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) Rollback() error { return u.txn.Rollback() } diff --git a/src/github.com/matrix-org/dendrite/roomserver/types/types.go b/src/github.com/matrix-org/dendrite/roomserver/types/types.go index 0a90d40d1..1d268e1f4 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/types/types.go +++ b/src/github.com/matrix-org/dendrite/roomserver/types/types.go @@ -133,6 +133,12 @@ type StateEntryList struct { // (On postgresql this wraps a database transaction that holds a "FOR UPDATE" // lock on the row holding the latest events for the room.) type RoomRecentEventsUpdater interface { + // The latest event IDs and state in the room. + LatestEvents() []StateAtEventAndReference + // The event ID of the latest event written to the output log in the room. + LastEventIDSent() string + // The current state of the room. + CurrentStateSnapshotNID() StateSnapshotNID // Store the previous events referenced by an event. // This adds the event NID to an entry in the database for each of the previous events. // If there isn't an entry for one of previous events then an entry is created. @@ -143,7 +149,10 @@ type RoomRecentEventsUpdater interface { IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) // Set the list of latest events for the room. // This replaces the current list stored in the database with the given list - SetLatestEvents(roomNID RoomNID, latest []StateAtEventAndReference, lastEventNIDSent EventNID) error + SetLatestEvents( + roomNID RoomNID, latest []StateAtEventAndReference, lastEventNIDSent EventNID, + currentStateSnapshotNID StateSnapshotNID, + ) error // Check if the event has already be written to the output logs. HasEventBeenSent(eventNID EventNID) (bool, error) // Mark the event as having been sent to the output logs.