0
0
Fork 0
mirror of https://github.com/matrix-org/dendrite synced 2024-12-14 19:33:50 +01:00

Add contexts to the roomserver storage layer (#229)

* Add contexts to the roomserver storage layer

* Fix rooms_table
This commit is contained in:
Mark Haines 2017-09-13 16:30:19 +01:00 committed by GitHub
parent 3133bef797
commit bfcce5bd21
21 changed files with 744 additions and 379 deletions

View file

@ -32,16 +32,16 @@ import (
type RoomserverAliasAPIDatabase interface { type RoomserverAliasAPIDatabase interface {
// Save a given room alias with the room ID it refers to. // Save a given room alias with the room ID it refers to.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
SetRoomAlias(alias string, roomID string) error SetRoomAlias(ctx context.Context, alias string, roomID string) error
// Look up the room ID a given alias refers to. // Look up the room ID a given alias refers to.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
GetRoomIDFromAlias(alias string) (string, error) GetRoomIDFromAlias(ctx context.Context, alias string) (string, error)
// Look up all aliases referring to a given room ID. // Look up all aliases referring to a given room ID.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
GetAliasesFromRoomID(roomID string) ([]string, error) GetAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error)
// Remove a given room alias. // Remove a given room alias.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
RemoveRoomAlias(alias string) error RemoveRoomAlias(ctx context.Context, alias string) error
} }
// RoomserverAliasAPI is an implementation of api.RoomserverAliasAPI // RoomserverAliasAPI is an implementation of api.RoomserverAliasAPI
@ -59,7 +59,7 @@ func (r *RoomserverAliasAPI) SetRoomAlias(
response *api.SetRoomAliasResponse, response *api.SetRoomAliasResponse,
) error { ) error {
// Check if the alias isn't already referring to a room // Check if the alias isn't already referring to a room
roomID, err := r.DB.GetRoomIDFromAlias(request.Alias) roomID, err := r.DB.GetRoomIDFromAlias(ctx, request.Alias)
if err != nil { if err != nil {
return err return err
} }
@ -71,7 +71,7 @@ func (r *RoomserverAliasAPI) SetRoomAlias(
response.AliasExists = false response.AliasExists = false
// Save the new alias // Save the new alias
if err := r.DB.SetRoomAlias(request.Alias, request.RoomID); err != nil { if err := r.DB.SetRoomAlias(ctx, request.Alias, request.RoomID); err != nil {
return err return err
} }
@ -93,7 +93,7 @@ func (r *RoomserverAliasAPI) GetAliasRoomID(
response *api.GetAliasRoomIDResponse, response *api.GetAliasRoomIDResponse,
) error { ) error {
// Look up the room ID in the database // Look up the room ID in the database
roomID, err := r.DB.GetRoomIDFromAlias(request.Alias) roomID, err := r.DB.GetRoomIDFromAlias(ctx, request.Alias)
if err != nil { if err != nil {
return err return err
} }
@ -109,18 +109,21 @@ func (r *RoomserverAliasAPI) RemoveRoomAlias(
response *api.RemoveRoomAliasResponse, response *api.RemoveRoomAliasResponse,
) error { ) error {
// Look up the room ID in the database // Look up the room ID in the database
roomID, err := r.DB.GetRoomIDFromAlias(request.Alias) roomID, err := r.DB.GetRoomIDFromAlias(ctx, request.Alias)
if err != nil { if err != nil {
return err return err
} }
// Remove the dalias from the database // Remove the dalias from the database
if err := r.DB.RemoveRoomAlias(request.Alias); err != nil { if err := r.DB.RemoveRoomAlias(ctx, request.Alias); err != nil {
return err return err
} }
// Send an updated m.room.aliases event // Send an updated m.room.aliases event
if err := r.sendUpdatedAliasesEvent(ctx, request.UserID, roomID); err != nil { // At this point we've already committed the alias to the database so we
// shouldn't cancel this request.
// TODO: Ensure that we send unsent events when if server restarts.
if err := r.sendUpdatedAliasesEvent(context.TODO(), request.UserID, roomID); err != nil {
return err return err
} }
@ -147,7 +150,7 @@ func (r *RoomserverAliasAPI) sendUpdatedAliasesEvent(
// Retrieve the updated list of aliases, marhal it and set it as the // Retrieve the updated list of aliases, marhal it and set it as the
// event's content // event's content
aliases, err := r.DB.GetAliasesFromRoomID(roomID) aliases, err := r.DB.GetAliasesFromRoomID(ctx, roomID)
if err != nil { if err != nil {
return err return err
} }

View file

@ -15,16 +15,23 @@
package input package input
import ( import (
"context"
"sort"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"sort"
) )
// checkAuthEvents checks that the event passes authentication checks // checkAuthEvents checks that the event passes authentication checks
// Returns the numeric IDs for the auth events. // Returns the numeric IDs for the auth events.
func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEventIDs []string) ([]types.EventNID, error) { func checkAuthEvents(
ctx context.Context,
db RoomEventDatabase,
event gomatrixserverlib.Event,
authEventIDs []string,
) ([]types.EventNID, error) {
// Grab the numeric IDs for the supplied auth state events from the database. // Grab the numeric IDs for the supplied auth state events from the database.
authStateEntries, err := db.StateEntriesForEventIDs(authEventIDs) authStateEntries, err := db.StateEntriesForEventIDs(ctx, authEventIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -34,7 +41,7 @@ func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEv
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event}) stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event})
// Load the actual auth events from the database. // Load the actual auth events from the database.
authEvents, err := loadAuthEvents(db, stateNeeded, authStateEntries) authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -84,7 +91,10 @@ func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Even
} }
func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *gomatrixserverlib.Event { func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *gomatrixserverlib.Event {
eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, types.EmptyStateKeyNID}) eventNID, ok := ae.state.lookup(types.StateKeyTuple{
EventTypeNID: typeNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
if !ok { if !ok {
return nil return nil
} }
@ -100,7 +110,10 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
if !ok { if !ok {
return nil return nil
} }
eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, stateKeyNID}) eventNID, ok := ae.state.lookup(types.StateKeyTuple{
EventTypeNID: typeNID,
EventStateKeyNID: stateKeyNID,
})
if !ok { if !ok {
return nil return nil
} }
@ -113,6 +126,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
// loadAuthEvents loads the events needed for authentication from the supplied room state. // loadAuthEvents loads the events needed for authentication from the supplied room state.
func loadAuthEvents( func loadAuthEvents(
ctx context.Context,
db RoomEventDatabase, db RoomEventDatabase,
needed gomatrixserverlib.StateNeeded, needed gomatrixserverlib.StateNeeded,
state []types.StateEntry, state []types.StateEntry,
@ -121,7 +135,7 @@ func loadAuthEvents(
var neededStateKeys []string var neededStateKeys []string
neededStateKeys = append(neededStateKeys, needed.Member...) neededStateKeys = append(neededStateKeys, needed.Member...)
neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...) neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
if result.stateKeyNIDMap, err = db.EventStateKeyNIDs(neededStateKeys); err != nil { if result.stateKeyNIDMap, err = db.EventStateKeyNIDs(ctx, neededStateKeys); err != nil {
return return
} }
@ -135,34 +149,52 @@ func loadAuthEvents(
eventNIDs = append(eventNIDs, eventNID) eventNIDs = append(eventNIDs, eventNID)
} }
} }
if result.events, err = db.Events(eventNIDs); err != nil { if result.events, err = db.Events(ctx, eventNIDs); err != nil {
return return
} }
return return
} }
// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events. // stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events.
func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple { func stateKeyTuplesNeeded(
stateKeyNIDMap map[string]types.EventStateKeyNID,
stateNeeded gomatrixserverlib.StateNeeded,
) []types.StateKeyTuple {
var keyTuples []types.StateKeyTuple var keyTuples []types.StateKeyTuple
if stateNeeded.Create { if stateNeeded.Create {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomCreateNID, types.EmptyStateKeyNID}) keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomCreateNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
} }
if stateNeeded.PowerLevels { if stateNeeded.PowerLevels {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomPowerLevelsNID, types.EmptyStateKeyNID}) keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomPowerLevelsNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
} }
if stateNeeded.JoinRules { if stateNeeded.JoinRules {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomJoinRulesNID, types.EmptyStateKeyNID}) keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomJoinRulesNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
} }
for _, member := range stateNeeded.Member { for _, member := range stateNeeded.Member {
stateKeyNID, ok := stateKeyNIDMap[member] stateKeyNID, ok := stateKeyNIDMap[member]
if ok { if ok {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomMemberNID, stateKeyNID}) keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomMemberNID,
EventStateKeyNID: stateKeyNID,
})
} }
} }
for _, token := range stateNeeded.ThirdPartyInvite { for _, token := range stateNeeded.ThirdPartyInvite {
stateKeyNID, ok := stateKeyNIDMap[token] stateKeyNID, ok := stateKeyNIDMap[token]
if ok { if ok {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomThirdPartyInviteNID, stateKeyNID}) keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomThirdPartyInviteNID,
EventStateKeyNID: stateKeyNID,
})
} }
} }
return keyTuples return keyTuples

View file

