diff --git a/src/github.com/matrix-org/dendrite/cmd/dendrite-monolith-server/main.go b/src/github.com/matrix-org/dendrite/cmd/dendrite-monolith-server/main.go index 1a4c3a46b..82a77f2f9 100644 --- a/src/github.com/matrix-org/dendrite/cmd/dendrite-monolith-server/main.go +++ b/src/github.com/matrix-org/dendrite/cmd/dendrite-monolith-server/main.go @@ -15,6 +15,7 @@ package main import ( + "context" "flag" "net/http" "os" @@ -264,13 +265,13 @@ func (m *monolith) setupProducers() { } func (m *monolith) setupNotifiers() { - pos, err := m.syncAPIDB.SyncStreamPosition() + pos, err := m.syncAPIDB.SyncStreamPosition(context.Background()) if err != nil { log.Panicf("startup: failed to get latest sync stream position : %s", err) } m.syncAPINotifier = syncapi_sync.NewNotifier(syncapi_types.StreamPosition(pos)) - if err = m.syncAPINotifier.Load(m.syncAPIDB); err != nil { + if err = m.syncAPINotifier.Load(context.Background(), m.syncAPIDB); err != nil { log.Panicf("startup: failed to set up notifier: %s", err) } } diff --git a/src/github.com/matrix-org/dendrite/cmd/dendrite-sync-api-server/main.go b/src/github.com/matrix-org/dendrite/cmd/dendrite-sync-api-server/main.go index 7e9e4c128..b57786313 100644 --- a/src/github.com/matrix-org/dendrite/cmd/dendrite-sync-api-server/main.go +++ b/src/github.com/matrix-org/dendrite/cmd/dendrite-sync-api-server/main.go @@ -15,6 +15,7 @@ package main import ( + "context" "flag" "net/http" "os" @@ -67,13 +68,13 @@ func main() { log.Panicf("startup: failed to create account database with data source %s : %s", cfg.Database.Account, err) } - pos, err := db.SyncStreamPosition() + pos, err := db.SyncStreamPosition(context.Background()) if err != nil { log.Panicf("startup: failed to get latest sync stream position : %s", err) } n := sync.NewNotifier(types.StreamPosition(pos)) - if err = n.Load(db); err != nil { + if err = n.Load(context.Background(), db); err != nil { log.Panicf("startup: failed to set up notifier: %s", err) } diff --git a/src/github.com/matrix-org/dendrite/syncapi/consumers/clientapi.go b/src/github.com/matrix-org/dendrite/syncapi/consumers/clientapi.go index 7cc38b815..60ed01b66 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/consumers/clientapi.go +++ b/src/github.com/matrix-org/dendrite/syncapi/consumers/clientapi.go @@ -15,6 +15,7 @@ package consumers import ( + "context" "encoding/json" log "github.com/Sirupsen/logrus" @@ -77,7 +78,9 @@ func (s *OutputClientData) onMessage(msg *sarama.ConsumerMessage) error { "room_id": output.RoomID, }).Info("received data from client API server") - syncStreamPos, err := s.db.UpsertAccountData(string(msg.Key), output.RoomID, output.Type) + syncStreamPos, err := s.db.UpsertAccountData( + context.TODO(), string(msg.Key), output.RoomID, output.Type, + ) if err != nil { log.WithFields(log.Fields{ "type": output.Type, diff --git a/src/github.com/matrix-org/dendrite/syncapi/consumers/roomserver.go b/src/github.com/matrix-org/dendrite/syncapi/consumers/roomserver.go index 364a91a87..fc547e92b 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/consumers/roomserver.go +++ b/src/github.com/matrix-org/dendrite/syncapi/consumers/roomserver.go @@ -122,7 +122,11 @@ func (s *OutputRoomEvent) onMessage(msg *sarama.ConsumerMessage) error { } syncStreamPos, err := s.db.WriteEvent( - &ev, addsStateEvents, output.NewRoomEvent.AddsStateEventIDs, output.NewRoomEvent.RemovesStateEventIDs, + context.TODO(), + &ev, + addsStateEvents, + output.NewRoomEvent.AddsStateEventIDs, + output.NewRoomEvent.RemovesStateEventIDs, ) if err != nil { @@ -157,7 +161,7 @@ func (s *OutputRoomEvent) lookupStateEvents( // Check if this is re-adding a state events that we previously processed // If we have previously received a state event it may still be in // our event database. - result, err := s.db.Events(addsStateEventIDs) + result, err := s.db.Events(context.TODO(), addsStateEventIDs) if err != nil { return nil, err } @@ -205,7 +209,9 @@ func (s *OutputRoomEvent) updateStateEvent(event gomatrixserverlib.Event) (gomat stateKey = *event.StateKey() } - prevEvent, err := s.db.GetStateEvent(event.Type(), event.RoomID(), stateKey) + prevEvent, err := s.db.GetStateEvent( + context.TODO(), event.Type(), event.RoomID(), stateKey, + ) if err != nil { return event, err } diff --git a/src/github.com/matrix-org/dendrite/syncapi/storage/account_data_table.go b/src/github.com/matrix-org/dendrite/syncapi/storage/account_data_table.go index 183e0d192..4a23d6977 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/storage/account_data_table.go +++ b/src/github.com/matrix-org/dendrite/syncapi/storage/account_data_table.go @@ -15,6 +15,7 @@ package storage import ( + "context" "database/sql" "github.com/matrix-org/dendrite/syncapi/types" @@ -71,14 +72,18 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { } func (s *accountDataStatements) insertAccountData( - pos types.StreamPosition, userID string, roomID string, dataType string, + ctx context.Context, + pos types.StreamPosition, + userID, roomID, dataType string, ) (err error) { - _, err = s.insertAccountDataStmt.Exec(pos, userID, roomID, dataType) + _, err = s.insertAccountDataStmt.ExecContext(ctx, pos, userID, roomID, dataType) return } func (s *accountDataStatements) selectAccountDataInRange( - userID string, oldPos types.StreamPosition, newPos types.StreamPosition, + ctx context.Context, + userID string, + oldPos, newPos types.StreamPosition, ) (data map[string][]string, err error) { data = make(map[string][]string) @@ -89,7 +94,7 @@ func (s *accountDataStatements) selectAccountDataInRange( oldPos-- } - rows, err := s.selectAccountDataInRangeStmt.Query(userID, oldPos, newPos) + rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, oldPos, newPos) if err != nil { return } diff --git a/src/github.com/matrix-org/dendrite/syncapi/storage/current_room_state_table.go b/src/github.com/matrix-org/dendrite/syncapi/storage/current_room_state_table.go index 10933e965..307af4338 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/storage/current_room_state_table.go +++ b/src/github.com/matrix-org/dendrite/syncapi/storage/current_room_state_table.go @@ -15,6 +15,7 @@ package storage import ( + "context" "database/sql" "github.com/lib/pq" @@ -114,8 +115,10 @@ func (s *currentRoomStateStatements) prepare(db *sql.DB) (err error) { } // JoinedMemberLists returns a map of room ID to a list of joined user IDs. -func (s *currentRoomStateStatements) selectJoinedUsers() (map[string][]string, error) { - rows, err := s.selectJoinedUsersStmt.Query() +func (s *currentRoomStateStatements) selectJoinedUsers( + ctx context.Context, +) (map[string][]string, error) { + rows, err := s.selectJoinedUsersStmt.QueryContext(ctx) if err != nil { return nil, err } @@ -136,8 +139,11 @@ func (s *currentRoomStateStatements) selectJoinedUsers() (map[string][]string, e } // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. -func (s *currentRoomStateStatements) selectRoomIDsWithMembership(txn *sql.Tx, userID, membership string) ([]string, error) { - rows, err := common.TxStmt(txn, s.selectRoomIDsWithMembershipStmt).Query(userID, membership) +func (s *currentRoomStateStatements) selectRoomIDsWithMembership( + ctx context.Context, txn *sql.Tx, userID, membership string, +) ([]string, error) { + stmt := common.TxStmt(txn, s.selectRoomIDsWithMembershipStmt) + rows, err := stmt.QueryContext(ctx, userID, membership) if err != nil { return nil, err } @@ -155,8 +161,11 @@ func (s *currentRoomStateStatements) selectRoomIDsWithMembership(txn *sql.Tx, us } // CurrentState returns all the current state events for the given room. -func (s *currentRoomStateStatements) selectCurrentState(txn *sql.Tx, roomID string) ([]gomatrixserverlib.Event, error) { - rows, err := common.TxStmt(txn, s.selectCurrentStateStmt).Query(roomID) +func (s *currentRoomStateStatements) selectCurrentState( + ctx context.Context, txn *sql.Tx, roomID string, +) ([]gomatrixserverlib.Event, error) { + stmt := common.TxStmt(txn, s.selectCurrentStateStmt) + rows, err := stmt.QueryContext(ctx, roomID) if err != nil { return nil, err } @@ -165,22 +174,37 @@ func (s *currentRoomStateStatements) selectCurrentState(txn *sql.Tx, roomID stri return rowsToEvents(rows) } -func (s *currentRoomStateStatements) deleteRoomStateByEventID(txn *sql.Tx, eventID string) error { - _, err := common.TxStmt(txn, s.deleteRoomStateByEventIDStmt).Exec(eventID) +func (s *currentRoomStateStatements) deleteRoomStateByEventID( + ctx context.Context, txn *sql.Tx, eventID string, +) error { + stmt := common.TxStmt(txn, s.deleteRoomStateByEventIDStmt) + _, err := stmt.ExecContext(ctx, eventID) return err } func (s *currentRoomStateStatements) upsertRoomState( - txn *sql.Tx, event gomatrixserverlib.Event, membership *string, addedAt int64, + ctx context.Context, txn *sql.Tx, + event gomatrixserverlib.Event, membership *string, addedAt int64, ) error { - _, err := common.TxStmt(txn, s.upsertRoomStateStmt).Exec( - event.RoomID(), event.EventID(), event.Type(), *event.StateKey(), event.JSON(), membership, addedAt, + stmt := common.TxStmt(txn, s.upsertRoomStateStmt) + _, err := stmt.ExecContext( + ctx, + event.RoomID(), + event.EventID(), + event.Type(), + *event.StateKey(), + event.JSON(), + membership, + addedAt, ) return err } -func (s *currentRoomStateStatements) selectEventsWithEventIDs(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) { - rows, err := common.TxStmt(txn, s.selectEventsWithEventIDsStmt).Query(pq.StringArray(eventIDs)) +func (s *currentRoomStateStatements) selectEventsWithEventIDs( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) ([]streamEvent, error) { + stmt := common.TxStmt(txn, s.selectEventsWithEventIDsStmt) + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } @@ -205,11 +229,18 @@ func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.Event, error) { return result, nil } -func (s *currentRoomStateStatements) selectStateEvent(evType string, roomID string, stateKey string) (*gomatrixserverlib.Event, error) { +func (s *currentRoomStateStatements) selectStateEvent( + ctx context.Context, evType string, roomID string, stateKey string, +) (*gomatrixserverlib.Event, error) { + stmt := s.selectStateEventStmt var res []byte - if err := s.selectStateEventStmt.QueryRow(evType, roomID, stateKey).Scan(&res); err == sql.ErrNoRows { + err := stmt.QueryRowContext(ctx, evType, roomID, stateKey).Scan(&res) + if err == sql.ErrNoRows { return nil, nil } + if err != nil { + return nil, err + } ev, err := gomatrixserverlib.NewEventFromTrustedJSON(res, false) return &ev, err } diff --git a/src/github.com/matrix-org/dendrite/syncapi/storage/output_room_events_table.go b/src/github.com/matrix-org/dendrite/syncapi/storage/output_room_events_table.go index 2472754db..7ae24990a 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/storage/output_room_events_table.go +++ b/src/github.com/matrix-org/dendrite/syncapi/storage/output_room_events_table.go @@ -15,6 +15,7 @@ package storage import ( + "context" "database/sql" log "github.com/Sirupsen/logrus" @@ -104,9 +105,11 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // two positions, only the most recent state is returned. func (s *outputRoomEventsStatements) selectStateInRange( - txn *sql.Tx, oldPos, newPos types.StreamPosition, + ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition, ) (map[string]map[string]bool, map[string]streamEvent, error) { - rows, err := common.TxStmt(txn, s.selectStateInRangeStmt).Query(oldPos, newPos) + stmt := common.TxStmt(txn, s.selectStateInRangeStmt) + + rows, err := stmt.QueryContext(ctx, oldPos, newPos) if err != nil { return nil, nil, err } @@ -167,9 +170,12 @@ func (s *outputRoomEventsStatements) selectStateInRange( // MaxID returns the ID of the last inserted event in this table. 'txn' is optional. If it is not supplied, // then this function should only ever be used at startup, as it will race with inserting events if it is // done afterwards. If there are no inserted events, 0 is returned. -func (s *outputRoomEventsStatements) selectMaxID(txn *sql.Tx) (id int64, err error) { +func (s *outputRoomEventsStatements) selectMaxID( + ctx context.Context, txn *sql.Tx, +) (id int64, err error) { var nullableID sql.NullInt64 - err = common.TxStmt(txn, s.selectMaxIDStmt).QueryRow().Scan(&nullableID) + stmt := common.TxStmt(txn, s.selectMaxIDStmt) + err = stmt.QueryRowContext(ctx).Scan(&nullableID) if nullableID.Valid { id = nullableID.Int64 } @@ -178,18 +184,29 @@ func (s *outputRoomEventsStatements) selectMaxID(txn *sql.Tx) (id int64, err err // InsertEvent into the output_room_events table. addState and removeState are an optional list of state event IDs. Returns the position // of the inserted event. -func (s *outputRoomEventsStatements) insertEvent(txn *sql.Tx, event *gomatrixserverlib.Event, addState, removeState []string) (streamPos int64, err error) { - err = common.TxStmt(txn, s.insertEventStmt).QueryRow( - event.RoomID(), event.EventID(), event.JSON(), pq.StringArray(addState), pq.StringArray(removeState), +func (s *outputRoomEventsStatements) insertEvent( + ctx context.Context, txn *sql.Tx, + event *gomatrixserverlib.Event, addState, removeState []string, +) (streamPos int64, err error) { + stmt := common.TxStmt(txn, s.insertEventStmt) + err = stmt.QueryRowContext( + ctx, + event.RoomID(), + event.EventID(), + event.JSON(), + pq.StringArray(addState), + pq.StringArray(removeState), ).Scan(&streamPos) return } // RecentEventsInRoom returns the most recent events in the given room, up to a maximum of 'limit'. func (s *outputRoomEventsStatements) selectRecentEvents( - _ *sql.Tx, roomID string, fromPos, toPos types.StreamPosition, limit int, + ctx context.Context, txn *sql.Tx, + roomID string, fromPos, toPos types.StreamPosition, limit int, ) ([]streamEvent, error) { - rows, err := s.selectRecentEventsStmt.Query(roomID, fromPos, toPos, limit) + stmt := common.TxStmt(txn, s.selectRecentEventsStmt) + rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) if err != nil { return nil, err } @@ -205,8 +222,11 @@ func (s *outputRoomEventsStatements) selectRecentEvents( // Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing // from the database. -func (s *outputRoomEventsStatements) selectEvents(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) { - rows, err := common.TxStmt(txn, s.selectEventsStmt).Query(pq.StringArray(eventIDs)) +func (s *outputRoomEventsStatements) selectEvents( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) ([]streamEvent, error) { + stmt := common.TxStmt(txn, s.selectEventsStmt) + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } diff --git a/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go b/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go index f7d5ebd7b..925a1233f 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go +++ b/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go @@ -15,6 +15,7 @@ package storage import ( + "context" "database/sql" "fmt" // Import the postgres database driver. @@ -75,16 +76,16 @@ func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) { } // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. -func (d *SyncServerDatabase) AllJoinedUsersInRooms() (map[string][]string, error) { - return d.roomstate.selectJoinedUsers() +func (d *SyncServerDatabase) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { + return d.roomstate.selectJoinedUsers(ctx) } // Events lookups a list of event by their event ID. // Returns a list of events matching the requested IDs found in the database. // If an event is not found in the database then it will be omitted from the list. // Returns an error if there was a problem talking with the database -func (d *SyncServerDatabase) Events(eventIDs []string) ([]gomatrixserverlib.Event, error) { - streamEvents, err := d.events.selectEvents(nil, eventIDs) +func (d *SyncServerDatabase) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) { + streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs) if err != nil { return nil, err } @@ -95,11 +96,14 @@ func (d *SyncServerDatabase) Events(eventIDs []string) ([]gomatrixserverlib.Even // when generating the stream position for this event. Returns the sync stream position for the inserted event. // Returns an error if there was a problem inserting this event. func (d *SyncServerDatabase) WriteEvent( - ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string, + ctx context.Context, + ev *gomatrixserverlib.Event, + addStateEvents []gomatrixserverlib.Event, + addStateEventIDs, removeStateEventIDs []string, ) (streamPos types.StreamPosition, returnErr error) { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { var err error - pos, err := d.events.insertEvent(txn, ev, addStateEventIDs, removeStateEventIDs) + pos, err := d.events.insertEvent(ctx, txn, ev, addStateEventIDs, removeStateEventIDs) if err != nil { return err } @@ -110,17 +114,20 @@ func (d *SyncServerDatabase) WriteEvent( return nil } - return d.updateRoomState(txn, removeStateEventIDs, addStateEvents, streamPos) + return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, streamPos) }) return } func (d *SyncServerDatabase) updateRoomState( - txn *sql.Tx, removedEventIDs []string, addedEvents []gomatrixserverlib.Event, streamPos types.StreamPosition, + ctx context.Context, txn *sql.Tx, + removedEventIDs []string, + addedEvents []gomatrixserverlib.Event, + streamPos types.StreamPosition, ) error { // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. for _, eventID := range removedEventIDs { - if err := d.roomstate.deleteRoomStateByEventID(txn, eventID); err != nil { + if err := d.roomstate.deleteRoomStateByEventID(ctx, txn, eventID); err != nil { return err } } @@ -138,7 +145,7 @@ func (d *SyncServerDatabase) updateRoomState( } membership = &value } - if err := d.roomstate.upsertRoomState(txn, event, membership, int64(streamPos)); err != nil { + if err := d.roomstate.upsertRoomState(ctx, txn, event, membership, int64(streamPos)); err != nil { return err } } @@ -149,8 +156,10 @@ func (d *SyncServerDatabase) updateRoomState( // GetStateEvent returns the Matrix state event of a given type for a given room with a given state key // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error -func (d *SyncServerDatabase) GetStateEvent(evType string, roomID string, stateKey string) (*gomatrixserverlib.Event, error) { - return d.roomstate.selectStateEvent(evType, roomID, stateKey) +func (d *SyncServerDatabase) GetStateEvent( + ctx context.Context, evType, roomID, stateKey string, +) (*gomatrixserverlib.Event, error) { + return d.roomstate.selectStateEvent(ctx, evType, roomID, stateKey) } // PartitionOffsets implements common.PartitionStorer @@ -164,8 +173,8 @@ func (d *SyncServerDatabase) SetPartitionOffset(topic string, partition int32, o } // SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. -func (d *SyncServerDatabase) SyncStreamPosition() (types.StreamPosition, error) { - id, err := d.events.selectMaxID(nil) +func (d *SyncServerDatabase) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) { + id, err := d.events.selectMaxID(ctx, nil) if err != nil { return types.StreamPosition(0), err } @@ -173,13 +182,18 @@ func (d *SyncServerDatabase) SyncStreamPosition() (types.StreamPosition, error) } // IncrementalSync returns all the data needed in order to create an incremental sync response. -func (d *SyncServerDatabase) IncrementalSync(userID string, fromPos, toPos types.StreamPosition, numRecentEventsPerRoom int) (res *types.Response, returnErr error) { +func (d *SyncServerDatabase) IncrementalSync( + ctx context.Context, + userID string, + fromPos, toPos types.StreamPosition, + numRecentEventsPerRoom int, +) (res *types.Response, returnErr error) { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { // Work out which rooms to return in the response. This is done by getting not only the currently // joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions. // This works out what the 'state' key should be for each room as well as which membership block // to put the room into. - deltas, err := d.getStateDeltas(txn, fromPos, toPos, userID) + deltas, err := d.getStateDeltas(ctx, txn, fromPos, toPos, userID) if err != nil { return err } @@ -196,7 +210,9 @@ func (d *SyncServerDatabase) IncrementalSync(userID string, fromPos, toPos types // This is all "okay" assuming history_visibility == "shared" which it is by default. endPos = delta.membershipPos } - recentStreamEvents, err := d.events.selectRecentEvents(txn, delta.roomID, fromPos, endPos, numRecentEventsPerRoom) + recentStreamEvents, err := d.events.selectRecentEvents( + ctx, txn, delta.roomID, fromPos, endPos, numRecentEventsPerRoom, + ) if err != nil { return err } @@ -224,27 +240,29 @@ func (d *SyncServerDatabase) IncrementalSync(userID string, fromPos, toPos types } // TODO: This should be done in getStateDeltas - return d.addInvitesToResponse(txn, userID, res) + return d.addInvitesToResponse(ctx, txn, userID, res) }) return } // CompleteSync a complete /sync API response for the given user. -func (d *SyncServerDatabase) CompleteSync(userID string, numRecentEventsPerRoom int) (res *types.Response, returnErr error) { +func (d *SyncServerDatabase) CompleteSync( + ctx context.Context, userID string, numRecentEventsPerRoom int, +) (res *types.Response, returnErr error) { // This needs to be all done in a transaction as we need to do multiple SELECTs, and we need to have // a consistent view of the database throughout. This includes extracting the sync stream position. // This does have the unfortunate side-effect that all the matrixy logic resides in this function, // but it's better to not hide the fact that this is being done in a transaction. returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { // Get the current stream position which we will base the sync response on. - id, err := d.events.selectMaxID(txn) + id, err := d.events.selectMaxID(ctx, txn) if err != nil { return err } pos := types.StreamPosition(id) // Extract room state and recent events for all rooms the user is joined to. - roomIDs, err := d.roomstate.selectRoomIDsWithMembership(txn, userID, "join") + roomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, "join") if err != nil { return err } @@ -252,14 +270,14 @@ func (d *SyncServerDatabase) CompleteSync(userID string, numRecentEventsPerRoom // Build up a /sync response. Add joined rooms. res = types.NewResponse(pos) for _, roomID := range roomIDs { - stateEvents, err := d.roomstate.selectCurrentState(txn, roomID) + stateEvents, err := d.roomstate.selectCurrentState(ctx, txn, roomID) if err != nil { return err } // TODO: When filters are added, we may need to call this multiple times to get enough events. // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 recentStreamEvents, err := d.events.selectRecentEvents( - txn, roomID, types.StreamPosition(0), pos, numRecentEventsPerRoom, + ctx, txn, roomID, types.StreamPosition(0), pos, numRecentEventsPerRoom, ) if err != nil { return err @@ -274,7 +292,7 @@ func (d *SyncServerDatabase) CompleteSync(userID string, numRecentEventsPerRoom res.Rooms.Join[roomID] = *jr } - return d.addInvitesToResponse(txn, userID, res) + return d.addInvitesToResponse(ctx, txn, userID, res) }) return } @@ -285,9 +303,9 @@ func (d *SyncServerDatabase) CompleteSync(userID string, numRecentEventsPerRoom // If no data is retrieved, returns an empty map // If there was an issue with the retrieval, returns an error func (d *SyncServerDatabase) GetAccountDataInRange( - userID string, oldPos types.StreamPosition, newPos types.StreamPosition, + ctx context.Context, userID string, oldPos, newPos types.StreamPosition, ) (map[string][]string, error) { - return d.accountData.selectAccountDataInRange(userID, oldPos, newPos) + return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos) } // UpsertAccountData keeps track of new or updated account data, by saving the type @@ -296,19 +314,22 @@ func (d *SyncServerDatabase) GetAccountDataInRange( // If no data with the given type, user ID and room ID exists in the database, // creates a new row, else update the existing one // Returns an error if there was an issue with the upsert -func (d *SyncServerDatabase) UpsertAccountData(userID string, roomID string, dataType string) (types.StreamPosition, error) { - pos, err := d.SyncStreamPosition() +func (d *SyncServerDatabase) UpsertAccountData( + ctx context.Context, userID, roomID, dataType string, +) (types.StreamPosition, error) { + pos, err := d.SyncStreamPosition(ctx) if err != nil { return pos, err } - err = d.accountData.insertAccountData(pos, userID, roomID, dataType) + err = d.accountData.insertAccountData(ctx, pos, userID, roomID, dataType) return pos, err } -func (d *SyncServerDatabase) addInvitesToResponse(txn *sql.Tx, userID string, res *types.Response) error { +func (d *SyncServerDatabase) addInvitesToResponse( + ctx context.Context, txn *sql.Tx, userID string, res *types.Response) error { // Add invites - TODO: This will break over federation as they won't be in the current state table according to Mark. - roomIDs, err := d.roomstate.selectRoomIDsWithMembership(txn, userID, "invite") + roomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, "invite") if err != nil { return err } @@ -322,7 +343,11 @@ func (d *SyncServerDatabase) addInvitesToResponse(txn *sql.Tx, userID string, re // fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database. // Returns a map of room ID to list of events. -func (d *SyncServerDatabase) fetchStateEvents(txn *sql.Tx, roomIDToEventIDSet map[string]map[string]bool, eventIDToEvent map[string]streamEvent) (map[string][]streamEvent, error) { +func (d *SyncServerDatabase) fetchStateEvents( + ctx context.Context, txn *sql.Tx, + roomIDToEventIDSet map[string]map[string]bool, + eventIDToEvent map[string]streamEvent, +) (map[string][]streamEvent, error) { stateBetween := make(map[string][]streamEvent) missingEvents := make(map[string][]string) for roomID, ids := range roomIDToEventIDSet { @@ -350,7 +375,7 @@ func (d *SyncServerDatabase) fetchStateEvents(txn *sql.Tx, roomIDToEventIDSet ma for _, missingEvIDs := range missingEvents { allMissingEventIDs = append(allMissingEventIDs, missingEvIDs...) } - evs, err := d.fetchMissingStateEvents(txn, allMissingEventIDs) + evs, err := d.fetchMissingStateEvents(ctx, txn, allMissingEventIDs) if err != nil { return nil, err } @@ -363,10 +388,12 @@ func (d *SyncServerDatabase) fetchStateEvents(txn *sql.Tx, roomIDToEventIDSet ma return stateBetween, nil } -func (d *SyncServerDatabase) fetchMissingStateEvents(txn *sql.Tx, eventIDs []string) ([]streamEvent, error) { +func (d *SyncServerDatabase) fetchMissingStateEvents( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) ([]streamEvent, error) { // Fetch from the events table first so we pick up the stream ID for the // event. - events, err := d.events.selectEvents(txn, eventIDs) + events, err := d.events.selectEvents(ctx, txn, eventIDs) if err != nil { return nil, err } @@ -388,7 +415,7 @@ func (d *SyncServerDatabase) fetchMissingStateEvents(txn *sql.Tx, eventIDs []str // If they are missing from the events table then they should be state // events that we received from outside the main event stream. // These should be in the room state table. - stateEvents, err := d.roomstate.selectEventsWithEventIDs(txn, missing) + stateEvents, err := d.roomstate.selectEventsWithEventIDs(ctx, txn, missing) if err != nil { return nil, err @@ -402,7 +429,10 @@ func (d *SyncServerDatabase) fetchMissingStateEvents(txn *sql.Tx, eventIDs []str return events, nil } -func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.StreamPosition, userID string) ([]stateDelta, error) { +func (d *SyncServerDatabase) getStateDeltas( + ctx context.Context, txn *sql.Tx, + fromPos, toPos types.StreamPosition, userID string, +) ([]stateDelta, error) { // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // - Get membership list changes for this user in this sync response // - For each room which has membership list changes: @@ -414,11 +444,11 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St var deltas []stateDelta // get all the state events ever between these two positions - stateNeeded, eventMap, err := d.events.selectStateInRange(txn, fromPos, toPos) + stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos) if err != nil { return nil, err } - state, err := d.fetchStateEvents(txn, stateNeeded, eventMap) + state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) if err != nil { return nil, err } @@ -434,7 +464,7 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St if membership == "join" { // send full room state down instead of a delta var allState []gomatrixserverlib.Event - allState, err = d.roomstate.selectCurrentState(txn, roomID) + allState, err = d.roomstate.selectCurrentState(ctx, txn, roomID) if err != nil { return nil, err } @@ -458,7 +488,7 @@ func (d *SyncServerDatabase) getStateDeltas(txn *sql.Tx, fromPos, toPos types.St } // Add in currently joined rooms - joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(txn, userID, "join") + joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, "join") if err != nil { return nil, err } diff --git a/src/github.com/matrix-org/dendrite/syncapi/sync/notifier.go b/src/github.com/matrix-org/dendrite/syncapi/sync/notifier.go index 42007e158..9d0c8be5f 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/sync/notifier.go +++ b/src/github.com/matrix-org/dendrite/syncapi/sync/notifier.go @@ -15,6 +15,7 @@ package sync import ( + "context" "sync" log "github.com/Sirupsen/logrus" @@ -131,8 +132,8 @@ func (n *Notifier) WaitForEvents(req syncRequest) types.StreamPosition { } // Load the membership states required to notify users correctly. -func (n *Notifier) Load(db *storage.SyncServerDatabase) error { - roomToUsers, err := db.AllJoinedUsersInRooms() +func (n *Notifier) Load(ctx context.Context, db *storage.SyncServerDatabase) error { + roomToUsers, err := db.AllJoinedUsersInRooms(ctx) if err != nil { return err } diff --git a/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go b/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go index 922ee5d88..c9b86a6f1 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go +++ b/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go @@ -108,9 +108,9 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype func (rp *RequestPool) currentSyncForUser(req syncRequest, currentPos types.StreamPosition) (*types.Response, error) { // TODO: handle ignored users if req.since == types.StreamPosition(0) { - return rp.db.CompleteSync(req.userID, req.limit) + return rp.db.CompleteSync(req.ctx, req.userID, req.limit) } - return rp.db.IncrementalSync(req.userID, req.since, currentPos, req.limit) + return rp.db.IncrementalSync(req.ctx, req.userID, req.since, currentPos, req.limit) } func (rp *RequestPool) appendAccountData( @@ -145,7 +145,7 @@ func (rp *RequestPool) appendAccountData( } // Sync is not initial, get all account data since the latest sync - dataTypes, err := rp.db.GetAccountDataInRange(userID, req.since, currentPos) + dataTypes, err := rp.db.GetAccountDataInRange(req.ctx, userID, req.since, currentPos) if err != nil { return nil, err }