@ -15,6 +15,7 @@
package input package input
import ( import (
"context"
"fmt" "fmt"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
@ -28,22 +29,38 @@ import (
type RoomEventDatabase interface { type RoomEventDatabase interface {
state.RoomStateDatabase state.RoomStateDatabase
// Stores a matrix room event in the database // Stores a matrix room event in the database
StoreEvent(event gomatrixserverlib.Event, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) StoreEvent(
ctx context.Context,
event gomatrixserverlib.Event,
authEventNIDs []types.EventNID,
) (types.RoomNID, types.StateAtEvent, error)
// Look up the state entries for a list of string event IDs // Look up the state entries for a list of string event IDs
// Returns an error if the there is an error talking to the database // Returns an error if the there is an error talking to the database
// Returns a types.MissingEventError if the event IDs aren't in the database. // Returns a types.MissingEventError if the event IDs aren't in the database.
StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error)
// Set the state at an event. // Set the state at an event.
SetState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error SetState(
ctx context.Context,
eventNID types.EventNID,
stateNID types.StateSnapshotNID,
) error
// Look up the latest events in a room in preparation for an update. // Look up the latest events in a room in preparation for an update.
// The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. // The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error.
// Returns the latest events in the room and the last eventID sent to the log along with an updater. // Returns the latest events in the room and the last eventID sent to the log along with an updater.
// If this returns an error then no further action is required. // If this returns an error then no further action is required.
GetLatestEventsForUpdate(roomNID types.RoomNID) (updater types.RoomRecentEventsUpdater, err error) GetLatestEventsForUpdate(
ctx context.Context, roomNID types.RoomNID,
) (updater types.RoomRecentEventsUpdater, err error)
// Look up the string event IDs for a list of numeric event IDs // Look up the string event IDs for a list of numeric event IDs
EventIDs(eventNIDs []types.EventNID) (map[types.EventNID]string, error) EventIDs(
ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error)
// Build a membership updater for the target user in a room. // Build a membership updater for the target user in a room.
MembershipUpdater(roomID, targerUserID string) (types.MembershipUpdater, error) MembershipUpdater(
ctx context.Context, roomID, targerUserID string,
) (types.MembershipUpdater, error)
} }
// OutputRoomEventWriter has the APIs needed to write an event to the output logs. // OutputRoomEventWriter has the APIs needed to write an event to the output logs.
@ -52,18 +69,23 @@ type OutputRoomEventWriter interface {
WriteOutputEvents(roomID string, updates []api.OutputEvent) error WriteOutputEvents(roomID string, updates []api.OutputEvent) error
} }
func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.InputRoomEvent) error { func processRoomEvent(
ctx context.Context,
db RoomEventDatabase,
ow OutputRoomEventWriter,
input api.InputRoomEvent,
) error {
// Parse and validate the event JSON // Parse and validate the event JSON
event := input.Event event := input.Event
// Check that the event passes authentication checks and work out the numeric IDs for the auth events. // Check that the event passes authentication checks and work out the numeric IDs for the auth events.
authEventNIDs, err := checkAuthEvents(db, event, input.AuthEventIDs) authEventNIDs, err := checkAuthEvents(ctx, db, event, input.AuthEventIDs)
if err != nil { if err != nil {
return err return err
} }
// Store the event // Store the event
roomNID, stateAtEvent, err := db.StoreEvent(event, authEventNIDs) roomNID, stateAtEvent, err := db.StoreEvent(ctx, event, authEventNIDs)
if err != nil { if err != nil {
return err return err
} }
@ -82,20 +104,20 @@ func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.
// We've been told what the state at the event is so we don't need to calculate it. // We've been told what the state at the event is so we don't need to calculate it.
// Check that those state events are in the database and store the state. // Check that those state events are in the database and store the state.
var entries []types.StateEntry var entries []types.StateEntry
if entries, err = db.StateEntriesForEventIDs(input.StateEventIDs); err != nil { if entries, err = db.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
return err return err
} }
if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(roomNID, nil, entries); err != nil { if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(ctx, roomNID, nil, entries); err != nil {
return nil return nil
} }
} else { } else {
// We haven't been told what the state at the event is so we need to calculate it from the prev_events // We haven't been told what the state at the event is so we need to calculate it from the prev_events
if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(db, event, roomNID); err != nil { if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(ctx, db, event, roomNID); err != nil {
return err return err
} }
} }
db.SetState(stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) db.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
} }
if input.Kind == api.KindBackfill { if input.Kind == api.KindBackfill {
@ -104,14 +126,19 @@ func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.
} }
// Update the extremities of the event graph for the room // Update the extremities of the event graph for the room
if err := updateLatestEvents(db, ow, roomNID, stateAtEvent, event, input.SendAsServer); err != nil { if err := updateLatestEvents(ctx, db, ow, roomNID, stateAtEvent, event, input.SendAsServer); err != nil {
return err return err
} }
return nil return nil
} }
func processInviteEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.InputInviteEvent) (err error) { func processInviteEvent(
ctx context.Context,
db RoomEventDatabase,
ow OutputRoomEventWriter,
input api.InputInviteEvent,
) (err error) {
if input.Event.StateKey() == nil { if input.Event.StateKey() == nil {
return fmt.Errorf("invite must be a state event") return fmt.Errorf("invite must be a state event")
} }
@ -119,7 +146,7 @@ func processInviteEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input ap
roomID := input.Event.RoomID() roomID := input.Event.RoomID()
targetUserID := *input.Event.StateKey() targetUserID := *input.Event.StateKey()
updater, err := db.MembershipUpdater(roomID, targetUserID) updater, err := db.MembershipUpdater(ctx, roomID, targetUserID)
if err != nil { if err != nil {
return err return err
} }

View file

@ -59,12 +59,12 @@ func (r *RoomserverInputAPI) InputRoomEvents(
response *api.InputRoomEventsResponse, response *api.InputRoomEventsResponse,
) error { ) error {
for i := range request.InputRoomEvents { for i := range request.InputRoomEvents {
if err := processRoomEvent(r.DB, r, request.InputRoomEvents[i]); err != nil { if err := processRoomEvent(ctx, r.DB, r, request.InputRoomEvents[i]); err != nil {
return err return err
} }
} }
for i := range request.InputInviteEvents { for i := range request.InputInviteEvents {
if err := processInviteEvent(r.DB, r, request.InputInviteEvents[i]); err != nil { if err := processInviteEvent(ctx, r.DB, r, request.InputInviteEvents[i]); err != nil {
return err return err
} }
} }

View file

@ -16,6 +16,7 @@ package input
import ( import (
"bytes" "bytes"
"context"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -42,6 +43,7 @@ import (
// 7 <----- latest // 7 <----- latest
// //
func updateLatestEvents( func updateLatestEvents(
ctx context.Context,
db RoomEventDatabase, db RoomEventDatabase,
ow OutputRoomEventWriter, ow OutputRoomEventWriter,
roomNID types.RoomNID, roomNID types.RoomNID,
@ -49,7 +51,7 @@ func updateLatestEvents(
event gomatrixserverlib.Event, event gomatrixserverlib.Event,
sendAsServer string, sendAsServer string,
) (err error) { ) (err error) {
updater, err := db.GetLatestEventsForUpdate(roomNID) updater, err := db.GetLatestEventsForUpdate(ctx, roomNID)
if err != nil { if err != nil {
return return
} }
@ -57,7 +59,7 @@ func updateLatestEvents(
defer common.EndTransaction(updater, &succeeded) defer common.EndTransaction(updater, &succeeded)
u := latestEventsUpdater{ u := latestEventsUpdater{
db: db, updater: updater, ow: ow, roomNID: roomNID, ctx: ctx, db: db, updater: updater, ow: ow, roomNID: roomNID,
stateAtEvent: stateAtEvent, event: event, sendAsServer: sendAsServer, stateAtEvent: stateAtEvent, event: event, sendAsServer: sendAsServer,
} }
if err = u.doUpdateLatestEvents(); err != nil { if err = u.doUpdateLatestEvents(); err != nil {
@ -73,6 +75,7 @@ func updateLatestEvents(
// The state could be passed using function arguments, but it becomes impractical // The state could be passed using function arguments, but it becomes impractical
// when there are so many variables to pass around. // when there are so many variables to pass around.
type latestEventsUpdater struct { type latestEventsUpdater struct {
ctx context.Context
db RoomEventDatabase db RoomEventDatabase
updater types.RoomRecentEventsUpdater updater types.RoomRecentEventsUpdater
ow OutputRoomEventWriter ow OutputRoomEventWriter
@ -133,7 +136,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
return err return err
} }
updates, err := updateMemberships(u.db, u.updater, u.removed, u.added) updates, err := updateMemberships(u.ctx, u.db, u.updater, u.removed, u.added)
if err != nil { if err != nil {
return err return err
} }
@ -174,18 +177,22 @@ func (u *latestEventsUpdater) latestState() error {
for i := range u.latest { for i := range u.latest {
latestStateAtEvents[i] = u.latest[i].StateAtEvent latestStateAtEvents[i] = u.latest[i].StateAtEvent
} }
u.newStateNID, err = state.CalculateAndStoreStateAfterEvents(u.db, u.roomNID, latestStateAtEvents) u.newStateNID, err = state.CalculateAndStoreStateAfterEvents(
u.ctx, u.db, u.roomNID, latestStateAtEvents,
)
if err != nil { if err != nil {
return err return err
} }
u.removed, u.added, err = state.DifferenceBetweeenStateSnapshots(u.db, u.oldStateNID, u.newStateNID) u.removed, u.added, err = state.DifferenceBetweeenStateSnapshots(
u.ctx, u.db, u.oldStateNID, u.newStateNID,
)
if err != nil { if err != nil {
return err return err
} }
u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = state.DifferenceBetweeenStateSnapshots( u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = state.DifferenceBetweeenStateSnapshots(
u.db, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID, u.ctx, u.db, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID,
) )
if err != nil { if err != nil {
return err return err
@ -193,7 +200,12 @@ func (u *latestEventsUpdater) latestState() error {
return nil return nil
} }
func calculateLatest(oldLatest []types.StateAtEventAndReference, alreadyReferenced bool, prevEvents []gomatrixserverlib.EventReference, newEvent types.StateAtEventAndReference) []types.StateAtEventAndReference { func calculateLatest(
oldLatest []types.StateAtEventAndReference,
alreadyReferenced bool,
prevEvents []gomatrixserverlib.EventReference,
newEvent types.StateAtEventAndReference,
) []types.StateAtEventAndReference {
var alreadyInLatest bool var alreadyInLatest bool
var newLatest []types.StateAtEventAndReference var newLatest []types.StateAtEventAndReference
for _, l := range oldLatest { for _, l := range oldLatest {
@ -253,7 +265,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
stateEventNIDs = append(stateEventNIDs, entry.EventNID) stateEventNIDs = append(stateEventNIDs, entry.EventNID)
} }
stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))] stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))]
eventIDMap, err := u.db.EventIDs(stateEventNIDs) eventIDMap, err := u.db.EventIDs(u.ctx, stateEventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -15,6 +15,7 @@
package input package input
import ( import (
"context"
"fmt" "fmt"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -27,7 +28,10 @@ import (
// Returns a list of output events to write to the kafka log to inform the // Returns a list of output events to write to the kafka log to inform the
// consumers about the invites added or retired by the change in current state. // consumers about the invites added or retired by the change in current state.
func updateMemberships( func updateMemberships(
db RoomEventDatabase, updater types.RoomRecentEventsUpdater, removed, added []types.StateEntry, ctx context.Context,
db RoomEventDatabase,
updater types.RoomRecentEventsUpdater,
removed, added []types.StateEntry,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
changes := membershipChanges(removed, added) changes := membershipChanges(removed, added)
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
@ -43,7 +47,7 @@ func updateMemberships(
// Load the event JSON so we can look up the "membership" key. // Load the event JSON so we can look up the "membership" key.
// TODO: Maybe add a membership key to the events table so we can load that // TODO: Maybe add a membership key to the events table so we can load that
// key without having to load the entire event JSON? // key without having to load the entire event JSON?
events, err := db.Events(eventNIDs) events, err := db.Events(ctx, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -33,35 +33,47 @@ type RoomserverQueryAPIDatabase interface {
// Look up the numeric ID for the room. // Look up the numeric ID for the room.
// Returns 0 if the room doesn't exists. // Returns 0 if the room doesn't exists.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
RoomNID(roomID string) (types.RoomNID, error) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error)
// Look up event references for the latest events in the room and the current state snapshot. // Look up event references for the latest events in the room and the current state snapshot.
// Returns the latest events, the current state and the maximum depth of the latest events plus 1. // Returns the latest events, the current state and the maximum depth of the latest events plus 1.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) LatestEventIDs(
ctx context.Context, roomNID types.RoomNID,
) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
// Look up the numeric IDs for a list of events. // Look up the numeric IDs for a list of events.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
EventNIDs(eventIDs []string) (map[string]types.EventNID, error) EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error)
// Lookup the event IDs for a batch of event numeric IDs. // Lookup the event IDs for a batch of event numeric IDs.
// Returns an error if the retrieval went wrong. // Returns an error if the retrieval went wrong.
EventIDs(eventNIDs []types.EventNID) (map[types.EventNID]string, error) EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
// Lookup the membership of a given user in a given room. // Lookup the membership of a given user in a given room.
// Returns the numeric ID of the latest membership event sent from this user // Returns the numeric ID of the latest membership event sent from this user
// in this room, along a boolean set to true if the user is still in this room, // in this room, along a boolean set to true if the user is still in this room,
// false if not. // false if not.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
GetMembership(roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) GetMembership(
ctx context.Context, roomNID types.RoomNID, requestSenderUserID string,
) (membershipEventNID types.EventNID, stillInRoom bool, err error)
// Lookup the membership event numeric IDs for all user that are or have // Lookup the membership event numeric IDs for all user that are or have
// been members of a given room. Only lookup events of "join" membership if // been members of a given room. Only lookup events of "join" membership if
// joinOnly is set to true. // joinOnly is set to true.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
GetMembershipEventNIDsForRoom(roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool,
) ([]types.EventNID, error)
// Look up the active invites targeting a user in a room and return the // Look up the active invites targeting a user in a room and return the
// numeric state key IDs for the user IDs who sent them. // numeric state key IDs for the user IDs who sent them.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
GetInvitesForUser(roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserNIDs []types.EventStateKeyNID, err error) GetInvitesForUser(
ctx context.Context,
roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID,
) (senderUserNIDs []types.EventStateKeyNID, err error)
// Look up the string event state keys for a list of numeric event state keys // Look up the string event state keys for a list of numeric event state keys
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
EventStateKeys([]types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) EventStateKeys(
context.Context, []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error)
} }
// RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI // RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI
@ -76,7 +88,7 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState(
response *api.QueryLatestEventsAndStateResponse, response *api.QueryLatestEventsAndStateResponse,
) error { ) error {
response.QueryLatestEventsAndStateRequest = *request response.QueryLatestEventsAndStateRequest = *request
roomNID, err := r.DB.RoomNID(request.RoomID) roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
if err != nil { if err != nil {
return err return err
} }
@ -85,18 +97,21 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState(
} }
response.RoomExists = true response.RoomExists = true
var currentStateSnapshotNID types.StateSnapshotNID var currentStateSnapshotNID types.StateSnapshotNID
response.LatestEvents, currentStateSnapshotNID, response.Depth, err = r.DB.LatestEventIDs(roomNID) response.LatestEvents, currentStateSnapshotNID, response.Depth, err =
r.DB.LatestEventIDs(ctx, roomNID)
if err != nil { if err != nil {
return err return err
} }
// Look up the currrent state for the requested tuples. // Look up the currrent state for the requested tuples.
stateEntries, err := state.LoadStateAtSnapshotForStringTuples(r.DB, currentStateSnapshotNID, request.StateToFetch) stateEntries, err := state.LoadStateAtSnapshotForStringTuples(
ctx, r.DB, currentStateSnapshotNID, request.StateToFetch,
)
if err != nil { if err != nil {
return err return err
} }
stateEvents, err := r.loadStateEvents(stateEntries) stateEvents, err := r.loadStateEvents(ctx, stateEntries)
if err != nil { if err != nil {
return err return err
} }
@ -112,7 +127,7 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents(
response *api.QueryStateAfterEventsResponse, response *api.QueryStateAfterEventsResponse,
) error { ) error {
response.QueryStateAfterEventsRequest = *request response.QueryStateAfterEventsRequest = *request
roomNID, err := r.DB.RoomNID(request.RoomID) roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
if err != nil { if err != nil {
return err return err
} }
@ -121,7 +136,7 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents(
} }
response.RoomExists = true response.RoomExists = true
prevStates, err := r.DB.StateAtEventIDs(request.PrevEventIDs) prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs)
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
case types.MissingEventError: case types.MissingEventError:
@ -133,12 +148,14 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents(
response.PrevEventsExist = true response.PrevEventsExist = true
// Look up the currrent state for the requested tuples. // Look up the currrent state for the requested tuples.
stateEntries, err := state.LoadStateAfterEventsForStringTuples(r.DB, prevStates, request.StateToFetch) stateEntries, err := state.LoadStateAfterEventsForStringTuples(
ctx, r.DB, prevStates, request.StateToFetch,
)
if err != nil { if err != nil {
return err return err
} }
stateEvents, err := r.loadStateEvents(stateEntries) stateEvents, err := r.loadStateEvents(ctx, stateEntries)
if err != nil { if err != nil {
return err return err
} }
@ -155,7 +172,7 @@ func (r *RoomserverQueryAPI) QueryEventsByID(
) error { ) error {
response.QueryEventsByIDRequest = *request response.QueryEventsByIDRequest = *request
eventNIDMap, err := r.DB.EventNIDs(request.EventIDs) eventNIDMap, err := r.DB.EventNIDs(ctx, request.EventIDs)
if err != nil { if err != nil {
return err return err
} }
@ -165,7 +182,7 @@ func (r *RoomserverQueryAPI) QueryEventsByID(
eventNIDs = append(eventNIDs, nid) eventNIDs = append(eventNIDs, nid)
} }
events, err := r.loadEvents(eventNIDs) events, err := r.loadEvents(ctx, eventNIDs)
if err != nil { if err != nil {
return err return err
} }
@ -174,16 +191,20 @@ func (r *RoomserverQueryAPI) QueryEventsByID(
return nil return nil
} }
func (r *RoomserverQueryAPI) loadStateEvents(stateEntries []types.StateEntry) ([]gomatrixserverlib.Event, error) { func (r *RoomserverQueryAPI) loadStateEvents(
ctx context.Context, stateEntries []types.StateEntry,
) ([]gomatrixserverlib.Event, error) {
eventNIDs := make([]types.EventNID, len(stateEntries)) eventNIDs := make([]types.EventNID, len(stateEntries))
for i := range stateEntries { for i := range stateEntries {
eventNIDs[i] = stateEntries[i].EventNID eventNIDs[i] = stateEntries[i].EventNID
} }
return r.loadEvents(eventNIDs) return r.loadEvents(ctx, eventNIDs)
} }
func (r *RoomserverQueryAPI) loadEvents(eventNIDs []types.EventNID) ([]gomatrixserverlib.Event, error) { func (r *RoomserverQueryAPI) loadEvents(
stateEvents, err := r.DB.Events(eventNIDs) ctx context.Context, eventNIDs []types.EventNID,
) ([]gomatrixserverlib.Event, error) {
stateEvents, err := r.DB.Events(ctx, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -201,12 +222,12 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
request *api.QueryMembershipsForRoomRequest, request *api.QueryMembershipsForRoomRequest,
response *api.QueryMembershipsForRoomResponse, response *api.QueryMembershipsForRoomResponse,
) error { ) error {
roomNID, err := r.DB.RoomNID(request.RoomID) roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
if err != nil { if err != nil {
return err return err
} }
membershipEventNID, stillInRoom, err := r.DB.GetMembership(roomNID, request.Sender) membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, roomNID, request.Sender)
if err != nil { if err != nil {
return nil return nil
} }
@ -223,14 +244,14 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
var events []types.Event var events []types.Event
if stillInRoom { if stillInRoom {
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(roomNID, request.JoinedOnly) eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly)
if err != nil { if err != nil {
return err return err
} }
events, err = r.DB.Events(eventNIDs) events, err = r.DB.Events(ctx, eventNIDs)
} else { } else {
events, err = r.getMembershipsBeforeEventNID(membershipEventNID, request.JoinedOnly) events, err = r.getMembershipsBeforeEventNID(ctx, membershipEventNID, request.JoinedOnly)
} }
if err != nil { if err != nil {
@ -249,22 +270,24 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
// of the event's room as it was when this event was fired, then filters the state events to // of the event's room as it was when this event was fired, then filters the state events to
// only keep the "m.room.member" events with a "join" membership. These events are returned. // only keep the "m.room.member" events with a "join" membership. These events are returned.
// Returns an error if there was an issue fetching the events. // Returns an error if there was an issue fetching the events.
func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(eventNID types.EventNID, joinedOnly bool) ([]types.Event, error) { func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(
ctx context.Context, eventNID types.EventNID, joinedOnly bool,
) ([]types.Event, error) {
events := []types.Event{} events := []types.Event{}
// Lookup the event NID // Lookup the event NID
eIDs, err := r.DB.EventIDs([]types.EventNID{eventNID}) eIDs, err := r.DB.EventIDs(ctx, []types.EventNID{eventNID})
if err != nil { if err != nil {
return nil, err return nil, err
} }
eventIDs := []string{eIDs[eventNID]} eventIDs := []string{eIDs[eventNID]}
prevState, err := r.DB.StateAtEventIDs(eventIDs) prevState, err := r.DB.StateAtEventIDs(ctx, eventIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Fetch the state as it was when this event was fired // Fetch the state as it was when this event was fired
stateEntries, err := state.LoadCombinedStateAfterEvents(r.DB, prevState) stateEntries, err := state.LoadCombinedStateAfterEvents(ctx, r.DB, prevState)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -278,7 +301,7 @@ func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(eventNID types.EventNI
} }
// Get all of the events in this state // Get all of the events in this state
stateEvents, err := r.DB.Events(eventNIDs) stateEvents, err := r.DB.Events(ctx, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -304,27 +327,27 @@ func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(eventNID types.EventNI
// QueryInvitesForUser implements api.RoomserverQueryAPI // QueryInvitesForUser implements api.RoomserverQueryAPI
func (r *RoomserverQueryAPI) QueryInvitesForUser( func (r *RoomserverQueryAPI) QueryInvitesForUser(
_ context.Context, ctx context.Context,
request *api.QueryInvitesForUserRequest, request *api.QueryInvitesForUserRequest,
response *api.QueryInvitesForUserResponse, response *api.QueryInvitesForUserResponse,
) error { ) error {
roomNID, err := r.DB.RoomNID(request.RoomID) roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
if err != nil { if err != nil {
return err return err
} }
targetUserNIDs, err := r.DB.EventStateKeyNIDs([]string{request.TargetUserID}) targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.TargetUserID})
if err != nil { if err != nil {
return err return err
} }
targetUserNID := targetUserNIDs[request.TargetUserID] targetUserNID := targetUserNIDs[request.TargetUserID]
senderUserNIDs, err := r.DB.GetInvitesForUser(roomNID, targetUserNID) senderUserNIDs, err := r.DB.GetInvitesForUser(ctx, roomNID, targetUserNID)
if err != nil { if err != nil {
return err return err
} }
senderUserIDs, err := r.DB.EventStateKeys(senderUserNIDs) senderUserIDs, err := r.DB.EventStateKeys(ctx, senderUserNIDs)
if err != nil { if err != nil {
return err return err
} }
@ -342,14 +365,14 @@ func (r *RoomserverQueryAPI) QueryServerAllowedToSeeEvent(
request *api.QueryServerAllowedToSeeEventRequest, request *api.QueryServerAllowedToSeeEventRequest,
response *api.QueryServerAllowedToSeeEventResponse, response *api.QueryServerAllowedToSeeEventResponse,
) error { ) error {
stateEntries, err := state.LoadStateAtEvent(r.DB, request.EventID) stateEntries, err := state.LoadStateAtEvent(ctx, r.DB, request.EventID)
if err != nil { if err != nil {
return err return err
} }
// TODO: We probably want to make it so that we don't have to pull // TODO: We probably want to make it so that we don't have to pull
// out all the state if possible. // out all the state if possible.
stateAtEvent, err := r.loadStateEvents(stateEntries) stateAtEvent, err := r.loadStateEvents(ctx, stateEntries)
if err != nil { if err != nil {
return err return err
} }

View file

@ -17,6 +17,7 @@
package state package state
import ( import (
"context"
"fmt" "fmt"
"sort" "sort"
"time" "time"
@ -30,49 +31,58 @@ import (
// A RoomStateDatabase has the storage APIs needed to load state from the database // A RoomStateDatabase has the storage APIs needed to load state from the database
type RoomStateDatabase interface { type RoomStateDatabase interface {
// Store the room state at an event in the database // Store the room state at an event in the database
AddState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) AddState(
ctx context.Context,
roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (types.StateSnapshotNID, error)
// Look up the state of a room at each event for a list of string event IDs. // Look up the state of a room at each event for a list of string event IDs.
// Returns an error if there is an error talking to the database // Returns an error if there is an error talking to the database
// Returns a types.MissingEventError if the room state for the event IDs aren't in the database // Returns a types.MissingEventError if the room state for the event IDs aren't in the database
StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, error) StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
// Look up the numeric IDs for a list of string event types. // Look up the numeric IDs for a list of string event types.
// Returns a map from string event type to numeric ID for the event type. // Returns a map from string event type to numeric ID for the event type.
EventTypeNIDs(eventTypes []string) (map[string]types.EventTypeNID, error) EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
// Look up the numeric IDs for a list of string event state keys. // Look up the numeric IDs for a list of string event state keys.
// Returns a map from string state key to numeric ID for the state key. // Returns a map from string state key to numeric ID for the state key.
EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
// Look up the numeric state data IDs for each numeric state snapshot ID // Look up the numeric state data IDs for each numeric state snapshot ID
// The returned slice is sorted by numeric state snapshot ID. // The returned slice is sorted by numeric state snapshot ID.
StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
// Look up the state data for each numeric state data ID // Look up the state data for each numeric state data ID
// The returned slice is sorted by numeric state data ID. // The returned slice is sorted by numeric state data ID.
StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
// Look up the state data for the state key tuples for each numeric state block ID // Look up the state data for the state key tuples for each numeric state block ID
// This is used to fetch a subset of the room state at a snapshot. // This is used to fetch a subset of the room state at a snapshot.
// If a block doesn't contain any of the requested tuples then it can be discarded from the result. // If a block doesn't contain any of the requested tuples then it can be discarded from the result.
// The returned slice is sorted by numeric state block ID. // The returned slice is sorted by numeric state block ID.
StateEntriesForTuples(stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ( StateEntriesForTuples(
[]types.StateEntryList, error, ctx context.Context,
) stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error)
// Look up the Events for a list of numeric event IDs. // Look up the Events for a list of numeric event IDs.
// Returns a sorted list of events. // Returns a sorted list of events.
Events(eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
// Look up snapshot NID for an event ID string // Look up snapshot NID for an event ID string
SnapshotNIDFromEventID(eventID string) (types.StateSnapshotNID, error) SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
} }
// LoadStateAtSnapshot loads the full state of a room at a particular snapshot. // LoadStateAtSnapshot loads the full state of a room at a particular snapshot.
// This is typically the state before an event or the current state of a room. // This is typically the state before an event or the current state of a room.
// Returns a sorted list of state entries or an error if there was a problem talking to the database. // Returns a sorted list of state entries or an error if there was a problem talking to the database.
func LoadStateAtSnapshot(db RoomStateDatabase, stateNID types.StateSnapshotNID) ([]types.StateEntry, error) { func LoadStateAtSnapshot(
stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{stateNID}) ctx context.Context, db RoomStateDatabase, stateNID types.StateSnapshotNID,
) ([]types.StateEntry, error) {
stateBlockNIDLists, err := db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
if err != nil { if err != nil {
return nil, err return nil, err
} }
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
stateBlockNIDList := stateBlockNIDLists[0] stateBlockNIDList := stateBlockNIDLists[0]
stateEntryLists, err := db.StateEntries(stateBlockNIDList.StateBlockNIDs) stateEntryLists, err := db.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -100,13 +110,15 @@ func LoadStateAtSnapshot(db RoomStateDatabase, stateNID types.StateSnapshotNID)
} }
// LoadStateAtEvent loads the full state of a room at a particular event. // LoadStateAtEvent loads the full state of a room at a particular event.
func LoadStateAtEvent(db RoomStateDatabase, eventID string) ([]types.StateEntry, error) { func LoadStateAtEvent(
snapshotNID, err := db.SnapshotNIDFromEventID(eventID) ctx context.Context, db RoomStateDatabase, eventID string,
) ([]types.StateEntry, error) {
snapshotNID, err := db.SnapshotNIDFromEventID(ctx, eventID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
stateEntries, err := LoadStateAtSnapshot(db, snapshotNID) stateEntries, err := LoadStateAtSnapshot(ctx, db, snapshotNID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -116,7 +128,9 @@ func LoadStateAtEvent(db RoomStateDatabase, eventID string) ([]types.StateEntry,
// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events // LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events
// and combines those snapshots together into a single list. // and combines those snapshots together into a single list.
func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.StateAtEvent) ([]types.StateEntry, error) { func LoadCombinedStateAfterEvents(
ctx context.Context, db RoomStateDatabase, prevStates []types.StateAtEvent,
) ([]types.StateEntry, error) {
stateNIDs := make([]types.StateSnapshotNID, len(prevStates)) stateNIDs := make([]types.StateSnapshotNID, len(prevStates))
for i, state := range prevStates { for i, state := range prevStates {
stateNIDs[i] = state.BeforeStateSnapshotNID stateNIDs[i] = state.BeforeStateSnapshotNID
@ -125,7 +139,7 @@ func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.State
// Deduplicate the IDs before passing them to the database. // Deduplicate the IDs before passing them to the database.
// There could be duplicates because the events could be state events where // There could be duplicates because the events could be state events where
// the snapshot of the room state before them was the same. // the snapshot of the room state before them was the same.
stateBlockNIDLists, err := db.StateBlockNIDs(uniqueStateSnapshotNIDs(stateNIDs)) stateBlockNIDLists, err := db.StateBlockNIDs(ctx, uniqueStateSnapshotNIDs(stateNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -138,7 +152,7 @@ func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.State
// Deduplicate the IDs before passing them to the database. // Deduplicate the IDs before passing them to the database.
// There could be duplicates because a block of state entries could be reused by // There could be duplicates because a block of state entries could be reused by
// multiple snapshots. // multiple snapshots.
stateEntryLists, err := db.StateEntries(uniqueStateBlockNIDs(stateBlockNIDs)) stateEntryLists, err := db.StateEntries(ctx, uniqueStateBlockNIDs(stateBlockNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -186,9 +200,9 @@ func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.State
} }
// DifferenceBetweeenStateSnapshots works out which state entries have been added and removed between two snapshots. // DifferenceBetweeenStateSnapshots works out which state entries have been added and removed between two snapshots.
func DifferenceBetweeenStateSnapshots(db RoomStateDatabase, oldStateNID, newStateNID types.StateSnapshotNID) ( func DifferenceBetweeenStateSnapshots(
removed, added []types.StateEntry, err error, ctx context.Context, db RoomStateDatabase, oldStateNID, newStateNID types.StateSnapshotNID,
) { ) (removed, added []types.StateEntry, err error) {
if oldStateNID == newStateNID { if oldStateNID == newStateNID {
// If the snapshot NIDs are the same then nothing has changed // If the snapshot NIDs are the same then nothing has changed
return nil, nil, nil return nil, nil, nil
@ -197,13 +211,13 @@ func DifferenceBetweeenStateSnapshots(db RoomStateDatabase, oldStateNID, newStat
var oldEntries []types.StateEntry var oldEntries []types.StateEntry
var newEntries []types.StateEntry var newEntries []types.StateEntry
if oldStateNID != 0 { if oldStateNID != 0 {
oldEntries, err = LoadStateAtSnapshot(db, oldStateNID) oldEntries, err = LoadStateAtSnapshot(ctx, db, oldStateNID)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
} }
if newStateNID != 0 { if newStateNID != 0 {
newEntries, err = LoadStateAtSnapshot(db, newStateNID) newEntries, err = LoadStateAtSnapshot(ctx, db, newStateNID)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -246,19 +260,26 @@ func DifferenceBetweeenStateSnapshots(db RoomStateDatabase, oldStateNID, newStat
// This is typically the state before an event or the current state of a room. // This is typically the state before an event or the current state of a room.
// Returns a sorted list of state entries or an error if there was a problem talking to the database. // Returns a sorted list of state entries or an error if there was a problem talking to the database.
func LoadStateAtSnapshotForStringTuples( func LoadStateAtSnapshotForStringTuples(
db RoomStateDatabase, stateNID types.StateSnapshotNID, stateKeyTuples []gomatrixserverlib.StateKeyTuple, ctx context.Context,
db RoomStateDatabase,
stateNID types.StateSnapshotNID,
stateKeyTuples []gomatrixserverlib.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
numericTuples, err := stringTuplesToNumericTuples(db, stateKeyTuples) numericTuples, err := stringTuplesToNumericTuples(ctx, db, stateKeyTuples)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return loadStateAtSnapshotForNumericTuples(db, stateNID, numericTuples) return loadStateAtSnapshotForNumericTuples(ctx, db, stateNID, numericTuples)
} }
// stringTuplesToNumericTuples converts the string state key tuples into numeric IDs // stringTuplesToNumericTuples converts the string state key tuples into numeric IDs
// If there isn't a numeric ID for either the event type or the event state key then the tuple is discarded. // If there isn't a numeric ID for either the event type or the event state key then the tuple is discarded.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
func stringTuplesToNumericTuples(db RoomStateDatabase, stringTuples []gomatrixserverlib.StateKeyTuple) ([]types.StateKeyTuple, error) { func stringTuplesToNumericTuples(
ctx context.Context,
db RoomStateDatabase,
stringTuples []gomatrixserverlib.StateKeyTuple,
) ([]types.StateKeyTuple, error) {
eventTypes := make([]string, len(stringTuples)) eventTypes := make([]string, len(stringTuples))
stateKeys := make([]string, len(stringTuples)) stateKeys := make([]string, len(stringTuples))
for i := range stringTuples { for i := range stringTuples {
@ -266,12 +287,12 @@ func stringTuplesToNumericTuples(db RoomStateDatabase, stringTuples []gomatrixse
stateKeys[i] = stringTuples[i].StateKey stateKeys[i] = stringTuples[i].StateKey
} }
eventTypes = util.UniqueStrings(eventTypes) eventTypes = util.UniqueStrings(eventTypes)
eventTypeMap, err := db.EventTypeNIDs(eventTypes) eventTypeMap, err := db.EventTypeNIDs(ctx, eventTypes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
stateKeys = util.UniqueStrings(stateKeys) stateKeys = util.UniqueStrings(stateKeys)
stateKeyMap, err := db.EventStateKeyNIDs(stateKeys) stateKeyMap, err := db.EventStateKeyNIDs(ctx, stateKeys)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -297,16 +318,21 @@ func stringTuplesToNumericTuples(db RoomStateDatabase, stringTuples []gomatrixse
// This is typically the state before an event or the current state of a room. // This is typically the state before an event or the current state of a room.
// Returns a sorted list of state entries or an error if there was a problem talking to the database. // Returns a sorted list of state entries or an error if there was a problem talking to the database.
func loadStateAtSnapshotForNumericTuples( func loadStateAtSnapshotForNumericTuples(
db RoomStateDatabase, stateNID types.StateSnapshotNID, stateKeyTuples []types.StateKeyTuple, ctx context.Context,
db RoomStateDatabase,
stateNID types.StateSnapshotNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{stateNID}) stateBlockNIDLists, err := db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
if err != nil { if err != nil {
return nil, err return nil, err
} }
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
stateBlockNIDList := stateBlockNIDLists[0] stateBlockNIDList := stateBlockNIDLists[0]
stateEntryLists, err := db.StateEntriesForTuples(stateBlockNIDList.StateBlockNIDs, stateKeyTuples) stateEntryLists, err := db.StateEntriesForTuples(
ctx, stateBlockNIDList.StateBlockNIDs, stateKeyTuples,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -341,23 +367,29 @@ func loadStateAtSnapshotForNumericTuples(
// This is typically the state before an event. // This is typically the state before an event.
// Returns a sorted list of state entries or an error if there was a problem talking to the database. // Returns a sorted list of state entries or an error if there was a problem talking to the database.
func LoadStateAfterEventsForStringTuples( func LoadStateAfterEventsForStringTuples(
db RoomStateDatabase, prevStates []types.StateAtEvent, stateKeyTuples []gomatrixserverlib.StateKeyTuple, ctx context.Context,
db RoomStateDatabase,
prevStates []types.StateAtEvent,
stateKeyTuples []gomatrixserverlib.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
numericTuples, err := stringTuplesToNumericTuples(db, stateKeyTuples) numericTuples, err := stringTuplesToNumericTuples(ctx, db, stateKeyTuples)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return loadStateAfterEventsForNumericTuples(db, prevStates, numericTuples) return loadStateAfterEventsForNumericTuples(ctx, db, prevStates, numericTuples)
} }
func loadStateAfterEventsForNumericTuples( func loadStateAfterEventsForNumericTuples(
db RoomStateDatabase, prevStates []types.StateAtEvent, stateKeyTuples []types.StateKeyTuple, ctx context.Context,
db RoomStateDatabase,
prevStates []types.StateAtEvent,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
if len(prevStates) == 1 { if len(prevStates) == 1 {
// Fast path for a single event. // Fast path for a single event.
prevState := prevStates[0] prevState := prevStates[0]
result, err := loadStateAtSnapshotForNumericTuples( result, err := loadStateAtSnapshotForNumericTuples(
db, prevState.BeforeStateSnapshotNID, stateKeyTuples, ctx, db, prevState.BeforeStateSnapshotNID, stateKeyTuples,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -390,7 +422,7 @@ func loadStateAfterEventsForNumericTuples(
// TODO: Add metrics for this as it could take a long time for big rooms // TODO: Add metrics for this as it could take a long time for big rooms
// with large conflicts. // with large conflicts.
fullState, _, _, err := calculateStateAfterManyEvents(db, prevStates) fullState, _, _, err := calculateStateAfterManyEvents(ctx, db, prevStates)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -403,7 +435,10 @@ func loadStateAfterEventsForNumericTuples(
for _, tuple := range stateKeyTuples { for _, tuple := range stateKeyTuples {
eventNID, ok := stateEntryMap(fullState).lookup(tuple) eventNID, ok := stateEntryMap(fullState).lookup(tuple)
if ok { if ok {
result = append(result, types.StateEntry{tuple, eventNID}) result = append(result, types.StateEntry{
StateKeyTuple: tuple,
EventNID: eventNID,
})
} }
} }
sort.Sort(stateEntrySorter(result)) sort.Sort(stateEntrySorter(result))
@ -509,7 +544,10 @@ func init() {
// Stores the snapshot of the state in the database. // Stores the snapshot of the state in the database.
// Returns a numeric ID for the snapshot of the state before the event. // Returns a numeric ID for the snapshot of the state before the event.
func CalculateAndStoreStateBeforeEvent( func CalculateAndStoreStateBeforeEvent(
db RoomStateDatabase, event gomatrixserverlib.Event, roomNID types.RoomNID, ctx context.Context,
db RoomStateDatabase,
event gomatrixserverlib.Event,
roomNID types.RoomNID,
) (types.StateSnapshotNID, error) { ) (types.StateSnapshotNID, error) {
// Load the state at the prev events. // Load the state at the prev events.
prevEventRefs := event.PrevEvents() prevEventRefs := event.PrevEvents()
@ -518,25 +556,30 @@ func CalculateAndStoreStateBeforeEvent(
prevEventIDs[i] = prevEventRefs[i].EventID prevEventIDs[i] = prevEventRefs[i].EventID
} }
prevStates, err := db.StateAtEventIDs(prevEventIDs) prevStates, err := db.StateAtEventIDs(ctx, prevEventIDs)
if err != nil { if err != nil {
return 0, err return 0, err
} }
// The state before this event will be the state after the events that came before it. // The state before this event will be the state after the events that came before it.
return CalculateAndStoreStateAfterEvents(db, roomNID, prevStates) return CalculateAndStoreStateAfterEvents(ctx, db, roomNID, prevStates)
} }
// CalculateAndStoreStateAfterEvents finds the room state after the given events. // CalculateAndStoreStateAfterEvents finds the room state after the given events.
// Stores the resulting state in the database and returns a numeric ID for that snapshot. // Stores the resulting state in the database and returns a numeric ID for that snapshot.
func CalculateAndStoreStateAfterEvents(db RoomStateDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent) (types.StateSnapshotNID, error) { func CalculateAndStoreStateAfterEvents(
ctx context.Context,
db RoomStateDatabase,
roomNID types.RoomNID,
prevStates []types.StateAtEvent,
) (types.StateSnapshotNID, error) {
metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)} metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)}
if len(prevStates) == 0 { if len(prevStates) == 0 {
// 2) There weren't any prev_events for this event so the state is // 2) There weren't any prev_events for this event so the state is
// empty. // empty.
metrics.algorithm = "empty_state" metrics.algorithm = "empty_state"
return metrics.stop(db.AddState(roomNID, nil, nil)) return metrics.stop(db.AddState(ctx, roomNID, nil, nil))
} }
if len(prevStates) == 1 { if len(prevStates) == 1 {
@ -551,7 +594,9 @@ func CalculateAndStoreStateAfterEvents(db RoomStateDatabase, roomNID types.RoomN
} }
// The previous event was a state event so we need to store a copy // The previous event was a state event so we need to store a copy
// of the previous state updated with that event. // of the previous state updated with that event.
stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{prevState.BeforeStateSnapshotNID}) stateBlockNIDLists, err := db.StateBlockNIDs(
ctx, []types.StateSnapshotNID{prevState.BeforeStateSnapshotNID},
)
if err != nil { if err != nil {
metrics.algorithm = "_load_state_blocks" metrics.algorithm = "_load_state_blocks"
return metrics.stop(0, err) return metrics.stop(0, err)
@ -562,14 +607,14 @@ func CalculateAndStoreStateAfterEvents(db RoomStateDatabase, roomNID types.RoomN
// add the state event as a block of size one to the end of the blocks. // add the state event as a block of size one to the end of the blocks.
metrics.algorithm = "single_delta" metrics.algorithm = "single_delta"
return metrics.stop(db.AddState( return metrics.stop(db.AddState(
roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry},
)) ))
} }
// If there are too many deltas then we need to calculate the full state // If there are too many deltas then we need to calculate the full state
// So fall through to calculateAndStoreStateAfterManyEvents // So fall through to calculateAndStoreStateAfterManyEvents
} }
return calculateAndStoreStateAfterManyEvents(db, roomNID, prevStates, metrics) return calculateAndStoreStateAfterManyEvents(ctx, db, roomNID, prevStates, metrics)
} }
// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state. // maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state.
@ -583,10 +628,15 @@ const maxStateBlockNIDs = 64
// This handles the slow path of calculateAndStoreStateAfterEvents for when there is more than one event. // This handles the slow path of calculateAndStoreStateAfterEvents for when there is more than one event.
// Stores the resulting state and returns a numeric ID for the snapshot. // Stores the resulting state and returns a numeric ID for the snapshot.
func calculateAndStoreStateAfterManyEvents( func calculateAndStoreStateAfterManyEvents(
db RoomStateDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent, metrics calculateStateMetrics, ctx context.Context,
db RoomStateDatabase,
roomNID types.RoomNID,
prevStates []types.StateAtEvent,
metrics calculateStateMetrics,
) (types.StateSnapshotNID, error) { ) (types.StateSnapshotNID, error) {
state, algorithm, conflictLength, err := calculateStateAfterManyEvents(db, prevStates) state, algorithm, conflictLength, err :=
calculateStateAfterManyEvents(ctx, db, prevStates)
metrics.algorithm = algorithm metrics.algorithm = algorithm
if err != nil { if err != nil {
return metrics.stop(0, err) return metrics.stop(0, err)
@ -596,16 +646,16 @@ func calculateAndStoreStateAfterManyEvents(
// previous state. // previous state.
metrics.conflictLength = conflictLength metrics.conflictLength = conflictLength
metrics.fullStateLength = len(state) metrics.fullStateLength = len(state)
return metrics.stop(db.AddState(roomNID, nil, state)) return metrics.stop(db.AddState(ctx, roomNID, nil, state))
} }
func calculateStateAfterManyEvents( func calculateStateAfterManyEvents(
db RoomStateDatabase, prevStates []types.StateAtEvent, ctx context.Context, db RoomStateDatabase, prevStates []types.StateAtEvent,
) (state []types.StateEntry, algorithm string, conflictLength int, err error) { ) (state []types.StateEntry, algorithm string, conflictLength int, err error) {
var combined []types.StateEntry var combined []types.StateEntry
// Conflict resolution. // Conflict resolution.
// First stage: load the state after each of the prev events. // First stage: load the state after each of the prev events.
combined, err = LoadCombinedStateAfterEvents(db, prevStates) combined, err = LoadCombinedStateAfterEvents(ctx, db, prevStates)
if err != nil { if err != nil {
algorithm = "_load_combined_state" algorithm = "_load_combined_state"
return return
@ -635,7 +685,7 @@ func calculateStateAfterManyEvents(
} }
var resolved []types.StateEntry var resolved []types.StateEntry
resolved, err = resolveConflicts(db, notConflicted, conflicts) resolved, err = resolveConflicts(ctx, db, notConflicted, conflicts)
if err != nil { if err != nil {
algorithm = "_resolve_conflicts" algorithm = "_resolve_conflicts"
return return
@ -657,10 +707,14 @@ func calculateStateAfterManyEvents(
// Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts. // Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts.
// The returned list is sorted by state key tuple. // The returned list is sorted by state key tuple.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
func resolveConflicts(db RoomStateDatabase, notConflicted, conflicted []types.StateEntry) ([]types.StateEntry, error) { func resolveConflicts(
ctx context.Context,
db RoomStateDatabase,
notConflicted, conflicted []types.StateEntry,
) ([]types.StateEntry, error) {
// Load the conflicted events // Load the conflicted events
conflictedEvents, eventIDMap, err := loadStateEvents(db, conflicted) conflictedEvents, eventIDMap, err := loadStateEvents(ctx, db, conflicted)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -672,7 +726,7 @@ func resolveConflicts(db RoomStateDatabase, notConflicted, conflicted []types.St
var neededStateKeys []string var neededStateKeys []string
neededStateKeys = append(neededStateKeys, needed.Member...) neededStateKeys = append(neededStateKeys, needed.Member...)
neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...) neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
stateKeyNIDMap, err := db.EventStateKeyNIDs(neededStateKeys) stateKeyNIDMap, err := db.EventStateKeyNIDs(ctx, neededStateKeys)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -682,10 +736,13 @@ func resolveConflicts(db RoomStateDatabase, notConflicted, conflicted []types.St
var authEntries []types.StateEntry var authEntries []types.StateEntry
for _, tuple := range tuplesNeeded { for _, tuple := range tuplesNeeded {
if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok { if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok {
authEntries = append(authEntries, types.StateEntry{tuple, eventNID}) authEntries = append(authEntries, types.StateEntry{
StateKeyTuple: tuple,
EventNID: eventNID,
})
} }
} }
authEvents, _, err := loadStateEvents(db, authEntries) authEvents, _, err := loadStateEvents(ctx, db, authEntries)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -711,24 +768,39 @@ func resolveConflicts(db RoomStateDatabase, notConflicted, conflicted []types.St
func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple { func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple {
var keyTuples []types.StateKeyTuple var keyTuples []types.StateKeyTuple
if stateNeeded.Create { if stateNeeded.Create {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomCreateNID, types.EmptyStateKeyNID}) keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomCreateNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
} }
if stateNeeded.PowerLevels { if stateNeeded.PowerLevels {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomPowerLevelsNID, types.EmptyStateKeyNID}) keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomPowerLevelsNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
} }
if stateNeeded.JoinRules { if stateNeeded.JoinRules {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomJoinRulesNID, types.EmptyStateKeyNID}) keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomJoinRulesNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
} }
for _, member := range stateNeeded.Member { for _, member := range stateNeeded.Member {
stateKeyNID, ok := stateKeyNIDMap[member] stateKeyNID, ok := stateKeyNIDMap[member]
if ok { if ok {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomMemberNID, stateKeyNID}) keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomMemberNID,
EventStateKeyNID: stateKeyNID,
})
} }
} }
for _, token := range stateNeeded.ThirdPartyInvite { for _, token := range stateNeeded.ThirdPartyInvite {
stateKeyNID, ok := stateKeyNIDMap[token] stateKeyNID, ok := stateKeyNIDMap[token]
if ok { if ok {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomThirdPartyInviteNID, stateKeyNID}) keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomThirdPartyInviteNID,
EventStateKeyNID: stateKeyNID,
})
} }
} }
return keyTuples return keyTuples
@ -738,12 +810,14 @@ func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stat
// Returns a list of state events in no particular order and a map from string event ID back to state entry. // Returns a list of state events in no particular order and a map from string event ID back to state entry.
// The map can be used to recover which numeric state entry a given event is for. // The map can be used to recover which numeric state entry a given event is for.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
func loadStateEvents(db RoomStateDatabase, entries []types.StateEntry) ([]gomatrixserverlib.Event, map[string]types.StateEntry, error) { func loadStateEvents(
ctx context.Context, db RoomStateDatabase, entries []types.StateEntry,
) ([]gomatrixserverlib.Event, map[string]types.StateEntry, error) {
eventNIDs := make([]types.EventNID, len(entries)) eventNIDs := make([]types.EventNID, len(entries))
for i := range entries { for i := range entries {
eventNIDs[i] = entries[i].EventNID eventNIDs[i] = entries[i].EventNID
} }
events, err := db.Events(eventNIDs) events, err := db.Events(ctx, eventNIDs)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View file

@ -15,6 +15,7 @@
package storage package storage
import ( import (
"context"
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
@ -65,8 +66,10 @@ func (s *eventJSONStatements) prepare(db *sql.DB) (err error) {
}.prepare(db) }.prepare(db)
} }
func (s *eventJSONStatements) insertEventJSON(eventNID types.EventNID, eventJSON []byte) error { func (s *eventJSONStatements) insertEventJSON(
_, err := s.insertEventJSONStmt.Exec(int64(eventNID), eventJSON) ctx context.Context, eventNID types.EventNID, eventJSON []byte,
) error {
_, err := s.insertEventJSONStmt.ExecContext(ctx, int64(eventNID), eventJSON)
return err return err
} }
@ -75,8 +78,10 @@ type eventJSONPair struct {
EventJSON []byte EventJSON []byte
} }
func (s *eventJSONStatements) bulkSelectEventJSON(eventNIDs []types.EventNID) ([]eventJSONPair, error) { func (s *eventJSONStatements) bulkSelectEventJSON(
rows, err := s.bulkSelectEventJSONStmt.Query(eventNIDsAsArray(eventNIDs)) ctx context.Context, eventNIDs []types.EventNID,
) ([]eventJSONPair, error) {
rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -15,6 +15,7 @@
package storage package storage
import ( import (
"context"
"database/sql" "database/sql"
"github.com/lib/pq" "github.com/lib/pq"
@ -91,20 +92,30 @@ func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
}.prepare(db) }.prepare(db)
} }
func (s *eventStateKeyStatements) insertEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) { func (s *eventStateKeyStatements) insertEventStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) {
var eventStateKeyNID int64 var eventStateKeyNID int64
err := common.TxStmt(txn, s.insertEventStateKeyNIDStmt).QueryRow(eventStateKey).Scan(&eventStateKeyNID) stmt := common.TxStmt(txn, s.insertEventStateKeyNIDStmt)
err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
return types.EventStateKeyNID(eventStateKeyNID), err return types.EventStateKeyNID(eventStateKeyNID), err
} }
func (s *eventStateKeyStatements) selectEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) { func (s *eventStateKeyStatements) selectEventStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) {
var eventStateKeyNID int64 var eventStateKeyNID int64
err := common.TxStmt(txn, s.selectEventStateKeyNIDStmt).QueryRow(eventStateKey).Scan(&eventStateKeyNID) stmt := common.TxStmt(txn, s.selectEventStateKeyNIDStmt)
err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
return types.EventStateKeyNID(eventStateKeyNID), err return types.EventStateKeyNID(eventStateKeyNID), err
} }
func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) { func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
rows, err := s.bulkSelectEventStateKeyNIDStmt.Query(pq.StringArray(eventStateKeys)) ctx context.Context, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext(
ctx, pq.StringArray(eventStateKeys),
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -122,18 +133,23 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(eventStateKeys []st
return result, nil return result, nil
} }
func (s *eventStateKeyStatements) selectEventStateKey(txn *sql.Tx, eventStateKeyNID types.EventStateKeyNID) (string, error) { func (s *eventStateKeyStatements) selectEventStateKey(
ctx context.Context, txn *sql.Tx, eventStateKeyNID types.EventStateKeyNID,
) (string, error) {
var eventStateKey string var eventStateKey string
err := common.TxStmt(txn, s.selectEventStateKeyStmt).QueryRow(eventStateKeyNID).Scan(&eventStateKey) stmt := common.TxStmt(txn, s.selectEventStateKeyStmt)
err := stmt.QueryRowContext(ctx, eventStateKeyNID).Scan(&eventStateKey)
return eventStateKey, err return eventStateKey, err
} }
func (s *eventStateKeyStatements) bulkSelectEventStateKey(eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) { func (s *eventStateKeyStatements) bulkSelectEventStateKey(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
var nIDs pq.Int64Array var nIDs pq.Int64Array
for i := range eventStateKeyNIDs { for i := range eventStateKeyNIDs {
nIDs[i] = int64(eventStateKeyNIDs[i]) nIDs[i] = int64(eventStateKeyNIDs[i])
} }
rows, err := s.bulkSelectEventStateKeyStmt.Query(nIDs) rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -15,6 +15,7 @@
package storage package storage
import ( import (
"context"
"database/sql" "database/sql"
"github.com/lib/pq" "github.com/lib/pq"
@ -107,20 +108,26 @@ func (s *eventTypeStatements) prepare(db *sql.DB) (err error) {
}.prepare(db) }.prepare(db)
} }
func (s *eventTypeStatements) insertEventTypeNID(eventType string) (types.EventTypeNID, error) { func (s *eventTypeStatements) insertEventTypeNID(
ctx context.Context, eventType string,
) (types.EventTypeNID, error) {
var eventTypeNID int64 var eventTypeNID int64
err := s.insertEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID) err := s.insertEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID)
return types.EventTypeNID(eventTypeNID), err return types.EventTypeNID(eventTypeNID), err
} }
func (s *eventTypeStatements) selectEventTypeNID(eventType string) (types.EventTypeNID, error) { func (s *eventTypeStatements) selectEventTypeNID(
ctx context.Context, eventType string,
) (types.EventTypeNID, error) {
var eventTypeNID int64 var eventTypeNID int64
err := s.selectEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID) err := s.selectEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID)
return types.EventTypeNID(eventTypeNID), err return types.EventTypeNID(eventTypeNID), err
} }
func (s *eventTypeStatements) bulkSelectEventTypeNID(eventTypes []string) (map[string]types.EventTypeNID, error) { func (s *eventTypeStatements) bulkSelectEventTypeNID(
rows, err := s.bulkSelectEventTypeNIDStmt.Query(pq.StringArray(eventTypes)) ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -15,6 +15,7 @@
package storage package storage
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
@ -154,7 +155,10 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
} }
func (s *eventStatements) insertEvent( func (s *eventStatements) insertEvent(
roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, ctx context.Context,
roomNID types.RoomNID,
eventTypeNID types.EventTypeNID,
eventStateKeyNID types.EventStateKeyNID,
eventID string, eventID string,
referenceSHA256 []byte, referenceSHA256 []byte,
authEventNIDs []types.EventNID, authEventNIDs []types.EventNID,
@ -162,24 +166,28 @@ func (s *eventStatements) insertEvent(
) (types.EventNID, types.StateSnapshotNID, error) { ) (types.EventNID, types.StateSnapshotNID, error) {
var eventNID int64 var eventNID int64
var stateNID int64 var stateNID int64
err := s.insertEventStmt.QueryRow( err := s.insertEventStmt.QueryRowContext(
int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256, ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
eventNIDsAsArray(authEventNIDs), depth, eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
).Scan(&eventNID, &stateNID) ).Scan(&eventNID, &stateNID)
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
} }
func (s *eventStatements) selectEvent(eventID string) (types.EventNID, types.StateSnapshotNID, error) { func (s *eventStatements) selectEvent(
ctx context.Context, eventID string,
) (types.EventNID, types.StateSnapshotNID, error) {
var eventNID int64 var eventNID int64
var stateNID int64 var stateNID int64
err := s.selectEventStmt.QueryRow(eventID).Scan(&eventNID, &stateNID) err := s.selectEventStmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID)
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
} }
// bulkSelectStateEventByID lookups a list of state events by event ID. // bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) bulkSelectStateEventByID(eventIDs []string) ([]types.StateEntry, error) { func (s *eventStatements) bulkSelectStateEventByID(
rows, err := s.bulkSelectStateEventByIDStmt.Query(pq.StringArray(eventIDs)) ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error) {
rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -216,8 +224,10 @@ func (s *eventStatements) bulkSelectStateEventByID(eventIDs []string) ([]types.S
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError. // If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError.
func (s *eventStatements) bulkSelectStateAtEventByID(eventIDs []string) ([]types.StateAtEvent, error) { func (s *eventStatements) bulkSelectStateAtEventByID(
rows, err := s.bulkSelectStateAtEventByIDStmt.Query(pq.StringArray(eventIDs)) ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) {
rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -248,28 +258,40 @@ func (s *eventStatements) bulkSelectStateAtEventByID(eventIDs []string) ([]types
return results, err return results, err
} }
func (s *eventStatements) updateEventState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error { func (s *eventStatements) updateEventState(
_, err := s.updateEventStateStmt.Exec(int64(eventNID), int64(stateNID)) ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error {
_, err := s.updateEventStateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID))
return err return err
} }
func (s *eventStatements) selectEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) { func (s *eventStatements) selectEventSentToOutput(
err = common.TxStmt(txn, s.selectEventSentToOutputStmt).QueryRow(int64(eventNID)).Scan(&sentToOutput) ctx context.Context, txn *sql.Tx, eventNID types.EventNID,
) (sentToOutput bool, err error) {
stmt := common.TxStmt(txn, s.selectEventSentToOutputStmt)
stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput)
return return
} }
func (s *eventStatements) updateEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) error { func (s *eventStatements) updateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error {
_, err := common.TxStmt(txn, s.updateEventSentToOutputStmt).Exec(int64(eventNID)) stmt := common.TxStmt(txn, s.updateEventSentToOutputStmt)
_, err := stmt.ExecContext(ctx, int64(eventNID))
return err return err
} }
func (s *eventStatements) selectEventID(txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) { func (s *eventStatements) selectEventID(
err = common.TxStmt(txn, s.selectEventIDStmt).QueryRow(int64(eventNID)).Scan(&eventID) ctx context.Context, txn *sql.Tx, eventNID types.EventNID,
) (eventID string, err error) {
stmt := common.TxStmt(txn, s.selectEventIDStmt)
err = stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&eventID)
return return
} }
func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) { func (s *eventStatements) bulkSelectStateAtEventAndReference(
rows, err := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt).Query(eventNIDsAsArray(eventNIDs)) ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]types.StateAtEventAndReference, error) {
stmt := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -304,8 +326,10 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventN
return results, nil return results, nil
} }
func (s *eventStatements) bulkSelectEventReference(eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error) { func (s *eventStatements) bulkSelectEventReference(
rows, err := s.bulkSelectEventReferenceStmt.Query(eventNIDsAsArray(eventNIDs)) ctx context.Context, eventNIDs []types.EventNID,
) ([]gomatrixserverlib.EventReference, error) {
rows, err := s.bulkSelectEventReferenceStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -325,8 +349,8 @@ func (s *eventStatements) bulkSelectEventReference(eventNIDs []types.EventNID) (
} }
// bulkSelectEventID returns a map from numeric event ID to string event ID. // bulkSelectEventID returns a map from numeric event ID to string event ID.
func (s *eventStatements) bulkSelectEventID(eventNIDs []types.EventNID) (map[types.EventNID]string, error) { func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
rows, err := s.bulkSelectEventIDStmt.Query(eventNIDsAsArray(eventNIDs)) rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -349,8 +373,8 @@ func (s *eventStatements) bulkSelectEventID(eventNIDs []types.EventNID) (map[typ
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map. // If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) bulkSelectEventNID(eventIDs []string) (map[string]types.EventNID, error) { func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) {
rows, err := s.bulkSelectEventNIDStmt.Query(pq.StringArray(eventIDs)) rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -367,9 +391,10 @@ func (s *eventStatements) bulkSelectEventNID(eventIDs []string) (map[string]type
return results, nil return results, nil
} }
func (s *eventStatements) selectMaxEventDepth(eventNIDs []types.EventNID) (int64, error) { func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) {
var result int64 var result int64
err := s.selectMaxEventDepthStmt.QueryRow(eventNIDsAsArray(eventNIDs)).Scan(&result) stmt := s.selectMaxEventDepthStmt
err := stmt.QueryRowContext(ctx, eventNIDsAsArray(eventNIDs)).Scan(&result)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View file

@ -15,6 +15,7 @@
package storage package storage
import ( import (
"context"
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
@ -91,12 +92,13 @@ func (s *inviteStatements) prepare(db *sql.DB) (err error) {
} }
func (s *inviteStatements) insertInviteEvent( func (s *inviteStatements) insertInviteEvent(
ctx context.Context,
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
targetUserNID, senderUserNID types.EventStateKeyNID, targetUserNID, senderUserNID types.EventStateKeyNID,
inviteEventJSON []byte, inviteEventJSON []byte,
) (bool, error) { ) (bool, error) {
result, err := common.TxStmt(txn, s.insertInviteEventStmt).Exec( result, err := common.TxStmt(txn, s.insertInviteEventStmt).ExecContext(
inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
) )
if err != nil { if err != nil {
return false, err return false, err
@ -109,9 +111,11 @@ func (s *inviteStatements) insertInviteEvent(
} }
func (s *inviteStatements) updateInviteRetired( func (s *inviteStatements) updateInviteRetired(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) ([]string, error) { ) ([]string, error) {
rows, err := common.TxStmt(txn, s.updateInviteRetiredStmt).Query(roomNID, targetUserNID) stmt := common.TxStmt(txn, s.updateInviteRetiredStmt)
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -129,10 +133,11 @@ func (s *inviteStatements) updateInviteRetired(
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs // selectInviteActiveForUserInRoom returns a list of sender state key NIDs
func (s *inviteStatements) selectInviteActiveForUserInRoom( func (s *inviteStatements) selectInviteActiveForUserInRoom(
ctx context.Context,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, error) { ) ([]types.EventStateKeyNID, error) {
rows, err := s.selectInviteActiveForUserInRoomStmt.Query( rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
targetUserNID, roomNID, ctx, targetUserNID, roomNID,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -15,6 +15,7 @@
package storage package storage
import ( import (
"context"
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
@ -114,34 +115,38 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
} }
func (s *membershipStatements) insertMembership( func (s *membershipStatements) insertMembership(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) error { ) error {
_, err := common.TxStmt(txn, s.insertMembershipStmt).Exec(roomNID, targetUserNID) stmt := common.TxStmt(txn, s.insertMembershipStmt)
_, err := stmt.ExecContext(ctx, roomNID, targetUserNID)
return err return err
} }
func (s *membershipStatements) selectMembershipForUpdate( func (s *membershipStatements) selectMembershipForUpdate(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (membership membershipState, err error) { ) (membership membershipState, err error) {
err = common.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRow( err = common.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext(
roomNID, targetUserNID, ctx, roomNID, targetUserNID,
).Scan(&membership) ).Scan(&membership)
return return
} }
func (s *membershipStatements) selectMembershipFromRoomAndTarget( func (s *membershipStatements) selectMembershipFromRoomAndTarget(
ctx context.Context,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership membershipState, err error) { ) (eventNID types.EventNID, membership membershipState, err error) {
err = s.selectMembershipFromRoomAndTargetStmt.QueryRow( err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
roomNID, targetUserNID, ctx, roomNID, targetUserNID,
).Scan(&membership, &eventNID) ).Scan(&membership, &eventNID)
return return
} }
func (s *membershipStatements) selectMembershipsFromRoom( func (s *membershipStatements) selectMembershipsFromRoom(
roomNID types.RoomNID, ctx context.Context, roomNID types.RoomNID,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
rows, err := s.selectMembershipsFromRoomStmt.Query(roomNID) rows, err := s.selectMembershipsFromRoomStmt.QueryContext(ctx, roomNID)
if err != nil { if err != nil {
return return
} }
@ -156,9 +161,11 @@ func (s *membershipStatements) selectMembershipsFromRoom(
return return
} }
func (s *membershipStatements) selectMembershipsFromRoomAndMembership( func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
ctx context.Context,
roomNID types.RoomNID, membership membershipState, roomNID types.RoomNID, membership membershipState,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
rows, err := s.selectMembershipsFromRoomAndMembershipStmt.Query(roomNID, membership) stmt := s.selectMembershipsFromRoomAndMembershipStmt
rows, err := stmt.QueryContext(ctx, roomNID, membership)
if err != nil { if err != nil {
return return
} }
@ -174,12 +181,13 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
} }
func (s *membershipStatements) updateMembership( func (s *membershipStatements) updateMembership(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
senderUserNID types.EventStateKeyNID, membership membershipState, senderUserNID types.EventStateKeyNID, membership membershipState,
eventNID types.EventNID, eventNID types.EventNID,
) error { ) error {
_, err := common.TxStmt(txn, s.updateMembershipStmt).Exec( _, err := common.TxStmt(txn, s.updateMembershipStmt).ExecContext(
roomNID, targetUserNID, senderUserNID, membership, eventNID, ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID,
) )
return err return err
} }

View file

@ -15,6 +15,7 @@
package storage package storage
import ( import (
"context"
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
@ -73,14 +74,26 @@ func (s *previousEventStatements) prepare(db *sql.DB) (err error) {
}.prepare(db) }.prepare(db)
} }
func (s *previousEventStatements) insertPreviousEvent(txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error { func (s *previousEventStatements) insertPreviousEvent(
_, err := common.TxStmt(txn, s.insertPreviousEventStmt).Exec(previousEventID, previousEventReferenceSHA256, int64(eventNID)) ctx context.Context,
txn *sql.Tx,
previousEventID string,
previousEventReferenceSHA256 []byte,
eventNID types.EventNID,
) error {
stmt := common.TxStmt(txn, s.insertPreviousEventStmt)
_, err := stmt.ExecContext(
ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
)
return err return err
} }
// Check if the event reference exists // Check if the event reference exists
// Returns sql.ErrNoRows if the event reference doesn't exist. // Returns sql.ErrNoRows if the event reference doesn't exist.
func (s *previousEventStatements) selectPreviousEventExists(txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error { func (s *previousEventStatements) selectPreviousEventExists(
ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte,
) error {
var ok int64 var ok int64
return common.TxStmt(txn, s.selectPreviousEventExistsStmt).QueryRow(eventID, eventReferenceSHA256).Scan(&ok) stmt := common.TxStmt(txn, s.selectPreviousEventExistsStmt)
return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok)
} }

View file

@ -15,6 +15,7 @@
package storage package storage
import ( import (
"context"
"database/sql" "database/sql"
) )
@ -62,22 +63,28 @@ func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) {
}.prepare(db) }.prepare(db)
} }
func (s *roomAliasesStatements) insertRoomAlias(alias string, roomID string) (err error) { func (s *roomAliasesStatements) insertRoomAlias(
_, err = s.insertRoomAliasStmt.Exec(alias, roomID) ctx context.Context, alias string, roomID string,
) (err error) {
_, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID)
return return
} }
func (s *roomAliasesStatements) selectRoomIDFromAlias(alias string) (roomID string, err error) { func (s *roomAliasesStatements) selectRoomIDFromAlias(
err = s.selectRoomIDFromAliasStmt.QueryRow(alias).Scan(&roomID) ctx context.Context, alias string,
) (roomID string, err error) {
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", nil
} }
return return
} }
func (s *roomAliasesStatements) selectAliasesFromRoomID(roomID string) (aliases []string, err error) { func (s *roomAliasesStatements) selectAliasesFromRoomID(
ctx context.Context, roomID string,
) (aliases []string, err error) {
aliases = []string{} aliases = []string{}
rows, err := s.selectAliasesFromRoomIDStmt.Query(roomID) rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
if err != nil { if err != nil {
return return
} }
@ -94,7 +101,9 @@ func (s *roomAliasesStatements) selectAliasesFromRoomID(roomID string) (aliases
return return
} }
func (s *roomAliasesStatements) deleteRoomAlias(alias string) (err error) { func (s *roomAliasesStatements) deleteRoomAlias(
_, err = s.deleteRoomAliasStmt.Exec(alias) ctx context.Context, alias string,
) (err error) {
_, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias)
return return
} }

View file

@ -15,6 +15,7 @@
package storage package storage
import ( import (
"context"
"database/sql" "database/sql"
"github.com/lib/pq" "github.com/lib/pq"
@ -81,22 +82,31 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
}.prepare(db) }.prepare(db)
} }
func (s *roomStatements) insertRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) { func (s *roomStatements) insertRoomNID(
ctx context.Context, txn *sql.Tx, roomID string,
) (types.RoomNID, error) {
var roomNID int64 var roomNID int64
err := common.TxStmt(txn, s.insertRoomNIDStmt).QueryRow(roomID).Scan(&roomNID) stmt := common.TxStmt(txn, s.insertRoomNIDStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
return types.RoomNID(roomNID), err return types.RoomNID(roomNID), err
} }
func (s *roomStatements) selectRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) { func (s *roomStatements) selectRoomNID(
ctx context.Context, txn *sql.Tx, roomID string,
) (types.RoomNID, error) {
var roomNID int64 var roomNID int64
err := common.TxStmt(txn, s.selectRoomNIDStmt).QueryRow(roomID).Scan(&roomNID) stmt := common.TxStmt(txn, s.selectRoomNIDStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
return types.RoomNID(roomNID), err return types.RoomNID(roomNID), err
} }
func (s *roomStatements) selectLatestEventNIDs(roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) { func (s *roomStatements) selectLatestEventNIDs(
ctx context.Context, roomNID types.RoomNID,
) ([]types.EventNID, types.StateSnapshotNID, error) {
var nids pq.Int64Array var nids pq.Int64Array
var stateSnapshotNID int64 var stateSnapshotNID int64
err := s.selectLatestEventNIDsStmt.QueryRow(int64(roomNID)).Scan(&nids, &stateSnapshotNID) stmt := s.selectLatestEventNIDsStmt
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
@ -107,13 +117,14 @@ func (s *roomStatements) selectLatestEventNIDs(roomNID types.RoomNID) ([]types.E
return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil
} }
func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID types.RoomNID) ( func (s *roomStatements) selectLatestEventsNIDsForUpdate(
[]types.EventNID, types.EventNID, types.StateSnapshotNID, error, ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
) { ) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) {
var nids pq.Int64Array var nids pq.Int64Array
var lastEventSentNID int64 var lastEventSentNID int64
var stateSnapshotNID int64 var stateSnapshotNID int64
err := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID) stmt := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt)
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -125,11 +136,20 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID ty
} }
func (s *roomStatements) updateLatestEventNIDs( func (s *roomStatements) updateLatestEventNIDs(
txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, ctx context.Context,
txn *sql.Tx,
roomNID types.RoomNID,
eventNIDs []types.EventNID,
lastEventSentNID types.EventNID,
stateSnapshotNID types.StateSnapshotNID, stateSnapshotNID types.StateSnapshotNID,
) error { ) error {
_, err := common.TxStmt(txn, s.updateLatestEventNIDsStmt).Exec( stmt := common.TxStmt(txn, s.updateLatestEventNIDsStmt)
roomNID, eventNIDsAsArray(eventNIDs), int64(lastEventSentNID), int64(stateSnapshotNID), _, err := stmt.ExecContext(
ctx,
roomNID,
eventNIDsAsArray(eventNIDs),
int64(lastEventSentNID),
int64(stateSnapshotNID),
) )
return err return err
} }

View file

@ -15,6 +15,7 @@
package storage package storage
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"sort" "sort"
@ -97,9 +98,14 @@ func (s *stateBlockStatements) prepare(db *sql.DB) (err error) {
}.prepare(db) }.prepare(db)
} }
func (s *stateBlockStatements) bulkInsertStateData(stateBlockNID types.StateBlockNID, entries []types.StateEntry) error { func (s *stateBlockStatements) bulkInsertStateData(
ctx context.Context,
stateBlockNID types.StateBlockNID,
entries []types.StateEntry,
) error {
for _, entry := range entries { for _, entry := range entries {
_, err := s.insertStateDataStmt.Exec( _, err := s.insertStateDataStmt.ExecContext(
ctx,
int64(stateBlockNID), int64(stateBlockNID),
int64(entry.EventTypeNID), int64(entry.EventTypeNID),
int64(entry.EventStateKeyNID), int64(entry.EventStateKeyNID),
@ -112,18 +118,22 @@ func (s *stateBlockStatements) bulkInsertStateData(stateBlockNID types.StateBloc
return nil return nil
} }
func (s *stateBlockStatements) selectNextStateBlockNID() (types.StateBlockNID, error) { func (s *stateBlockStatements) selectNextStateBlockNID(
ctx context.Context,
) (types.StateBlockNID, error) {
var stateBlockNID int64 var stateBlockNID int64
err := s.selectNextStateBlockNIDStmt.QueryRow().Scan(&stateBlockNID) err := s.selectNextStateBlockNIDStmt.QueryRowContext(ctx).Scan(&stateBlockNID)
return types.StateBlockNID(stateBlockNID), err return types.StateBlockNID(stateBlockNID), err
} }
func (s *stateBlockStatements) bulkSelectStateBlockEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) { func (s *stateBlockStatements) bulkSelectStateBlockEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
nids := make([]int64, len(stateBlockNIDs)) nids := make([]int64, len(stateBlockNIDs))
for i := range stateBlockNIDs { for i := range stateBlockNIDs {
nids[i] = int64(stateBlockNIDs[i]) nids[i] = int64(stateBlockNIDs[i])
} }
rows, err := s.bulkSelectStateBlockEntriesStmt.Query(pq.Int64Array(nids)) rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, pq.Int64Array(nids))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -165,15 +175,20 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(stateBlockNIDs []type
} }
func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, ctx context.Context,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) { ) ([]types.StateEntryList, error) {
tuples := stateKeyTupleSorter(stateKeyTuples) tuples := stateKeyTupleSorter(stateKeyTuples)
// Sort the tuples so that we can run binary search against them as we filter the rows returned by the db. // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db.
sort.Sort(tuples) sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.Query( rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.QueryContext(
stateBlockNIDsAsArray(stateBlockNIDs), eventTypeNIDArray, eventStateKeyNIDArray, ctx,
stateBlockNIDsAsArray(stateBlockNIDs),
eventTypeNIDArray,
eventStateKeyNIDArray,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -15,29 +15,30 @@
package storage package storage
import ( import (
"github.com/matrix-org/dendrite/roomserver/types"
"sort" "sort"
"testing" "testing"
"github.com/matrix-org/dendrite/roomserver/types"
) )
func TestStateKeyTupleSorter(t *testing.T) { func TestStateKeyTupleSorter(t *testing.T) {
input := stateKeyTupleSorter{ input := stateKeyTupleSorter{
{1, 2}, {EventTypeNID: 1, EventStateKeyNID: 2},
{1, 4}, {EventTypeNID: 1, EventStateKeyNID: 4},
{2, 2}, {EventTypeNID: 2, EventStateKeyNID: 2},
{1, 1}, {EventTypeNID: 1, EventStateKeyNID: 1},
} }
want := []types.StateKeyTuple{ want := []types.StateKeyTuple{
{1, 1}, {EventTypeNID: 1, EventStateKeyNID: 1},
{1, 2}, {EventTypeNID: 1, EventStateKeyNID: 2},
{1, 4}, {EventTypeNID: 1, EventStateKeyNID: 4},
{2, 2}, {EventTypeNID: 2, EventStateKeyNID: 2},
} }
doNotWant := []types.StateKeyTuple{ doNotWant := []types.StateKeyTuple{
{0, 0}, {EventTypeNID: 0, EventStateKeyNID: 0},
{1, 3}, {EventTypeNID: 1, EventStateKeyNID: 3},
{2, 1}, {EventTypeNID: 2, EventStateKeyNID: 1},
{3, 1}, {EventTypeNID: 3, EventStateKeyNID: 1},
} }
wantTypeNIDs := []int64{1, 2} wantTypeNIDs := []int64{1, 2}
wantStateKeyNIDs := []int64{1, 2, 4} wantStateKeyNIDs := []int64{1, 2, 4}

View file

@ -15,6 +15,7 @@
package storage package storage
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
@ -74,21 +75,25 @@ func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) {
}.prepare(db) }.prepare(db)
} }
func (s *stateSnapshotStatements) insertState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error) { func (s *stateSnapshotStatements) insertState(
ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
) (stateNID types.StateSnapshotNID, err error) {
nids := make([]int64, len(stateBlockNIDs)) nids := make([]int64, len(stateBlockNIDs))
for i := range stateBlockNIDs { for i := range stateBlockNIDs {
nids[i] = int64(stateBlockNIDs[i]) nids[i] = int64(stateBlockNIDs[i])
} }
err = s.insertStateStmt.QueryRow(int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) err = s.insertStateStmt.QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
return return
} }
func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) { func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
nids := make([]int64, len(stateNIDs)) nids := make([]int64, len(stateNIDs))
for i := range stateNIDs { for i := range stateNIDs {
nids[i] = int64(stateNIDs[i]) nids[i] = int64(stateNIDs[i])
} }
rows, err := s.bulkSelectStateBlockNIDsStmt.Query(pq.Int64Array(nids)) rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -15,6 +15,7 @@
package storage package storage
import ( import (
"context"
"database/sql" "database/sql"
// Import the postgres database driver. // Import the postgres database driver.
@ -43,7 +44,9 @@ func Open(dataSourceName string) (*Database, error) {
} }
// StoreEvent implements input.EventDatabase // StoreEvent implements input.EventDatabase
func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) { func (d *Database) StoreEvent(
ctx context.Context, event gomatrixserverlib.Event, authEventNIDs []types.EventNID,
) (types.RoomNID, types.StateAtEvent, error) {
var ( var (
roomNID types.RoomNID roomNID types.RoomNID
eventTypeNID types.EventTypeNID eventTypeNID types.EventTypeNID
@ -53,11 +56,11 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ
err error err error
) )
if roomNID, err = d.assignRoomNID(nil, event.RoomID()); err != nil { if roomNID, err = d.assignRoomNID(ctx, nil, event.RoomID()); err != nil {
return 0, types.StateAtEvent{}, err return 0, types.StateAtEvent{}, err
} }
if eventTypeNID, err = d.assignEventTypeNID(event.Type()); err != nil { if eventTypeNID, err = d.assignEventTypeNID(ctx, event.Type()); err != nil {
return 0, types.StateAtEvent{}, err return 0, types.StateAtEvent{}, err
} }
@ -65,12 +68,13 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ
// Assigned a numeric ID for the state_key if there is one present. // Assigned a numeric ID for the state_key if there is one present.
// Otherwise set the numeric ID for the state_key to 0. // Otherwise set the numeric ID for the state_key to 0.
if eventStateKey != nil { if eventStateKey != nil {
if eventStateKeyNID, err = d.assignStateKeyNID(nil, *eventStateKey); err != nil { if eventStateKeyNID, err = d.assignStateKeyNID(ctx, nil, *eventStateKey); err != nil {
return 0, types.StateAtEvent{}, err return 0, types.StateAtEvent{}, err
} }
} }
if eventNID, stateNID, err = d.statements.insertEvent( if eventNID, stateNID, err = d.statements.insertEvent(
ctx,
roomNID, roomNID,
eventTypeNID, eventTypeNID,
eventStateKeyNID, eventStateKeyNID,
@ -81,14 +85,14 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ
); err != nil { ); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// We've already inserted the event so select the numeric event ID // We've already inserted the event so select the numeric event ID
eventNID, stateNID, err = d.statements.selectEvent(event.EventID()) eventNID, stateNID, err = d.statements.selectEvent(ctx, event.EventID())
} }
if err != nil { if err != nil {
return 0, types.StateAtEvent{}, err return 0, types.StateAtEvent{}, err
} }
} }
if err = d.statements.insertEventJSON(eventNID, event.JSON()); err != nil { if err = d.statements.insertEventJSON(ctx, eventNID, event.JSON()); err != nil {
return 0, types.StateAtEvent{}, err return 0, types.StateAtEvent{}, err
} }
@ -104,76 +108,94 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ
}, nil }, nil
} }
func (d *Database) assignRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) { func (d *Database) assignRoomNID(
ctx context.Context, txn *sql.Tx, roomID string,
) (types.RoomNID, error) {
// Check if we already have a numeric ID in the database. // Check if we already have a numeric ID in the database.
roomNID, err := d.statements.selectRoomNID(txn, roomID) roomNID, err := d.statements.selectRoomNID(ctx, txn, roomID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database. // We don't have a numeric ID so insert one into the database.
roomNID, err = d.statements.insertRoomNID(txn, roomID) roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// We raced with another insert so run the select again. // We raced with another insert so run the select again.
roomNID, err = d.statements.selectRoomNID(txn, roomID) roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
} }
} }
return roomNID, err return roomNID, err
} }
func (d *Database) assignEventTypeNID(eventType string) (types.EventTypeNID, error) { func (d *Database) assignEventTypeNID(
ctx context.Context, eventType string,
) (types.EventTypeNID, error) {
// Check if we already have a numeric ID in the database. // Check if we already have a numeric ID in the database.
eventTypeNID, err := d.statements.selectEventTypeNID(eventType) eventTypeNID, err := d.statements.selectEventTypeNID(ctx, eventType)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database. // We don't have a numeric ID so insert one into the database.
eventTypeNID, err = d.statements.insertEventTypeNID(eventType) eventTypeNID, err = d.statements.insertEventTypeNID(ctx, eventType)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// We raced with another insert so run the select again. // We raced with another insert so run the select again.
eventTypeNID, err = d.statements.selectEventTypeNID(eventType) eventTypeNID, err = d.statements.selectEventTypeNID(ctx, eventType)
} }
} }
return eventTypeNID, err return eventTypeNID, err
} }
func (d *Database) assignStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) { func (d *Database) assignStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) {
// Check if we already have a numeric ID in the database. // Check if we already have a numeric ID in the database.
eventStateKeyNID, err := d.statements.selectEventStateKeyNID(txn, eventStateKey) eventStateKeyNID, err := d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database. // We don't have a numeric ID so insert one into the database.
eventStateKeyNID, err = d.statements.insertEventStateKeyNID(txn, eventStateKey) eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// We raced with another insert so run the select again. // We raced with another insert so run the select again.
eventStateKeyNID, err = d.statements.selectEventStateKeyNID(txn, eventStateKey) eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey)
} }
} }
return eventStateKeyNID, err return eventStateKeyNID, err
} }
// StateEntriesForEventIDs implements input.EventDatabase // StateEntriesForEventIDs implements input.EventDatabase
func (d *Database) StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error) { func (d *Database) StateEntriesForEventIDs(
return d.statements.bulkSelectStateEventByID(eventIDs) ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error) {
return d.statements.bulkSelectStateEventByID(ctx, eventIDs)
} }
// EventTypeNIDs implements state.RoomStateDatabase // EventTypeNIDs implements state.RoomStateDatabase
func (d *Database) EventTypeNIDs(eventTypes []string) (map[string]types.EventTypeNID, error) { func (d *Database) EventTypeNIDs(
return d.statements.bulkSelectEventTypeNID(eventTypes) ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
return d.statements.bulkSelectEventTypeNID(ctx, eventTypes)
} }
// EventStateKeyNIDs implements state.RoomStateDatabase // EventStateKeyNIDs implements state.RoomStateDatabase
func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) { func (d *Database) EventStateKeyNIDs(
return d.statements.bulkSelectEventStateKeyNID(eventStateKeys) ctx context.Context, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
return d.statements.bulkSelectEventStateKeyNID(ctx, eventStateKeys)
} }
// EventStateKeys implements query.RoomserverQueryAPIDatabase // EventStateKeys implements query.RoomserverQueryAPIDatabase
func (d *Database) EventStateKeys(eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) { func (d *Database) EventStateKeys(
return d.statements.bulkSelectEventStateKey(eventStateKeyNIDs) ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
return d.statements.bulkSelectEventStateKey(ctx, eventStateKeyNIDs)
} }
// EventNIDs implements query.RoomserverQueryAPIDatabase // EventNIDs implements query.RoomserverQueryAPIDatabase
func (d *Database) EventNIDs(eventIDs []string) (map[string]types.EventNID, error) { func (d *Database) EventNIDs(
return d.statements.bulkSelectEventNID(eventIDs) ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) {
return d.statements.bulkSelectEventNID(ctx, eventIDs)
} }
// Events implements input.EventDatabase // Events implements input.EventDatabase
func (d *Database) Events(eventNIDs []types.EventNID) ([]types.Event, error) { func (d *Database) Events(
eventJSONs, err := d.statements.bulkSelectEventJSON(eventNIDs) ctx context.Context, eventNIDs []types.EventNID,
) ([]types.Event, error) {
eventJSONs, err := d.statements.bulkSelectEventJSON(ctx, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -191,78 +213,98 @@ func (d *Database) Events(eventNIDs []types.EventNID) ([]types.Event, error) {
} }
// AddState implements input.EventDatabase // AddState implements input.EventDatabase
func (d *Database) AddState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) { func (d *Database) AddState(
ctx context.Context,
roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (types.StateSnapshotNID, error) {
if len(state) > 0 { if len(state) > 0 {
stateBlockNID, err := d.statements.selectNextStateBlockNID() stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if err = d.statements.bulkInsertStateData(stateBlockNID, state); err != nil { if err = d.statements.bulkInsertStateData(ctx, stateBlockNID, state); err != nil {
return 0, err return 0, err
} }
stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
} }
return d.statements.insertState(roomNID, stateBlockNIDs) return d.statements.insertState(ctx, roomNID, stateBlockNIDs)
} }
// SetState implements input.EventDatabase // SetState implements input.EventDatabase
func (d *Database) SetState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error { func (d *Database) SetState(
return d.statements.updateEventState(eventNID, stateNID) ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error {
return d.statements.updateEventState(ctx, eventNID, stateNID)
} }
// StateAtEventIDs implements input.EventDatabase // StateAtEventIDs implements input.EventDatabase
func (d *Database) StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, error) { func (d *Database) StateAtEventIDs(
return d.statements.bulkSelectStateAtEventByID(eventIDs) ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) {
return d.statements.bulkSelectStateAtEventByID(ctx, eventIDs)
} }
// StateBlockNIDs implements state.RoomStateDatabase // StateBlockNIDs implements state.RoomStateDatabase
func (d *Database) StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) { func (d *Database) StateBlockNIDs(
return d.statements.bulkSelectStateBlockNIDs(stateNIDs) ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
return d.statements.bulkSelectStateBlockNIDs(ctx, stateNIDs)
} }
// StateEntries implements state.RoomStateDatabase // StateEntries implements state.RoomStateDatabase
func (d *Database) StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) { func (d *Database) StateEntries(
return d.statements.bulkSelectStateBlockEntries(stateBlockNIDs) ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
return d.statements.bulkSelectStateBlockEntries(ctx, stateBlockNIDs)
} }
// SnapshotNIDFromEventID implements state.RoomStateDatabase // SnapshotNIDFromEventID implements state.RoomStateDatabase
func (d *Database) SnapshotNIDFromEventID(eventID string) (types.StateSnapshotNID, error) { func (d *Database) SnapshotNIDFromEventID(
_, stateNID, err := d.statements.selectEvent(eventID) ctx context.Context, eventID string,
) (types.StateSnapshotNID, error) {
_, stateNID, err := d.statements.selectEvent(ctx, eventID)
return stateNID, err return stateNID, err
} }
// EventIDs implements input.RoomEventDatabase // EventIDs implements input.RoomEventDatabase
func (d *Database) EventIDs(eventNIDs []types.EventNID) (map[types.EventNID]string, error) { func (d *Database) EventIDs(
return d.statements.bulkSelectEventID(eventNIDs) ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error) {
return d.statements.bulkSelectEventID(ctx, eventNIDs)
} }
// GetLatestEventsForUpdate implements input.EventDatabase // GetLatestEventsForUpdate implements input.EventDatabase
func (d *Database) GetLatestEventsForUpdate(roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error) { func (d *Database) GetLatestEventsForUpdate(
ctx context.Context, roomNID types.RoomNID,
) (types.RoomRecentEventsUpdater, error) {
txn, err := d.db.Begin() txn, err := d.db.Begin()
if err != nil { if err != nil {
return nil, err return nil, err
} }
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := d.statements.selectLatestEventsNIDsForUpdate(txn, roomNID) eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID)
if err != nil { if err != nil {
txn.Rollback() txn.Rollback()
return nil, err return nil, err
} }
stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(txn, eventNIDs) stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
if err != nil { if err != nil {
txn.Rollback() txn.Rollback()
return nil, err return nil, err
} }
var lastEventIDSent string var lastEventIDSent string
if lastEventNIDSent != 0 { if lastEventNIDSent != 0 {
lastEventIDSent, err = d.statements.selectEventID(txn, lastEventNIDSent) lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent)
if err != nil { if err != nil {
txn.Rollback() txn.Rollback()
return nil, err return nil, err
} }
} }
return &roomRecentEventsUpdater{ return &roomRecentEventsUpdater{
transaction{txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
}, nil }, nil
} }
@ -293,7 +335,7 @@ func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotN
// StorePreviousEvents implements types.RoomRecentEventsUpdater // StorePreviousEvents implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
for _, ref := range previousEventReferences { for _, ref := range previousEventReferences {
if err := u.d.statements.insertPreviousEvent(u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { if err := u.d.statements.insertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
return err return err
} }
} }
@ -302,7 +344,7 @@ func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, p
// IsReferenced implements types.RoomRecentEventsUpdater // IsReferenced implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
err := u.d.statements.selectPreviousEventExists(u.txn, eventReference.EventID, eventReference.EventSHA256) err := u.d.statements.selectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
if err == nil { if err == nil {
return true, nil return true, nil
} }
@ -321,26 +363,26 @@ func (u *roomRecentEventsUpdater) SetLatestEvents(
for i := range latest { for i := range latest {
eventNIDs[i] = latest[i].EventNID eventNIDs[i] = latest[i].EventNID
} }
return u.d.statements.updateLatestEventNIDs(u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) return u.d.statements.updateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
} }
// HasEventBeenSent implements types.RoomRecentEventsUpdater // HasEventBeenSent implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
return u.d.statements.selectEventSentToOutput(u.txn, eventNID) return u.d.statements.selectEventSentToOutput(u.ctx, u.txn, eventNID)
} }
// MarkEventAsSent implements types.RoomRecentEventsUpdater // MarkEventAsSent implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
return u.d.statements.updateEventSentToOutput(u.txn, eventNID) return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID)
} }
func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) { func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) {
return u.d.membershipUpdaterTxn(u.txn, u.roomNID, targetUserNID) return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID)
} }
// RoomNID implements query.RoomserverQueryAPIDB // RoomNID implements query.RoomserverQueryAPIDB
func (d *Database) RoomNID(roomID string) (types.RoomNID, error) { func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) {
roomNID, err := d.statements.selectRoomNID(nil, roomID) roomNID, err := d.statements.selectRoomNID(ctx, nil, roomID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return 0, nil return 0, nil
} }
@ -348,16 +390,18 @@ func (d *Database) RoomNID(roomID string) (types.RoomNID, error) {
} }
// LatestEventIDs implements query.RoomserverQueryAPIDatabase // LatestEventIDs implements query.RoomserverQueryAPIDatabase
func (d *Database) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) { func (d *Database) LatestEventIDs(
eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(roomNID) ctx context.Context, roomNID types.RoomNID,
) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) {
eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(ctx, roomNID)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
references, err := d.statements.bulkSelectEventReference(eventNIDs) references, err := d.statements.bulkSelectEventReference(ctx, eventNIDs)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
depth, err := d.statements.selectMaxEventDepth(eventNIDs) depth, err := d.statements.selectMaxEventDepth(ctx, eventNIDs)
if err != nil { if err != nil {
return nil, 0, 0, err return nil, 0, 0, err
} }
@ -366,40 +410,48 @@ func (d *Database) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.Ev
// GetInvitesForUser implements query.RoomserverQueryAPIDatabase // GetInvitesForUser implements query.RoomserverQueryAPIDatabase
func (d *Database) GetInvitesForUser( func (d *Database) GetInvitesForUser(
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ctx context.Context,
roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID,
) (senderUserIDs []types.EventStateKeyNID, err error) { ) (senderUserIDs []types.EventStateKeyNID, err error) {
return d.statements.selectInviteActiveForUserInRoom(targetUserNID, roomNID) return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
} }
// SetRoomAlias implements alias.RoomserverAliasAPIDB // SetRoomAlias implements alias.RoomserverAliasAPIDB
func (d *Database) SetRoomAlias(alias string, roomID string) error { func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string) error {
return d.statements.insertRoomAlias(alias, roomID) return d.statements.insertRoomAlias(ctx, alias, roomID)
} }
// GetRoomIDFromAlias implements alias.RoomserverAliasAPIDB // GetRoomIDFromAlias implements alias.RoomserverAliasAPIDB
func (d *Database) GetRoomIDFromAlias(alias string) (string, error) { func (d *Database) GetRoomIDFromAlias(ctx context.Context, alias string) (string, error) {
return d.statements.selectRoomIDFromAlias(alias) return d.statements.selectRoomIDFromAlias(ctx, alias)
} }
// GetAliasesFromRoomID implements alias.RoomserverAliasAPIDB // GetAliasesFromRoomID implements alias.RoomserverAliasAPIDB
func (d *Database) GetAliasesFromRoomID(roomID string) ([]string, error) { func (d *Database) GetAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error) {
return d.statements.selectAliasesFromRoomID(roomID) return d.statements.selectAliasesFromRoomID(ctx, roomID)
} }
// RemoveRoomAlias implements alias.RoomserverAliasAPIDB // RemoveRoomAlias implements alias.RoomserverAliasAPIDB
func (d *Database) RemoveRoomAlias(alias string) error { func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
return d.statements.deleteRoomAlias(alias) return d.statements.deleteRoomAlias(ctx, alias)
} }
// StateEntriesForTuples implements state.RoomStateDatabase // StateEntriesForTuples implements state.RoomStateDatabase
func (d *Database) StateEntriesForTuples( func (d *Database) StateEntriesForTuples(
stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, ctx context.Context,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) { ) ([]types.StateEntryList, error) {
return d.statements.bulkSelectFilteredStateBlockEntries(stateBlockNIDs, stateKeyTuples) return d.statements.bulkSelectFilteredStateBlockEntries(
ctx, stateBlockNIDs, stateKeyTuples,
)
} }
// MembershipUpdater implements input.RoomEventDatabase // MembershipUpdater implements input.RoomEventDatabase
func (d *Database) MembershipUpdater(roomID, targetUserID string) (types.MembershipUpdater, error) { func (d *Database) MembershipUpdater(
ctx context.Context, roomID, targetUserID string,
) (types.MembershipUpdater, error) {
txn, err := d.db.Begin() txn, err := d.db.Begin()
if err != nil { if err != nil {
return nil, err return nil, err
@ -411,17 +463,17 @@ func (d *Database) MembershipUpdater(roomID, targetUserID string) (types.Members
} }
}() }()
roomNID, err := d.assignRoomNID(txn, roomID) roomNID, err := d.assignRoomNID(ctx, txn, roomID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetUserNID, err := d.assignStateKeyNID(txn, targetUserID) targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
updater, err := d.membershipUpdaterTxn(txn, roomNID, targetUserNID) updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -439,20 +491,23 @@ type membershipUpdater struct {
} }
func (d *Database) membershipUpdaterTxn( func (d *Database) membershipUpdaterTxn(
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ctx context.Context,
txn *sql.Tx,
roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID,
) (types.MembershipUpdater, error) { ) (types.MembershipUpdater, error) {
if err := d.statements.insertMembership(txn, roomNID, targetUserNID); err != nil { if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil {
return nil, err return nil, err
} }
membership, err := d.statements.selectMembershipForUpdate(txn, roomNID, targetUserNID) membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &membershipUpdater{ return &membershipUpdater{
transaction{txn}, d, roomNID, targetUserNID, membership, transaction{ctx, txn}, d, roomNID, targetUserNID, membership,
}, nil }, nil
} }
@ -473,19 +528,19 @@ func (u *membershipUpdater) IsLeave() bool {
// SetToInvite implements types.MembershipUpdater // SetToInvite implements types.MembershipUpdater
func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) {
senderUserNID, err := u.d.assignStateKeyNID(u.txn, event.Sender()) senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender())
if err != nil { if err != nil {
return false, err return false, err
} }
inserted, err := u.d.statements.insertInviteEvent( inserted, err := u.d.statements.insertInviteEvent(
u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
) )
if err != nil { if err != nil {
return false, err return false, err
} }
if u.membership != membershipStateInvite { if u.membership != membershipStateInvite {
if err = u.d.statements.updateMembership( if err = u.d.statements.updateMembership(
u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0,
); err != nil { ); err != nil {
return false, err return false, err
} }
@ -497,7 +552,7 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er
func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) { func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) {
var inviteEventIDs []string var inviteEventIDs []string
senderUserNID, err := u.d.assignStateKeyNID(u.txn, senderUserID) senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -505,7 +560,7 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
// If this is a join event update, there is no invite to update // If this is a join event update, there is no invite to update
if !isUpdate { if !isUpdate {
inviteEventIDs, err = u.d.statements.updateInviteRetired( inviteEventIDs, err = u.d.statements.updateInviteRetired(
u.txn, u.roomNID, u.targetUserNID, u.ctx, u.txn, u.roomNID, u.targetUserNID,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -513,14 +568,15 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
} }
// Look up the NID of the new join event // Look up the NID of the new join event
nIDs, err := u.d.EventNIDs([]string{eventID}) nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if u.membership != membershipStateJoin || isUpdate { if u.membership != membershipStateJoin || isUpdate {
if err = u.d.statements.updateMembership( if err = u.d.statements.updateMembership(
u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateJoin, nIDs[eventID], u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
membershipStateJoin, nIDs[eventID],
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@ -531,26 +587,27 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
// SetToLeave implements types.MembershipUpdater // SetToLeave implements types.MembershipUpdater
func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) {
senderUserNID, err := u.d.assignStateKeyNID(u.txn, senderUserID) senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
inviteEventIDs, err := u.d.statements.updateInviteRetired( inviteEventIDs, err := u.d.statements.updateInviteRetired(
u.txn, u.roomNID, u.targetUserNID, u.ctx, u.txn, u.roomNID, u.targetUserNID,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Look up the NID of the new leave event // Look up the NID of the new leave event
nIDs, err := u.d.EventNIDs([]string{eventID}) nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if u.membership != membershipStateLeaveOrBan { if u.membership != membershipStateLeaveOrBan {
if err = u.d.statements.updateMembership( if err = u.d.statements.updateMembership(
u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateLeaveOrBan, nIDs[eventID], u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
membershipStateLeaveOrBan, nIDs[eventID],
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@ -559,19 +616,18 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
} }
// GetMembership implements query.RoomserverQueryAPIDB // GetMembership implements query.RoomserverQueryAPIDB
func (d *Database) GetMembership(roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) { func (d *Database) GetMembership(
txn, err := d.db.Begin() ctx context.Context, roomNID types.RoomNID, requestSenderUserID string,
if err != nil { ) (membershipEventNID types.EventNID, stillInRoom bool, err error) {
return requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID)
}
defer txn.Commit()
requestSenderUserNID, err := d.assignStateKeyNID(txn, requestSenderUserID)
if err != nil { if err != nil {
return return
} }
senderMembershipEventNID, senderMembership, err := d.statements.selectMembershipFromRoomAndTarget(roomNID, requestSenderUserNID) senderMembershipEventNID, senderMembership, err :=
d.statements.selectMembershipFromRoomAndTarget(
ctx, roomNID, requestSenderUserNID,
)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// The user has never been a member of that room // The user has never been a member of that room
return 0, false, nil return 0, false, nil
@ -583,15 +639,20 @@ func (d *Database) GetMembership(roomNID types.RoomNID, requestSenderUserID stri
} }
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB // GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
func (d *Database) GetMembershipEventNIDsForRoom(roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error) { func (d *Database) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool,
) ([]types.EventNID, error) {
if joinOnly { if joinOnly {
return d.statements.selectMembershipsFromRoomAndMembership(roomNID, membershipStateJoin) return d.statements.selectMembershipsFromRoomAndMembership(
ctx, roomNID, membershipStateJoin,
)
} }
return d.statements.selectMembershipsFromRoom(roomNID) return d.statements.selectMembershipsFromRoom(ctx, roomNID)
} }
type transaction struct { type transaction struct {
ctx context.Context
txn *sql.Tx txn *sql.Tx
} }