mirror of
https://github.com/matrix-org/dendrite
synced 2024-12-14 17:03:49 +01:00
Add contexts to the roomserver storage layer (#229)
* Add contexts to the roomserver storage layer * Fix rooms_table
This commit is contained in:
parent
3133bef797
commit
bfcce5bd21
21 changed files with 744 additions and 379 deletions
|
@ -32,16 +32,16 @@ import (
|
||||||
type RoomserverAliasAPIDatabase interface {
|
type RoomserverAliasAPIDatabase interface {
|
||||||
// Save a given room alias with the room ID it refers to.
|
// Save a given room alias with the room ID it refers to.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
SetRoomAlias(alias string, roomID string) error
|
SetRoomAlias(ctx context.Context, alias string, roomID string) error
|
||||||
// Look up the room ID a given alias refers to.
|
// Look up the room ID a given alias refers to.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
GetRoomIDFromAlias(alias string) (string, error)
|
GetRoomIDFromAlias(ctx context.Context, alias string) (string, error)
|
||||||
// Look up all aliases referring to a given room ID.
|
// Look up all aliases referring to a given room ID.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
GetAliasesFromRoomID(roomID string) ([]string, error)
|
GetAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error)
|
||||||
// Remove a given room alias.
|
// Remove a given room alias.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
RemoveRoomAlias(alias string) error
|
RemoveRoomAlias(ctx context.Context, alias string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoomserverAliasAPI is an implementation of api.RoomserverAliasAPI
|
// RoomserverAliasAPI is an implementation of api.RoomserverAliasAPI
|
||||||
|
@ -59,7 +59,7 @@ func (r *RoomserverAliasAPI) SetRoomAlias(
|
||||||
response *api.SetRoomAliasResponse,
|
response *api.SetRoomAliasResponse,
|
||||||
) error {
|
) error {
|
||||||
// Check if the alias isn't already referring to a room
|
// Check if the alias isn't already referring to a room
|
||||||
roomID, err := r.DB.GetRoomIDFromAlias(request.Alias)
|
roomID, err := r.DB.GetRoomIDFromAlias(ctx, request.Alias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -71,7 +71,7 @@ func (r *RoomserverAliasAPI) SetRoomAlias(
|
||||||
response.AliasExists = false
|
response.AliasExists = false
|
||||||
|
|
||||||
// Save the new alias
|
// Save the new alias
|
||||||
if err := r.DB.SetRoomAlias(request.Alias, request.RoomID); err != nil {
|
if err := r.DB.SetRoomAlias(ctx, request.Alias, request.RoomID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ func (r *RoomserverAliasAPI) GetAliasRoomID(
|
||||||
response *api.GetAliasRoomIDResponse,
|
response *api.GetAliasRoomIDResponse,
|
||||||
) error {
|
) error {
|
||||||
// Look up the room ID in the database
|
// Look up the room ID in the database
|
||||||
roomID, err := r.DB.GetRoomIDFromAlias(request.Alias)
|
roomID, err := r.DB.GetRoomIDFromAlias(ctx, request.Alias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -109,18 +109,21 @@ func (r *RoomserverAliasAPI) RemoveRoomAlias(
|
||||||
response *api.RemoveRoomAliasResponse,
|
response *api.RemoveRoomAliasResponse,
|
||||||
) error {
|
) error {
|
||||||
// Look up the room ID in the database
|
// Look up the room ID in the database
|
||||||
roomID, err := r.DB.GetRoomIDFromAlias(request.Alias)
|
roomID, err := r.DB.GetRoomIDFromAlias(ctx, request.Alias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove the dalias from the database
|
// Remove the dalias from the database
|
||||||
if err := r.DB.RemoveRoomAlias(request.Alias); err != nil {
|
if err := r.DB.RemoveRoomAlias(ctx, request.Alias); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send an updated m.room.aliases event
|
// Send an updated m.room.aliases event
|
||||||
if err := r.sendUpdatedAliasesEvent(ctx, request.UserID, roomID); err != nil {
|
// At this point we've already committed the alias to the database so we
|
||||||
|
// shouldn't cancel this request.
|
||||||
|
// TODO: Ensure that we send unsent events when if server restarts.
|
||||||
|
if err := r.sendUpdatedAliasesEvent(context.TODO(), request.UserID, roomID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,7 +150,7 @@ func (r *RoomserverAliasAPI) sendUpdatedAliasesEvent(
|
||||||
|
|
||||||
// Retrieve the updated list of aliases, marhal it and set it as the
|
// Retrieve the updated list of aliases, marhal it and set it as the
|
||||||
// event's content
|
// event's content
|
||||||
aliases, err := r.DB.GetAliasesFromRoomID(roomID)
|
aliases, err := r.DB.GetAliasesFromRoomID(ctx, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,16 +15,23 @@
|
||||||
package input
|
package input
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"sort"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"sort"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// checkAuthEvents checks that the event passes authentication checks
|
// checkAuthEvents checks that the event passes authentication checks
|
||||||
// Returns the numeric IDs for the auth events.
|
// Returns the numeric IDs for the auth events.
|
||||||
func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEventIDs []string) ([]types.EventNID, error) {
|
func checkAuthEvents(
|
||||||
|
ctx context.Context,
|
||||||
|
db RoomEventDatabase,
|
||||||
|
event gomatrixserverlib.Event,
|
||||||
|
authEventIDs []string,
|
||||||
|
) ([]types.EventNID, error) {
|
||||||
// Grab the numeric IDs for the supplied auth state events from the database.
|
// Grab the numeric IDs for the supplied auth state events from the database.
|
||||||
authStateEntries, err := db.StateEntriesForEventIDs(authEventIDs)
|
authStateEntries, err := db.StateEntriesForEventIDs(ctx, authEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -34,7 +41,7 @@ func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEv
|
||||||
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event})
|
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event})
|
||||||
|
|
||||||
// Load the actual auth events from the database.
|
// Load the actual auth events from the database.
|
||||||
authEvents, err := loadAuthEvents(db, stateNeeded, authStateEntries)
|
authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -84,7 +91,10 @@ func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Even
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *gomatrixserverlib.Event {
|
func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *gomatrixserverlib.Event {
|
||||||
eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, types.EmptyStateKeyNID})
|
eventNID, ok := ae.state.lookup(types.StateKeyTuple{
|
||||||
|
EventTypeNID: typeNID,
|
||||||
|
EventStateKeyNID: types.EmptyStateKeyNID,
|
||||||
|
})
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -100,7 +110,10 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, stateKeyNID})
|
eventNID, ok := ae.state.lookup(types.StateKeyTuple{
|
||||||
|
EventTypeNID: typeNID,
|
||||||
|
EventStateKeyNID: stateKeyNID,
|
||||||
|
})
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -113,6 +126,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
|
||||||
|
|
||||||
// loadAuthEvents loads the events needed for authentication from the supplied room state.
|
// loadAuthEvents loads the events needed for authentication from the supplied room state.
|
||||||
func loadAuthEvents(
|
func loadAuthEvents(
|
||||||
|
ctx context.Context,
|
||||||
db RoomEventDatabase,
|
db RoomEventDatabase,
|
||||||
needed gomatrixserverlib.StateNeeded,
|
needed gomatrixserverlib.StateNeeded,
|
||||||
state []types.StateEntry,
|
state []types.StateEntry,
|
||||||
|
@ -121,7 +135,7 @@ func loadAuthEvents(
|
||||||
var neededStateKeys []string
|
var neededStateKeys []string
|
||||||
neededStateKeys = append(neededStateKeys, needed.Member...)
|
neededStateKeys = append(neededStateKeys, needed.Member...)
|
||||||
neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
|
neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
|
||||||
if result.stateKeyNIDMap, err = db.EventStateKeyNIDs(neededStateKeys); err != nil {
|
if result.stateKeyNIDMap, err = db.EventStateKeyNIDs(ctx, neededStateKeys); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,34 +149,52 @@ func loadAuthEvents(
|
||||||
eventNIDs = append(eventNIDs, eventNID)
|
eventNIDs = append(eventNIDs, eventNID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if result.events, err = db.Events(eventNIDs); err != nil {
|
if result.events, err = db.Events(ctx, eventNIDs); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events.
|
// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events.
|
||||||
func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple {
|
func stateKeyTuplesNeeded(
|
||||||
|
stateKeyNIDMap map[string]types.EventStateKeyNID,
|
||||||
|
stateNeeded gomatrixserverlib.StateNeeded,
|
||||||
|
) []types.StateKeyTuple {
|
||||||
var keyTuples []types.StateKeyTuple
|
var keyTuples []types.StateKeyTuple
|
||||||
if stateNeeded.Create {
|
if stateNeeded.Create {
|
||||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomCreateNID, types.EmptyStateKeyNID})
|
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||||
|
EventTypeNID: types.MRoomCreateNID,
|
||||||
|
EventStateKeyNID: types.EmptyStateKeyNID,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
if stateNeeded.PowerLevels {
|
if stateNeeded.PowerLevels {
|
||||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomPowerLevelsNID, types.EmptyStateKeyNID})
|
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||||
|
EventTypeNID: types.MRoomPowerLevelsNID,
|
||||||
|
EventStateKeyNID: types.EmptyStateKeyNID,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
if stateNeeded.JoinRules {
|
if stateNeeded.JoinRules {
|
||||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomJoinRulesNID, types.EmptyStateKeyNID})
|
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||||
|
EventTypeNID: types.MRoomJoinRulesNID,
|
||||||
|
EventStateKeyNID: types.EmptyStateKeyNID,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
for _, member := range stateNeeded.Member {
|
for _, member := range stateNeeded.Member {
|
||||||
stateKeyNID, ok := stateKeyNIDMap[member]
|
stateKeyNID, ok := stateKeyNIDMap[member]
|
||||||
if ok {
|
if ok {
|
||||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomMemberNID, stateKeyNID})
|
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||||
|
EventTypeNID: types.MRoomMemberNID,
|
||||||
|
EventStateKeyNID: stateKeyNID,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, token := range stateNeeded.ThirdPartyInvite {
|
for _, token := range stateNeeded.ThirdPartyInvite {
|
||||||
stateKeyNID, ok := stateKeyNIDMap[token]
|
stateKeyNID, ok := stateKeyNIDMap[token]
|
||||||
if ok {
|
if ok {
|
||||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomThirdPartyInviteNID, stateKeyNID})
|
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||||
|
EventTypeNID: types.MRoomThirdPartyInviteNID,
|
||||||
|
EventStateKeyNID: stateKeyNID,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return keyTuples
|
return keyTuples
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package input
|
package input
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
@ -28,22 +29,38 @@ import (
|
||||||
type RoomEventDatabase interface {
|
type RoomEventDatabase interface {
|
||||||
state.RoomStateDatabase
|
state.RoomStateDatabase
|
||||||
// Stores a matrix room event in the database
|
// Stores a matrix room event in the database
|
||||||
StoreEvent(event gomatrixserverlib.Event, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error)
|
StoreEvent(
|
||||||
|
ctx context.Context,
|
||||||
|
event gomatrixserverlib.Event,
|
||||||
|
authEventNIDs []types.EventNID,
|
||||||
|
) (types.RoomNID, types.StateAtEvent, error)
|
||||||
// Look up the state entries for a list of string event IDs
|
// Look up the state entries for a list of string event IDs
|
||||||
// Returns an error if the there is an error talking to the database
|
// Returns an error if the there is an error talking to the database
|
||||||
// Returns a types.MissingEventError if the event IDs aren't in the database.
|
// Returns a types.MissingEventError if the event IDs aren't in the database.
|
||||||
StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error)
|
StateEntriesForEventIDs(
|
||||||
|
ctx context.Context, eventIDs []string,
|
||||||
|
) ([]types.StateEntry, error)
|
||||||
// Set the state at an event.
|
// Set the state at an event.
|
||||||
SetState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error
|
SetState(
|
||||||
|
ctx context.Context,
|
||||||
|
eventNID types.EventNID,
|
||||||
|
stateNID types.StateSnapshotNID,
|
||||||
|
) error
|
||||||
// Look up the latest events in a room in preparation for an update.
|
// Look up the latest events in a room in preparation for an update.
|
||||||
// The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error.
|
// The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error.
|
||||||
// Returns the latest events in the room and the last eventID sent to the log along with an updater.
|
// Returns the latest events in the room and the last eventID sent to the log along with an updater.
|
||||||
// If this returns an error then no further action is required.
|
// If this returns an error then no further action is required.
|
||||||
GetLatestEventsForUpdate(roomNID types.RoomNID) (updater types.RoomRecentEventsUpdater, err error)
|
GetLatestEventsForUpdate(
|
||||||
|
ctx context.Context, roomNID types.RoomNID,
|
||||||
|
) (updater types.RoomRecentEventsUpdater, err error)
|
||||||
// Look up the string event IDs for a list of numeric event IDs
|
// Look up the string event IDs for a list of numeric event IDs
|
||||||
EventIDs(eventNIDs []types.EventNID) (map[types.EventNID]string, error)
|
EventIDs(
|
||||||
|
ctx context.Context, eventNIDs []types.EventNID,
|
||||||
|
) (map[types.EventNID]string, error)
|
||||||
// Build a membership updater for the target user in a room.
|
// Build a membership updater for the target user in a room.
|
||||||
MembershipUpdater(roomID, targerUserID string) (types.MembershipUpdater, error)
|
MembershipUpdater(
|
||||||
|
ctx context.Context, roomID, targerUserID string,
|
||||||
|
) (types.MembershipUpdater, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OutputRoomEventWriter has the APIs needed to write an event to the output logs.
|
// OutputRoomEventWriter has the APIs needed to write an event to the output logs.
|
||||||
|
@ -52,18 +69,23 @@ type OutputRoomEventWriter interface {
|
||||||
WriteOutputEvents(roomID string, updates []api.OutputEvent) error
|
WriteOutputEvents(roomID string, updates []api.OutputEvent) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.InputRoomEvent) error {
|
func processRoomEvent(
|
||||||
|
ctx context.Context,
|
||||||
|
db RoomEventDatabase,
|
||||||
|
ow OutputRoomEventWriter,
|
||||||
|
input api.InputRoomEvent,
|
||||||
|
) error {
|
||||||
// Parse and validate the event JSON
|
// Parse and validate the event JSON
|
||||||
event := input.Event
|
event := input.Event
|
||||||
|
|
||||||
// Check that the event passes authentication checks and work out the numeric IDs for the auth events.
|
// Check that the event passes authentication checks and work out the numeric IDs for the auth events.
|
||||||
authEventNIDs, err := checkAuthEvents(db, event, input.AuthEventIDs)
|
authEventNIDs, err := checkAuthEvents(ctx, db, event, input.AuthEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store the event
|
// Store the event
|
||||||
roomNID, stateAtEvent, err := db.StoreEvent(event, authEventNIDs)
|
roomNID, stateAtEvent, err := db.StoreEvent(ctx, event, authEventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -82,20 +104,20 @@ func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.
|
||||||
// We've been told what the state at the event is so we don't need to calculate it.
|
// We've been told what the state at the event is so we don't need to calculate it.
|
||||||
// Check that those state events are in the database and store the state.
|
// Check that those state events are in the database and store the state.
|
||||||
var entries []types.StateEntry
|
var entries []types.StateEntry
|
||||||
if entries, err = db.StateEntriesForEventIDs(input.StateEventIDs); err != nil {
|
if entries, err = db.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(roomNID, nil, entries); err != nil {
|
if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(ctx, roomNID, nil, entries); err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// We haven't been told what the state at the event is so we need to calculate it from the prev_events
|
// We haven't been told what the state at the event is so we need to calculate it from the prev_events
|
||||||
if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(db, event, roomNID); err != nil {
|
if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(ctx, db, event, roomNID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
db.SetState(stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
|
db.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.Kind == api.KindBackfill {
|
if input.Kind == api.KindBackfill {
|
||||||
|
@ -104,14 +126,19 @@ func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the extremities of the event graph for the room
|
// Update the extremities of the event graph for the room
|
||||||
if err := updateLatestEvents(db, ow, roomNID, stateAtEvent, event, input.SendAsServer); err != nil {
|
if err := updateLatestEvents(ctx, db, ow, roomNID, stateAtEvent, event, input.SendAsServer); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func processInviteEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.InputInviteEvent) (err error) {
|
func processInviteEvent(
|
||||||
|
ctx context.Context,
|
||||||
|
db RoomEventDatabase,
|
||||||
|
ow OutputRoomEventWriter,
|
||||||
|
input api.InputInviteEvent,
|
||||||
|
) (err error) {
|
||||||
if input.Event.StateKey() == nil {
|
if input.Event.StateKey() == nil {
|
||||||
return fmt.Errorf("invite must be a state event")
|
return fmt.Errorf("invite must be a state event")
|
||||||
}
|
}
|
||||||
|
@ -119,7 +146,7 @@ func processInviteEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input ap
|
||||||
roomID := input.Event.RoomID()
|
roomID := input.Event.RoomID()
|
||||||
targetUserID := *input.Event.StateKey()
|
targetUserID := *input.Event.StateKey()
|
||||||
|
|
||||||
updater, err := db.MembershipUpdater(roomID, targetUserID)
|
updater, err := db.MembershipUpdater(ctx, roomID, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,12 +59,12 @@ func (r *RoomserverInputAPI) InputRoomEvents(
|
||||||
response *api.InputRoomEventsResponse,
|
response *api.InputRoomEventsResponse,
|
||||||
) error {
|
) error {
|
||||||
for i := range request.InputRoomEvents {
|
for i := range request.InputRoomEvents {
|
||||||
if err := processRoomEvent(r.DB, r, request.InputRoomEvents[i]); err != nil {
|
if err := processRoomEvent(ctx, r.DB, r, request.InputRoomEvents[i]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for i := range request.InputInviteEvents {
|
for i := range request.InputInviteEvents {
|
||||||
if err := processInviteEvent(r.DB, r, request.InputInviteEvents[i]); err != nil {
|
if err := processInviteEvent(ctx, r.DB, r, request.InputInviteEvents[i]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@ package input
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
@ -42,6 +43,7 @@ import (
|
||||||
// 7 <----- latest
|
// 7 <----- latest
|
||||||
//
|
//
|
||||||
func updateLatestEvents(
|
func updateLatestEvents(
|
||||||
|
ctx context.Context,
|
||||||
db RoomEventDatabase,
|
db RoomEventDatabase,
|
||||||
ow OutputRoomEventWriter,
|
ow OutputRoomEventWriter,
|
||||||
roomNID types.RoomNID,
|
roomNID types.RoomNID,
|
||||||
|
@ -49,7 +51,7 @@ func updateLatestEvents(
|
||||||
event gomatrixserverlib.Event,
|
event gomatrixserverlib.Event,
|
||||||
sendAsServer string,
|
sendAsServer string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
updater, err := db.GetLatestEventsForUpdate(roomNID)
|
updater, err := db.GetLatestEventsForUpdate(ctx, roomNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -57,7 +59,7 @@ func updateLatestEvents(
|
||||||
defer common.EndTransaction(updater, &succeeded)
|
defer common.EndTransaction(updater, &succeeded)
|
||||||
|
|
||||||
u := latestEventsUpdater{
|
u := latestEventsUpdater{
|
||||||
db: db, updater: updater, ow: ow, roomNID: roomNID,
|
ctx: ctx, db: db, updater: updater, ow: ow, roomNID: roomNID,
|
||||||
stateAtEvent: stateAtEvent, event: event, sendAsServer: sendAsServer,
|
stateAtEvent: stateAtEvent, event: event, sendAsServer: sendAsServer,
|
||||||
}
|
}
|
||||||
if err = u.doUpdateLatestEvents(); err != nil {
|
if err = u.doUpdateLatestEvents(); err != nil {
|
||||||
|
@ -73,6 +75,7 @@ func updateLatestEvents(
|
||||||
// The state could be passed using function arguments, but it becomes impractical
|
// The state could be passed using function arguments, but it becomes impractical
|
||||||
// when there are so many variables to pass around.
|
// when there are so many variables to pass around.
|
||||||
type latestEventsUpdater struct {
|
type latestEventsUpdater struct {
|
||||||
|
ctx context.Context
|
||||||
db RoomEventDatabase
|
db RoomEventDatabase
|
||||||
updater types.RoomRecentEventsUpdater
|
updater types.RoomRecentEventsUpdater
|
||||||
ow OutputRoomEventWriter
|
ow OutputRoomEventWriter
|
||||||
|
@ -133,7 +136,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
updates, err := updateMemberships(u.db, u.updater, u.removed, u.added)
|
updates, err := updateMemberships(u.ctx, u.db, u.updater, u.removed, u.added)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -174,18 +177,22 @@ func (u *latestEventsUpdater) latestState() error {
|
||||||
for i := range u.latest {
|
for i := range u.latest {
|
||||||
latestStateAtEvents[i] = u.latest[i].StateAtEvent
|
latestStateAtEvents[i] = u.latest[i].StateAtEvent
|
||||||
}
|
}
|
||||||
u.newStateNID, err = state.CalculateAndStoreStateAfterEvents(u.db, u.roomNID, latestStateAtEvents)
|
u.newStateNID, err = state.CalculateAndStoreStateAfterEvents(
|
||||||
|
u.ctx, u.db, u.roomNID, latestStateAtEvents,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
u.removed, u.added, err = state.DifferenceBetweeenStateSnapshots(u.db, u.oldStateNID, u.newStateNID)
|
u.removed, u.added, err = state.DifferenceBetweeenStateSnapshots(
|
||||||
|
u.ctx, u.db, u.oldStateNID, u.newStateNID,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = state.DifferenceBetweeenStateSnapshots(
|
u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = state.DifferenceBetweeenStateSnapshots(
|
||||||
u.db, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID,
|
u.ctx, u.db, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -193,7 +200,12 @@ func (u *latestEventsUpdater) latestState() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func calculateLatest(oldLatest []types.StateAtEventAndReference, alreadyReferenced bool, prevEvents []gomatrixserverlib.EventReference, newEvent types.StateAtEventAndReference) []types.StateAtEventAndReference {
|
func calculateLatest(
|
||||||
|
oldLatest []types.StateAtEventAndReference,
|
||||||
|
alreadyReferenced bool,
|
||||||
|
prevEvents []gomatrixserverlib.EventReference,
|
||||||
|
newEvent types.StateAtEventAndReference,
|
||||||
|
) []types.StateAtEventAndReference {
|
||||||
var alreadyInLatest bool
|
var alreadyInLatest bool
|
||||||
var newLatest []types.StateAtEventAndReference
|
var newLatest []types.StateAtEventAndReference
|
||||||
for _, l := range oldLatest {
|
for _, l := range oldLatest {
|
||||||
|
@ -253,7 +265,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
|
||||||
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
|
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
|
||||||
}
|
}
|
||||||
stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))]
|
stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))]
|
||||||
eventIDMap, err := u.db.EventIDs(stateEventNIDs)
|
eventIDMap, err := u.db.EventIDs(u.ctx, stateEventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package input
|
package input
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
@ -27,7 +28,10 @@ import (
|
||||||
// Returns a list of output events to write to the kafka log to inform the
|
// Returns a list of output events to write to the kafka log to inform the
|
||||||
// consumers about the invites added or retired by the change in current state.
|
// consumers about the invites added or retired by the change in current state.
|
||||||
func updateMemberships(
|
func updateMemberships(
|
||||||
db RoomEventDatabase, updater types.RoomRecentEventsUpdater, removed, added []types.StateEntry,
|
ctx context.Context,
|
||||||
|
db RoomEventDatabase,
|
||||||
|
updater types.RoomRecentEventsUpdater,
|
||||||
|
removed, added []types.StateEntry,
|
||||||
) ([]api.OutputEvent, error) {
|
) ([]api.OutputEvent, error) {
|
||||||
changes := membershipChanges(removed, added)
|
changes := membershipChanges(removed, added)
|
||||||
var eventNIDs []types.EventNID
|
var eventNIDs []types.EventNID
|
||||||
|
@ -43,7 +47,7 @@ func updateMemberships(
|
||||||
// Load the event JSON so we can look up the "membership" key.
|
// Load the event JSON so we can look up the "membership" key.
|
||||||
// TODO: Maybe add a membership key to the events table so we can load that
|
// TODO: Maybe add a membership key to the events table so we can load that
|
||||||
// key without having to load the entire event JSON?
|
// key without having to load the entire event JSON?
|
||||||
events, err := db.Events(eventNIDs)
|
events, err := db.Events(ctx, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,35 +33,47 @@ type RoomserverQueryAPIDatabase interface {
|
||||||
// Look up the numeric ID for the room.
|
// Look up the numeric ID for the room.
|
||||||
// Returns 0 if the room doesn't exists.
|
// Returns 0 if the room doesn't exists.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
RoomNID(roomID string) (types.RoomNID, error)
|
RoomNID(ctx context.Context, roomID string) (types.RoomNID, error)
|
||||||
// Look up event references for the latest events in the room and the current state snapshot.
|
// Look up event references for the latest events in the room and the current state snapshot.
|
||||||
// Returns the latest events, the current state and the maximum depth of the latest events plus 1.
|
// Returns the latest events, the current state and the maximum depth of the latest events plus 1.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
|
LatestEventIDs(
|
||||||
|
ctx context.Context, roomNID types.RoomNID,
|
||||||
|
) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
|
||||||
// Look up the numeric IDs for a list of events.
|
// Look up the numeric IDs for a list of events.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
EventNIDs(eventIDs []string) (map[string]types.EventNID, error)
|
EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error)
|
||||||
// Lookup the event IDs for a batch of event numeric IDs.
|
// Lookup the event IDs for a batch of event numeric IDs.
|
||||||
// Returns an error if the retrieval went wrong.
|
// Returns an error if the retrieval went wrong.
|
||||||
EventIDs(eventNIDs []types.EventNID) (map[types.EventNID]string, error)
|
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
|
||||||
// Lookup the membership of a given user in a given room.
|
// Lookup the membership of a given user in a given room.
|
||||||
// Returns the numeric ID of the latest membership event sent from this user
|
// Returns the numeric ID of the latest membership event sent from this user
|
||||||
// in this room, along a boolean set to true if the user is still in this room,
|
// in this room, along a boolean set to true if the user is still in this room,
|
||||||
// false if not.
|
// false if not.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
GetMembership(roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error)
|
GetMembership(
|
||||||
|
ctx context.Context, roomNID types.RoomNID, requestSenderUserID string,
|
||||||
|
) (membershipEventNID types.EventNID, stillInRoom bool, err error)
|
||||||
// Lookup the membership event numeric IDs for all user that are or have
|
// Lookup the membership event numeric IDs for all user that are or have
|
||||||
// been members of a given room. Only lookup events of "join" membership if
|
// been members of a given room. Only lookup events of "join" membership if
|
||||||
// joinOnly is set to true.
|
// joinOnly is set to true.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
GetMembershipEventNIDsForRoom(roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error)
|
GetMembershipEventNIDsForRoom(
|
||||||
|
ctx context.Context, roomNID types.RoomNID, joinOnly bool,
|
||||||
|
) ([]types.EventNID, error)
|
||||||
// Look up the active invites targeting a user in a room and return the
|
// Look up the active invites targeting a user in a room and return the
|
||||||
// numeric state key IDs for the user IDs who sent them.
|
// numeric state key IDs for the user IDs who sent them.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
GetInvitesForUser(roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserNIDs []types.EventStateKeyNID, err error)
|
GetInvitesForUser(
|
||||||
|
ctx context.Context,
|
||||||
|
roomNID types.RoomNID,
|
||||||
|
targetUserNID types.EventStateKeyNID,
|
||||||
|
) (senderUserNIDs []types.EventStateKeyNID, err error)
|
||||||
// Look up the string event state keys for a list of numeric event state keys
|
// Look up the string event state keys for a list of numeric event state keys
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
EventStateKeys([]types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
|
EventStateKeys(
|
||||||
|
context.Context, []types.EventStateKeyNID,
|
||||||
|
) (map[types.EventStateKeyNID]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI
|
// RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI
|
||||||
|
@ -76,7 +88,7 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState(
|
||||||
response *api.QueryLatestEventsAndStateResponse,
|
response *api.QueryLatestEventsAndStateResponse,
|
||||||
) error {
|
) error {
|
||||||
response.QueryLatestEventsAndStateRequest = *request
|
response.QueryLatestEventsAndStateRequest = *request
|
||||||
roomNID, err := r.DB.RoomNID(request.RoomID)
|
roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -85,18 +97,21 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState(
|
||||||
}
|
}
|
||||||
response.RoomExists = true
|
response.RoomExists = true
|
||||||
var currentStateSnapshotNID types.StateSnapshotNID
|
var currentStateSnapshotNID types.StateSnapshotNID
|
||||||
response.LatestEvents, currentStateSnapshotNID, response.Depth, err = r.DB.LatestEventIDs(roomNID)
|
response.LatestEvents, currentStateSnapshotNID, response.Depth, err =
|
||||||
|
r.DB.LatestEventIDs(ctx, roomNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look up the currrent state for the requested tuples.
|
// Look up the currrent state for the requested tuples.
|
||||||
stateEntries, err := state.LoadStateAtSnapshotForStringTuples(r.DB, currentStateSnapshotNID, request.StateToFetch)
|
stateEntries, err := state.LoadStateAtSnapshotForStringTuples(
|
||||||
|
ctx, r.DB, currentStateSnapshotNID, request.StateToFetch,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
stateEvents, err := r.loadStateEvents(stateEntries)
|
stateEvents, err := r.loadStateEvents(ctx, stateEntries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -112,7 +127,7 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents(
|
||||||
response *api.QueryStateAfterEventsResponse,
|
response *api.QueryStateAfterEventsResponse,
|
||||||
) error {
|
) error {
|
||||||
response.QueryStateAfterEventsRequest = *request
|
response.QueryStateAfterEventsRequest = *request
|
||||||
roomNID, err := r.DB.RoomNID(request.RoomID)
|
roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -121,7 +136,7 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents(
|
||||||
}
|
}
|
||||||
response.RoomExists = true
|
response.RoomExists = true
|
||||||
|
|
||||||
prevStates, err := r.DB.StateAtEventIDs(request.PrevEventIDs)
|
prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch err.(type) {
|
switch err.(type) {
|
||||||
case types.MissingEventError:
|
case types.MissingEventError:
|
||||||
|
@ -133,12 +148,14 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents(
|
||||||
response.PrevEventsExist = true
|
response.PrevEventsExist = true
|
||||||
|
|
||||||
// Look up the currrent state for the requested tuples.
|
// Look up the currrent state for the requested tuples.
|
||||||
stateEntries, err := state.LoadStateAfterEventsForStringTuples(r.DB, prevStates, request.StateToFetch)
|
stateEntries, err := state.LoadStateAfterEventsForStringTuples(
|
||||||
|
ctx, r.DB, prevStates, request.StateToFetch,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
stateEvents, err := r.loadStateEvents(stateEntries)
|
stateEvents, err := r.loadStateEvents(ctx, stateEntries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -155,7 +172,7 @@ func (r *RoomserverQueryAPI) QueryEventsByID(
|
||||||
) error {
|
) error {
|
||||||
response.QueryEventsByIDRequest = *request
|
response.QueryEventsByIDRequest = *request
|
||||||
|
|
||||||
eventNIDMap, err := r.DB.EventNIDs(request.EventIDs)
|
eventNIDMap, err := r.DB.EventNIDs(ctx, request.EventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -165,7 +182,7 @@ func (r *RoomserverQueryAPI) QueryEventsByID(
|
||||||
eventNIDs = append(eventNIDs, nid)
|
eventNIDs = append(eventNIDs, nid)
|
||||||
}
|
}
|
||||||
|
|
||||||
events, err := r.loadEvents(eventNIDs)
|
events, err := r.loadEvents(ctx, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -174,16 +191,20 @@ func (r *RoomserverQueryAPI) QueryEventsByID(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RoomserverQueryAPI) loadStateEvents(stateEntries []types.StateEntry) ([]gomatrixserverlib.Event, error) {
|
func (r *RoomserverQueryAPI) loadStateEvents(
|
||||||
|
ctx context.Context, stateEntries []types.StateEntry,
|
||||||
|
) ([]gomatrixserverlib.Event, error) {
|
||||||
eventNIDs := make([]types.EventNID, len(stateEntries))
|
eventNIDs := make([]types.EventNID, len(stateEntries))
|
||||||
for i := range stateEntries {
|
for i := range stateEntries {
|
||||||
eventNIDs[i] = stateEntries[i].EventNID
|
eventNIDs[i] = stateEntries[i].EventNID
|
||||||
}
|
}
|
||||||
return r.loadEvents(eventNIDs)
|
return r.loadEvents(ctx, eventNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RoomserverQueryAPI) loadEvents(eventNIDs []types.EventNID) ([]gomatrixserverlib.Event, error) {
|
func (r *RoomserverQueryAPI) loadEvents(
|
||||||
stateEvents, err := r.DB.Events(eventNIDs)
|
ctx context.Context, eventNIDs []types.EventNID,
|
||||||
|
) ([]gomatrixserverlib.Event, error) {
|
||||||
|
stateEvents, err := r.DB.Events(ctx, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -201,12 +222,12 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
|
||||||
request *api.QueryMembershipsForRoomRequest,
|
request *api.QueryMembershipsForRoomRequest,
|
||||||
response *api.QueryMembershipsForRoomResponse,
|
response *api.QueryMembershipsForRoomResponse,
|
||||||
) error {
|
) error {
|
||||||
roomNID, err := r.DB.RoomNID(request.RoomID)
|
roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
membershipEventNID, stillInRoom, err := r.DB.GetMembership(roomNID, request.Sender)
|
membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, roomNID, request.Sender)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -223,14 +244,14 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
|
||||||
var events []types.Event
|
var events []types.Event
|
||||||
if stillInRoom {
|
if stillInRoom {
|
||||||
var eventNIDs []types.EventNID
|
var eventNIDs []types.EventNID
|
||||||
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(roomNID, request.JoinedOnly)
|
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
events, err = r.DB.Events(eventNIDs)
|
events, err = r.DB.Events(ctx, eventNIDs)
|
||||||
} else {
|
} else {
|
||||||
events, err = r.getMembershipsBeforeEventNID(membershipEventNID, request.JoinedOnly)
|
events, err = r.getMembershipsBeforeEventNID(ctx, membershipEventNID, request.JoinedOnly)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -249,22 +270,24 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
|
||||||
// of the event's room as it was when this event was fired, then filters the state events to
|
// of the event's room as it was when this event was fired, then filters the state events to
|
||||||
// only keep the "m.room.member" events with a "join" membership. These events are returned.
|
// only keep the "m.room.member" events with a "join" membership. These events are returned.
|
||||||
// Returns an error if there was an issue fetching the events.
|
// Returns an error if there was an issue fetching the events.
|
||||||
func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(eventNID types.EventNID, joinedOnly bool) ([]types.Event, error) {
|
func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(
|
||||||
|
ctx context.Context, eventNID types.EventNID, joinedOnly bool,
|
||||||
|
) ([]types.Event, error) {
|
||||||
events := []types.Event{}
|
events := []types.Event{}
|
||||||
// Lookup the event NID
|
// Lookup the event NID
|
||||||
eIDs, err := r.DB.EventIDs([]types.EventNID{eventNID})
|
eIDs, err := r.DB.EventIDs(ctx, []types.EventNID{eventNID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
eventIDs := []string{eIDs[eventNID]}
|
eventIDs := []string{eIDs[eventNID]}
|
||||||
|
|
||||||
prevState, err := r.DB.StateAtEventIDs(eventIDs)
|
prevState, err := r.DB.StateAtEventIDs(ctx, eventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch the state as it was when this event was fired
|
// Fetch the state as it was when this event was fired
|
||||||
stateEntries, err := state.LoadCombinedStateAfterEvents(r.DB, prevState)
|
stateEntries, err := state.LoadCombinedStateAfterEvents(ctx, r.DB, prevState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -278,7 +301,7 @@ func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(eventNID types.EventNI
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get all of the events in this state
|
// Get all of the events in this state
|
||||||
stateEvents, err := r.DB.Events(eventNIDs)
|
stateEvents, err := r.DB.Events(ctx, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -304,27 +327,27 @@ func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(eventNID types.EventNI
|
||||||
|
|
||||||
// QueryInvitesForUser implements api.RoomserverQueryAPI
|
// QueryInvitesForUser implements api.RoomserverQueryAPI
|
||||||
func (r *RoomserverQueryAPI) QueryInvitesForUser(
|
func (r *RoomserverQueryAPI) QueryInvitesForUser(
|
||||||
_ context.Context,
|
ctx context.Context,
|
||||||
request *api.QueryInvitesForUserRequest,
|
request *api.QueryInvitesForUserRequest,
|
||||||
response *api.QueryInvitesForUserResponse,
|
response *api.QueryInvitesForUserResponse,
|
||||||
) error {
|
) error {
|
||||||
roomNID, err := r.DB.RoomNID(request.RoomID)
|
roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
targetUserNIDs, err := r.DB.EventStateKeyNIDs([]string{request.TargetUserID})
|
targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.TargetUserID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
targetUserNID := targetUserNIDs[request.TargetUserID]
|
targetUserNID := targetUserNIDs[request.TargetUserID]
|
||||||
|
|
||||||
senderUserNIDs, err := r.DB.GetInvitesForUser(roomNID, targetUserNID)
|
senderUserNIDs, err := r.DB.GetInvitesForUser(ctx, roomNID, targetUserNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
senderUserIDs, err := r.DB.EventStateKeys(senderUserNIDs)
|
senderUserIDs, err := r.DB.EventStateKeys(ctx, senderUserNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -342,14 +365,14 @@ func (r *RoomserverQueryAPI) QueryServerAllowedToSeeEvent(
|
||||||
request *api.QueryServerAllowedToSeeEventRequest,
|
request *api.QueryServerAllowedToSeeEventRequest,
|
||||||
response *api.QueryServerAllowedToSeeEventResponse,
|
response *api.QueryServerAllowedToSeeEventResponse,
|
||||||
) error {
|
) error {
|
||||||
stateEntries, err := state.LoadStateAtEvent(r.DB, request.EventID)
|
stateEntries, err := state.LoadStateAtEvent(ctx, r.DB, request.EventID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: We probably want to make it so that we don't have to pull
|
// TODO: We probably want to make it so that we don't have to pull
|
||||||
// out all the state if possible.
|
// out all the state if possible.
|
||||||
stateAtEvent, err := r.loadStateEvents(stateEntries)
|
stateAtEvent, err := r.loadStateEvents(ctx, stateEntries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package state
|
package state
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"time"
|
"time"
|
||||||
|
@ -30,49 +31,58 @@ import (
|
||||||
// A RoomStateDatabase has the storage APIs needed to load state from the database
|
// A RoomStateDatabase has the storage APIs needed to load state from the database
|
||||||
type RoomStateDatabase interface {
|
type RoomStateDatabase interface {
|
||||||
// Store the room state at an event in the database
|
// Store the room state at an event in the database
|
||||||
AddState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
|
AddState(
|
||||||
|
ctx context.Context,
|
||||||
|
roomNID types.RoomNID,
|
||||||
|
stateBlockNIDs []types.StateBlockNID,
|
||||||
|
state []types.StateEntry,
|
||||||
|
) (types.StateSnapshotNID, error)
|
||||||
// Look up the state of a room at each event for a list of string event IDs.
|
// Look up the state of a room at each event for a list of string event IDs.
|
||||||
// Returns an error if there is an error talking to the database
|
// Returns an error if there is an error talking to the database
|
||||||
// Returns a types.MissingEventError if the room state for the event IDs aren't in the database
|
// Returns a types.MissingEventError if the room state for the event IDs aren't in the database
|
||||||
StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, error)
|
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
|
||||||
// Look up the numeric IDs for a list of string event types.
|
// Look up the numeric IDs for a list of string event types.
|
||||||
// Returns a map from string event type to numeric ID for the event type.
|
// Returns a map from string event type to numeric ID for the event type.
|
||||||
EventTypeNIDs(eventTypes []string) (map[string]types.EventTypeNID, error)
|
EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
|
||||||
// Look up the numeric IDs for a list of string event state keys.
|
// Look up the numeric IDs for a list of string event state keys.
|
||||||
// Returns a map from string state key to numeric ID for the state key.
|
// Returns a map from string state key to numeric ID for the state key.
|
||||||
EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
|
EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
|
||||||
// Look up the numeric state data IDs for each numeric state snapshot ID
|
// Look up the numeric state data IDs for each numeric state snapshot ID
|
||||||
// The returned slice is sorted by numeric state snapshot ID.
|
// The returned slice is sorted by numeric state snapshot ID.
|
||||||
StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
||||||
// Look up the state data for each numeric state data ID
|
// Look up the state data for each numeric state data ID
|
||||||
// The returned slice is sorted by numeric state data ID.
|
// The returned slice is sorted by numeric state data ID.
|
||||||
StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
|
StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
|
||||||
// Look up the state data for the state key tuples for each numeric state block ID
|
// Look up the state data for the state key tuples for each numeric state block ID
|
||||||
// This is used to fetch a subset of the room state at a snapshot.
|
// This is used to fetch a subset of the room state at a snapshot.
|
||||||
// If a block doesn't contain any of the requested tuples then it can be discarded from the result.
|
// If a block doesn't contain any of the requested tuples then it can be discarded from the result.
|
||||||
// The returned slice is sorted by numeric state block ID.
|
// The returned slice is sorted by numeric state block ID.
|
||||||
StateEntriesForTuples(stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) (
|
StateEntriesForTuples(
|
||||||
[]types.StateEntryList, error,
|
ctx context.Context,
|
||||||
)
|
stateBlockNIDs []types.StateBlockNID,
|
||||||
|
stateKeyTuples []types.StateKeyTuple,
|
||||||
|
) ([]types.StateEntryList, error)
|
||||||
// Look up the Events for a list of numeric event IDs.
|
// Look up the Events for a list of numeric event IDs.
|
||||||
// Returns a sorted list of events.
|
// Returns a sorted list of events.
|
||||||
Events(eventNIDs []types.EventNID) ([]types.Event, error)
|
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
|
||||||
// Look up snapshot NID for an event ID string
|
// Look up snapshot NID for an event ID string
|
||||||
SnapshotNIDFromEventID(eventID string) (types.StateSnapshotNID, error)
|
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadStateAtSnapshot loads the full state of a room at a particular snapshot.
|
// LoadStateAtSnapshot loads the full state of a room at a particular snapshot.
|
||||||
// This is typically the state before an event or the current state of a room.
|
// This is typically the state before an event or the current state of a room.
|
||||||
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
|
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
|
||||||
func LoadStateAtSnapshot(db RoomStateDatabase, stateNID types.StateSnapshotNID) ([]types.StateEntry, error) {
|
func LoadStateAtSnapshot(
|
||||||
stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{stateNID})
|
ctx context.Context, db RoomStateDatabase, stateNID types.StateSnapshotNID,
|
||||||
|
) ([]types.StateEntry, error) {
|
||||||
|
stateBlockNIDLists, err := db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
|
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
|
||||||
stateBlockNIDList := stateBlockNIDLists[0]
|
stateBlockNIDList := stateBlockNIDLists[0]
|
||||||
|
|
||||||
stateEntryLists, err := db.StateEntries(stateBlockNIDList.StateBlockNIDs)
|
stateEntryLists, err := db.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -100,13 +110,15 @@ func LoadStateAtSnapshot(db RoomStateDatabase, stateNID types.StateSnapshotNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadStateAtEvent loads the full state of a room at a particular event.
|
// LoadStateAtEvent loads the full state of a room at a particular event.
|
||||||
func LoadStateAtEvent(db RoomStateDatabase, eventID string) ([]types.StateEntry, error) {
|
func LoadStateAtEvent(
|
||||||
snapshotNID, err := db.SnapshotNIDFromEventID(eventID)
|
ctx context.Context, db RoomStateDatabase, eventID string,
|
||||||
|
) ([]types.StateEntry, error) {
|
||||||
|
snapshotNID, err := db.SnapshotNIDFromEventID(ctx, eventID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
stateEntries, err := LoadStateAtSnapshot(db, snapshotNID)
|
stateEntries, err := LoadStateAtSnapshot(ctx, db, snapshotNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -116,7 +128,9 @@ func LoadStateAtEvent(db RoomStateDatabase, eventID string) ([]types.StateEntry,
|
||||||
|
|
||||||
// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events
|
// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events
|
||||||
// and combines those snapshots together into a single list.
|
// and combines those snapshots together into a single list.
|
||||||
func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.StateAtEvent) ([]types.StateEntry, error) {
|
func LoadCombinedStateAfterEvents(
|
||||||
|
ctx context.Context, db RoomStateDatabase, prevStates []types.StateAtEvent,
|
||||||
|
) ([]types.StateEntry, error) {
|
||||||
stateNIDs := make([]types.StateSnapshotNID, len(prevStates))
|
stateNIDs := make([]types.StateSnapshotNID, len(prevStates))
|
||||||
for i, state := range prevStates {
|
for i, state := range prevStates {
|
||||||
stateNIDs[i] = state.BeforeStateSnapshotNID
|
stateNIDs[i] = state.BeforeStateSnapshotNID
|
||||||
|
@ -125,7 +139,7 @@ func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.State
|
||||||
// Deduplicate the IDs before passing them to the database.
|
// Deduplicate the IDs before passing them to the database.
|
||||||
// There could be duplicates because the events could be state events where
|
// There could be duplicates because the events could be state events where
|
||||||
// the snapshot of the room state before them was the same.
|
// the snapshot of the room state before them was the same.
|
||||||
stateBlockNIDLists, err := db.StateBlockNIDs(uniqueStateSnapshotNIDs(stateNIDs))
|
stateBlockNIDLists, err := db.StateBlockNIDs(ctx, uniqueStateSnapshotNIDs(stateNIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -138,7 +152,7 @@ func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.State
|
||||||
// Deduplicate the IDs before passing them to the database.
|
// Deduplicate the IDs before passing them to the database.
|
||||||
// There could be duplicates because a block of state entries could be reused by
|
// There could be duplicates because a block of state entries could be reused by
|
||||||
// multiple snapshots.
|
// multiple snapshots.
|
||||||
stateEntryLists, err := db.StateEntries(uniqueStateBlockNIDs(stateBlockNIDs))
|
stateEntryLists, err := db.StateEntries(ctx, uniqueStateBlockNIDs(stateBlockNIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -186,9 +200,9 @@ func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.State
|
||||||
}
|
}
|
||||||
|
|
||||||
// DifferenceBetweeenStateSnapshots works out which state entries have been added and removed between two snapshots.
|
// DifferenceBetweeenStateSnapshots works out which state entries have been added and removed between two snapshots.
|
||||||
func DifferenceBetweeenStateSnapshots(db RoomStateDatabase, oldStateNID, newStateNID types.StateSnapshotNID) (
|
func DifferenceBetweeenStateSnapshots(
|
||||||
removed, added []types.StateEntry, err error,
|
ctx context.Context, db RoomStateDatabase, oldStateNID, newStateNID types.StateSnapshotNID,
|
||||||
) {
|
) (removed, added []types.StateEntry, err error) {
|
||||||
if oldStateNID == newStateNID {
|
if oldStateNID == newStateNID {
|
||||||
// If the snapshot NIDs are the same then nothing has changed
|
// If the snapshot NIDs are the same then nothing has changed
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
|
@ -197,13 +211,13 @@ func DifferenceBetweeenStateSnapshots(db RoomStateDatabase, oldStateNID, newStat
|
||||||
var oldEntries []types.StateEntry
|
var oldEntries []types.StateEntry
|
||||||
var newEntries []types.StateEntry
|
var newEntries []types.StateEntry
|
||||||
if oldStateNID != 0 {
|
if oldStateNID != 0 {
|
||||||
oldEntries, err = LoadStateAtSnapshot(db, oldStateNID)
|
oldEntries, err = LoadStateAtSnapshot(ctx, db, oldStateNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if newStateNID != 0 {
|
if newStateNID != 0 {
|
||||||
newEntries, err = LoadStateAtSnapshot(db, newStateNID)
|
newEntries, err = LoadStateAtSnapshot(ctx, db, newStateNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -246,19 +260,26 @@ func DifferenceBetweeenStateSnapshots(db RoomStateDatabase, oldStateNID, newStat
|
||||||
// This is typically the state before an event or the current state of a room.
|
// This is typically the state before an event or the current state of a room.
|
||||||
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
|
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
|
||||||
func LoadStateAtSnapshotForStringTuples(
|
func LoadStateAtSnapshotForStringTuples(
|
||||||
db RoomStateDatabase, stateNID types.StateSnapshotNID, stateKeyTuples []gomatrixserverlib.StateKeyTuple,
|
ctx context.Context,
|
||||||
|
db RoomStateDatabase,
|
||||||
|
stateNID types.StateSnapshotNID,
|
||||||
|
stateKeyTuples []gomatrixserverlib.StateKeyTuple,
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
numericTuples, err := stringTuplesToNumericTuples(db, stateKeyTuples)
|
numericTuples, err := stringTuplesToNumericTuples(ctx, db, stateKeyTuples)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return loadStateAtSnapshotForNumericTuples(db, stateNID, numericTuples)
|
return loadStateAtSnapshotForNumericTuples(ctx, db, stateNID, numericTuples)
|
||||||
}
|
}
|
||||||
|
|
||||||
// stringTuplesToNumericTuples converts the string state key tuples into numeric IDs
|
// stringTuplesToNumericTuples converts the string state key tuples into numeric IDs
|
||||||
// If there isn't a numeric ID for either the event type or the event state key then the tuple is discarded.
|
// If there isn't a numeric ID for either the event type or the event state key then the tuple is discarded.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
func stringTuplesToNumericTuples(db RoomStateDatabase, stringTuples []gomatrixserverlib.StateKeyTuple) ([]types.StateKeyTuple, error) {
|
func stringTuplesToNumericTuples(
|
||||||
|
ctx context.Context,
|
||||||
|
db RoomStateDatabase,
|
||||||
|
stringTuples []gomatrixserverlib.StateKeyTuple,
|
||||||
|
) ([]types.StateKeyTuple, error) {
|
||||||
eventTypes := make([]string, len(stringTuples))
|
eventTypes := make([]string, len(stringTuples))
|
||||||
stateKeys := make([]string, len(stringTuples))
|
stateKeys := make([]string, len(stringTuples))
|
||||||
for i := range stringTuples {
|
for i := range stringTuples {
|
||||||
|
@ -266,12 +287,12 @@ func stringTuplesToNumericTuples(db RoomStateDatabase, stringTuples []gomatrixse
|
||||||
stateKeys[i] = stringTuples[i].StateKey
|
stateKeys[i] = stringTuples[i].StateKey
|
||||||
}
|
}
|
||||||
eventTypes = util.UniqueStrings(eventTypes)
|
eventTypes = util.UniqueStrings(eventTypes)
|
||||||
eventTypeMap, err := db.EventTypeNIDs(eventTypes)
|
eventTypeMap, err := db.EventTypeNIDs(ctx, eventTypes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
stateKeys = util.UniqueStrings(stateKeys)
|
stateKeys = util.UniqueStrings(stateKeys)
|
||||||
stateKeyMap, err := db.EventStateKeyNIDs(stateKeys)
|
stateKeyMap, err := db.EventStateKeyNIDs(ctx, stateKeys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -297,16 +318,21 @@ func stringTuplesToNumericTuples(db RoomStateDatabase, stringTuples []gomatrixse
|
||||||
// This is typically the state before an event or the current state of a room.
|
// This is typically the state before an event or the current state of a room.
|
||||||
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
|
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
|
||||||
func loadStateAtSnapshotForNumericTuples(
|
func loadStateAtSnapshotForNumericTuples(
|
||||||
db RoomStateDatabase, stateNID types.StateSnapshotNID, stateKeyTuples []types.StateKeyTuple,
|
ctx context.Context,
|
||||||
|
db RoomStateDatabase,
|
||||||
|
stateNID types.StateSnapshotNID,
|
||||||
|
stateKeyTuples []types.StateKeyTuple,
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{stateNID})
|
stateBlockNIDLists, err := db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
|
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
|
||||||
stateBlockNIDList := stateBlockNIDLists[0]
|
stateBlockNIDList := stateBlockNIDLists[0]
|
||||||
|
|
||||||
stateEntryLists, err := db.StateEntriesForTuples(stateBlockNIDList.StateBlockNIDs, stateKeyTuples)
|
stateEntryLists, err := db.StateEntriesForTuples(
|
||||||
|
ctx, stateBlockNIDList.StateBlockNIDs, stateKeyTuples,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -341,23 +367,29 @@ func loadStateAtSnapshotForNumericTuples(
|
||||||
// This is typically the state before an event.
|
// This is typically the state before an event.
|
||||||
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
|
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
|
||||||
func LoadStateAfterEventsForStringTuples(
|
func LoadStateAfterEventsForStringTuples(
|
||||||
db RoomStateDatabase, prevStates []types.StateAtEvent, stateKeyTuples []gomatrixserverlib.StateKeyTuple,
|
ctx context.Context,
|
||||||
|
db RoomStateDatabase,
|
||||||
|
prevStates []types.StateAtEvent,
|
||||||
|
stateKeyTuples []gomatrixserverlib.StateKeyTuple,
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
numericTuples, err := stringTuplesToNumericTuples(db, stateKeyTuples)
|
numericTuples, err := stringTuplesToNumericTuples(ctx, db, stateKeyTuples)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return loadStateAfterEventsForNumericTuples(db, prevStates, numericTuples)
|
return loadStateAfterEventsForNumericTuples(ctx, db, prevStates, numericTuples)
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadStateAfterEventsForNumericTuples(
|
func loadStateAfterEventsForNumericTuples(
|
||||||
db RoomStateDatabase, prevStates []types.StateAtEvent, stateKeyTuples []types.StateKeyTuple,
|
ctx context.Context,
|
||||||
|
db RoomStateDatabase,
|
||||||
|
prevStates []types.StateAtEvent,
|
||||||
|
stateKeyTuples []types.StateKeyTuple,
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
if len(prevStates) == 1 {
|
if len(prevStates) == 1 {
|
||||||
// Fast path for a single event.
|
// Fast path for a single event.
|
||||||
prevState := prevStates[0]
|
prevState := prevStates[0]
|
||||||
result, err := loadStateAtSnapshotForNumericTuples(
|
result, err := loadStateAtSnapshotForNumericTuples(
|
||||||
db, prevState.BeforeStateSnapshotNID, stateKeyTuples,
|
ctx, db, prevState.BeforeStateSnapshotNID, stateKeyTuples,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -390,7 +422,7 @@ func loadStateAfterEventsForNumericTuples(
|
||||||
|
|
||||||
// TODO: Add metrics for this as it could take a long time for big rooms
|
// TODO: Add metrics for this as it could take a long time for big rooms
|
||||||
// with large conflicts.
|
// with large conflicts.
|
||||||
fullState, _, _, err := calculateStateAfterManyEvents(db, prevStates)
|
fullState, _, _, err := calculateStateAfterManyEvents(ctx, db, prevStates)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -403,7 +435,10 @@ func loadStateAfterEventsForNumericTuples(
|
||||||
for _, tuple := range stateKeyTuples {
|
for _, tuple := range stateKeyTuples {
|
||||||
eventNID, ok := stateEntryMap(fullState).lookup(tuple)
|
eventNID, ok := stateEntryMap(fullState).lookup(tuple)
|
||||||
if ok {
|
if ok {
|
||||||
result = append(result, types.StateEntry{tuple, eventNID})
|
result = append(result, types.StateEntry{
|
||||||
|
StateKeyTuple: tuple,
|
||||||
|
EventNID: eventNID,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sort.Sort(stateEntrySorter(result))
|
sort.Sort(stateEntrySorter(result))
|
||||||
|
@ -509,7 +544,10 @@ func init() {
|
||||||
// Stores the snapshot of the state in the database.
|
// Stores the snapshot of the state in the database.
|
||||||
// Returns a numeric ID for the snapshot of the state before the event.
|
// Returns a numeric ID for the snapshot of the state before the event.
|
||||||
func CalculateAndStoreStateBeforeEvent(
|
func CalculateAndStoreStateBeforeEvent(
|
||||||
db RoomStateDatabase, event gomatrixserverlib.Event, roomNID types.RoomNID,
|
ctx context.Context,
|
||||||
|
db RoomStateDatabase,
|
||||||
|
event gomatrixserverlib.Event,
|
||||||
|
roomNID types.RoomNID,
|
||||||
) (types.StateSnapshotNID, error) {
|
) (types.StateSnapshotNID, error) {
|
||||||
// Load the state at the prev events.
|
// Load the state at the prev events.
|
||||||
prevEventRefs := event.PrevEvents()
|
prevEventRefs := event.PrevEvents()
|
||||||
|
@ -518,25 +556,30 @@ func CalculateAndStoreStateBeforeEvent(
|
||||||
prevEventIDs[i] = prevEventRefs[i].EventID
|
prevEventIDs[i] = prevEventRefs[i].EventID
|
||||||
}
|
}
|
||||||
|
|
||||||
prevStates, err := db.StateAtEventIDs(prevEventIDs)
|
prevStates, err := db.StateAtEventIDs(ctx, prevEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// The state before this event will be the state after the events that came before it.
|
// The state before this event will be the state after the events that came before it.
|
||||||
return CalculateAndStoreStateAfterEvents(db, roomNID, prevStates)
|
return CalculateAndStoreStateAfterEvents(ctx, db, roomNID, prevStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CalculateAndStoreStateAfterEvents finds the room state after the given events.
|
// CalculateAndStoreStateAfterEvents finds the room state after the given events.
|
||||||
// Stores the resulting state in the database and returns a numeric ID for that snapshot.
|
// Stores the resulting state in the database and returns a numeric ID for that snapshot.
|
||||||
func CalculateAndStoreStateAfterEvents(db RoomStateDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent) (types.StateSnapshotNID, error) {
|
func CalculateAndStoreStateAfterEvents(
|
||||||
|
ctx context.Context,
|
||||||
|
db RoomStateDatabase,
|
||||||
|
roomNID types.RoomNID,
|
||||||
|
prevStates []types.StateAtEvent,
|
||||||
|
) (types.StateSnapshotNID, error) {
|
||||||
metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)}
|
metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)}
|
||||||
|
|
||||||
if len(prevStates) == 0 {
|
if len(prevStates) == 0 {
|
||||||
// 2) There weren't any prev_events for this event so the state is
|
// 2) There weren't any prev_events for this event so the state is
|
||||||
// empty.
|
// empty.
|
||||||
metrics.algorithm = "empty_state"
|
metrics.algorithm = "empty_state"
|
||||||
return metrics.stop(db.AddState(roomNID, nil, nil))
|
return metrics.stop(db.AddState(ctx, roomNID, nil, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(prevStates) == 1 {
|
if len(prevStates) == 1 {
|
||||||
|
@ -551,7 +594,9 @@ func CalculateAndStoreStateAfterEvents(db RoomStateDatabase, roomNID types.RoomN
|
||||||
}
|
}
|
||||||
// The previous event was a state event so we need to store a copy
|
// The previous event was a state event so we need to store a copy
|
||||||
// of the previous state updated with that event.
|
// of the previous state updated with that event.
|
||||||
stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{prevState.BeforeStateSnapshotNID})
|
stateBlockNIDLists, err := db.StateBlockNIDs(
|
||||||
|
ctx, []types.StateSnapshotNID{prevState.BeforeStateSnapshotNID},
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
metrics.algorithm = "_load_state_blocks"
|
metrics.algorithm = "_load_state_blocks"
|
||||||
return metrics.stop(0, err)
|
return metrics.stop(0, err)
|
||||||
|
@ -562,14 +607,14 @@ func CalculateAndStoreStateAfterEvents(db RoomStateDatabase, roomNID types.RoomN
|
||||||
// add the state event as a block of size one to the end of the blocks.
|
// add the state event as a block of size one to the end of the blocks.
|
||||||
metrics.algorithm = "single_delta"
|
metrics.algorithm = "single_delta"
|
||||||
return metrics.stop(db.AddState(
|
return metrics.stop(db.AddState(
|
||||||
roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry},
|
ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry},
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
// If there are too many deltas then we need to calculate the full state
|
// If there are too many deltas then we need to calculate the full state
|
||||||
// So fall through to calculateAndStoreStateAfterManyEvents
|
// So fall through to calculateAndStoreStateAfterManyEvents
|
||||||
}
|
}
|
||||||
|
|
||||||
return calculateAndStoreStateAfterManyEvents(db, roomNID, prevStates, metrics)
|
return calculateAndStoreStateAfterManyEvents(ctx, db, roomNID, prevStates, metrics)
|
||||||
}
|
}
|
||||||
|
|
||||||
// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state.
|
// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state.
|
||||||
|
@ -583,10 +628,15 @@ const maxStateBlockNIDs = 64
|
||||||
// This handles the slow path of calculateAndStoreStateAfterEvents for when there is more than one event.
|
// This handles the slow path of calculateAndStoreStateAfterEvents for when there is more than one event.
|
||||||
// Stores the resulting state and returns a numeric ID for the snapshot.
|
// Stores the resulting state and returns a numeric ID for the snapshot.
|
||||||
func calculateAndStoreStateAfterManyEvents(
|
func calculateAndStoreStateAfterManyEvents(
|
||||||
db RoomStateDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent, metrics calculateStateMetrics,
|
ctx context.Context,
|
||||||
|
db RoomStateDatabase,
|
||||||
|
roomNID types.RoomNID,
|
||||||
|
prevStates []types.StateAtEvent,
|
||||||
|
metrics calculateStateMetrics,
|
||||||
) (types.StateSnapshotNID, error) {
|
) (types.StateSnapshotNID, error) {
|
||||||
|
|
||||||
state, algorithm, conflictLength, err := calculateStateAfterManyEvents(db, prevStates)
|
state, algorithm, conflictLength, err :=
|
||||||
|
calculateStateAfterManyEvents(ctx, db, prevStates)
|
||||||
metrics.algorithm = algorithm
|
metrics.algorithm = algorithm
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return metrics.stop(0, err)
|
return metrics.stop(0, err)
|
||||||
|
@ -596,16 +646,16 @@ func calculateAndStoreStateAfterManyEvents(
|
||||||
// previous state.
|
// previous state.
|
||||||
metrics.conflictLength = conflictLength
|
metrics.conflictLength = conflictLength
|
||||||
metrics.fullStateLength = len(state)
|
metrics.fullStateLength = len(state)
|
||||||
return metrics.stop(db.AddState(roomNID, nil, state))
|
return metrics.stop(db.AddState(ctx, roomNID, nil, state))
|
||||||
}
|
}
|
||||||
|
|
||||||
func calculateStateAfterManyEvents(
|
func calculateStateAfterManyEvents(
|
||||||
db RoomStateDatabase, prevStates []types.StateAtEvent,
|
ctx context.Context, db RoomStateDatabase, prevStates []types.StateAtEvent,
|
||||||
) (state []types.StateEntry, algorithm string, conflictLength int, err error) {
|
) (state []types.StateEntry, algorithm string, conflictLength int, err error) {
|
||||||
var combined []types.StateEntry
|
var combined []types.StateEntry
|
||||||
// Conflict resolution.
|
// Conflict resolution.
|
||||||
// First stage: load the state after each of the prev events.
|
// First stage: load the state after each of the prev events.
|
||||||
combined, err = LoadCombinedStateAfterEvents(db, prevStates)
|
combined, err = LoadCombinedStateAfterEvents(ctx, db, prevStates)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
algorithm = "_load_combined_state"
|
algorithm = "_load_combined_state"
|
||||||
return
|
return
|
||||||
|
@ -635,7 +685,7 @@ func calculateStateAfterManyEvents(
|
||||||
}
|
}
|
||||||
|
|
||||||
var resolved []types.StateEntry
|
var resolved []types.StateEntry
|
||||||
resolved, err = resolveConflicts(db, notConflicted, conflicts)
|
resolved, err = resolveConflicts(ctx, db, notConflicted, conflicts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
algorithm = "_resolve_conflicts"
|
algorithm = "_resolve_conflicts"
|
||||||
return
|
return
|
||||||
|
@ -657,10 +707,14 @@ func calculateStateAfterManyEvents(
|
||||||
// Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts.
|
// Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts.
|
||||||
// The returned list is sorted by state key tuple.
|
// The returned list is sorted by state key tuple.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
func resolveConflicts(db RoomStateDatabase, notConflicted, conflicted []types.StateEntry) ([]types.StateEntry, error) {
|
func resolveConflicts(
|
||||||
|
ctx context.Context,
|
||||||
|
db RoomStateDatabase,
|
||||||
|
notConflicted, conflicted []types.StateEntry,
|
||||||
|
) ([]types.StateEntry, error) {
|
||||||
|
|
||||||
// Load the conflicted events
|
// Load the conflicted events
|
||||||
conflictedEvents, eventIDMap, err := loadStateEvents(db, conflicted)
|
conflictedEvents, eventIDMap, err := loadStateEvents(ctx, db, conflicted)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -672,7 +726,7 @@ func resolveConflicts(db RoomStateDatabase, notConflicted, conflicted []types.St
|
||||||
var neededStateKeys []string
|
var neededStateKeys []string
|
||||||
neededStateKeys = append(neededStateKeys, needed.Member...)
|
neededStateKeys = append(neededStateKeys, needed.Member...)
|
||||||
neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
|
neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
|
||||||
stateKeyNIDMap, err := db.EventStateKeyNIDs(neededStateKeys)
|
stateKeyNIDMap, err := db.EventStateKeyNIDs(ctx, neededStateKeys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -682,10 +736,13 @@ func resolveConflicts(db RoomStateDatabase, notConflicted, conflicted []types.St
|
||||||
var authEntries []types.StateEntry
|
var authEntries []types.StateEntry
|
||||||
for _, tuple := range tuplesNeeded {
|
for _, tuple := range tuplesNeeded {
|
||||||
if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok {
|
if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok {
|
||||||
authEntries = append(authEntries, types.StateEntry{tuple, eventNID})
|
authEntries = append(authEntries, types.StateEntry{
|
||||||
|
StateKeyTuple: tuple,
|
||||||
|
EventNID: eventNID,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
authEvents, _, err := loadStateEvents(db, authEntries)
|
authEvents, _, err := loadStateEvents(ctx, db, authEntries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -711,24 +768,39 @@ func resolveConflicts(db RoomStateDatabase, notConflicted, conflicted []types.St
|
||||||
func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple {
|
func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple {
|
||||||
var keyTuples []types.StateKeyTuple
|
var keyTuples []types.StateKeyTuple
|
||||||
if stateNeeded.Create {
|
if stateNeeded.Create {
|
||||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomCreateNID, types.EmptyStateKeyNID})
|
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||||
|
EventTypeNID: types.MRoomCreateNID,
|
||||||
|
EventStateKeyNID: types.EmptyStateKeyNID,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
if stateNeeded.PowerLevels {
|
if stateNeeded.PowerLevels {
|
||||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomPowerLevelsNID, types.EmptyStateKeyNID})
|
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||||
|
EventTypeNID: types.MRoomPowerLevelsNID,
|
||||||
|
EventStateKeyNID: types.EmptyStateKeyNID,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
if stateNeeded.JoinRules {
|
if stateNeeded.JoinRules {
|
||||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomJoinRulesNID, types.EmptyStateKeyNID})
|
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||||
|
EventTypeNID: types.MRoomJoinRulesNID,
|
||||||
|
EventStateKeyNID: types.EmptyStateKeyNID,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
for _, member := range stateNeeded.Member {
|
for _, member := range stateNeeded.Member {
|
||||||
stateKeyNID, ok := stateKeyNIDMap[member]
|
stateKeyNID, ok := stateKeyNIDMap[member]
|
||||||
if ok {
|
if ok {
|
||||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomMemberNID, stateKeyNID})
|
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||||
|
EventTypeNID: types.MRoomMemberNID,
|
||||||
|
EventStateKeyNID: stateKeyNID,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, token := range stateNeeded.ThirdPartyInvite {
|
for _, token := range stateNeeded.ThirdPartyInvite {
|
||||||
stateKeyNID, ok := stateKeyNIDMap[token]
|
stateKeyNID, ok := stateKeyNIDMap[token]
|
||||||
if ok {
|
if ok {
|
||||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomThirdPartyInviteNID, stateKeyNID})
|
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||||
|
EventTypeNID: types.MRoomThirdPartyInviteNID,
|
||||||
|
EventStateKeyNID: stateKeyNID,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return keyTuples
|
return keyTuples
|
||||||
|
@ -738,12 +810,14 @@ func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stat
|
||||||
// Returns a list of state events in no particular order and a map from string event ID back to state entry.
|
// Returns a list of state events in no particular order and a map from string event ID back to state entry.
|
||||||
// The map can be used to recover which numeric state entry a given event is for.
|
// The map can be used to recover which numeric state entry a given event is for.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
func loadStateEvents(db RoomStateDatabase, entries []types.StateEntry) ([]gomatrixserverlib.Event, map[string]types.StateEntry, error) {
|
func loadStateEvents(
|
||||||
|
ctx context.Context, db RoomStateDatabase, entries []types.StateEntry,
|
||||||
|
) ([]gomatrixserverlib.Event, map[string]types.StateEntry, error) {
|
||||||
eventNIDs := make([]types.EventNID, len(entries))
|
eventNIDs := make([]types.EventNID, len(entries))
|
||||||
for i := range entries {
|
for i := range entries {
|
||||||
eventNIDs[i] = entries[i].EventNID
|
eventNIDs[i] = entries[i].EventNID
|
||||||
}
|
}
|
||||||
events, err := db.Events(eventNIDs)
|
events, err := db.Events(ctx, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
@ -65,8 +66,10 @@ func (s *eventJSONStatements) prepare(db *sql.DB) (err error) {
|
||||||
}.prepare(db)
|
}.prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventJSONStatements) insertEventJSON(eventNID types.EventNID, eventJSON []byte) error {
|
func (s *eventJSONStatements) insertEventJSON(
|
||||||
_, err := s.insertEventJSONStmt.Exec(int64(eventNID), eventJSON)
|
ctx context.Context, eventNID types.EventNID, eventJSON []byte,
|
||||||
|
) error {
|
||||||
|
_, err := s.insertEventJSONStmt.ExecContext(ctx, int64(eventNID), eventJSON)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,8 +78,10 @@ type eventJSONPair struct {
|
||||||
EventJSON []byte
|
EventJSON []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventJSONStatements) bulkSelectEventJSON(eventNIDs []types.EventNID) ([]eventJSONPair, error) {
|
func (s *eventJSONStatements) bulkSelectEventJSON(
|
||||||
rows, err := s.bulkSelectEventJSONStmt.Query(eventNIDsAsArray(eventNIDs))
|
ctx context.Context, eventNIDs []types.EventNID,
|
||||||
|
) ([]eventJSONPair, error) {
|
||||||
|
rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
@ -91,20 +92,30 @@ func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
|
||||||
}.prepare(db)
|
}.prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStateKeyStatements) insertEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) {
|
func (s *eventStateKeyStatements) insertEventStateKeyNID(
|
||||||
|
ctx context.Context, txn *sql.Tx, eventStateKey string,
|
||||||
|
) (types.EventStateKeyNID, error) {
|
||||||
var eventStateKeyNID int64
|
var eventStateKeyNID int64
|
||||||
err := common.TxStmt(txn, s.insertEventStateKeyNIDStmt).QueryRow(eventStateKey).Scan(&eventStateKeyNID)
|
stmt := common.TxStmt(txn, s.insertEventStateKeyNIDStmt)
|
||||||
|
err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
|
||||||
return types.EventStateKeyNID(eventStateKeyNID), err
|
return types.EventStateKeyNID(eventStateKeyNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStateKeyStatements) selectEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) {
|
func (s *eventStateKeyStatements) selectEventStateKeyNID(
|
||||||
|
ctx context.Context, txn *sql.Tx, eventStateKey string,
|
||||||
|
) (types.EventStateKeyNID, error) {
|
||||||
var eventStateKeyNID int64
|
var eventStateKeyNID int64
|
||||||
err := common.TxStmt(txn, s.selectEventStateKeyNIDStmt).QueryRow(eventStateKey).Scan(&eventStateKeyNID)
|
stmt := common.TxStmt(txn, s.selectEventStateKeyNIDStmt)
|
||||||
|
err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
|
||||||
return types.EventStateKeyNID(eventStateKeyNID), err
|
return types.EventStateKeyNID(eventStateKeyNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) {
|
func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
|
||||||
rows, err := s.bulkSelectEventStateKeyNIDStmt.Query(pq.StringArray(eventStateKeys))
|
ctx context.Context, eventStateKeys []string,
|
||||||
|
) (map[string]types.EventStateKeyNID, error) {
|
||||||
|
rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext(
|
||||||
|
ctx, pq.StringArray(eventStateKeys),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -122,18 +133,23 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(eventStateKeys []st
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStateKeyStatements) selectEventStateKey(txn *sql.Tx, eventStateKeyNID types.EventStateKeyNID) (string, error) {
|
func (s *eventStateKeyStatements) selectEventStateKey(
|
||||||
|
ctx context.Context, txn *sql.Tx, eventStateKeyNID types.EventStateKeyNID,
|
||||||
|
) (string, error) {
|
||||||
var eventStateKey string
|
var eventStateKey string
|
||||||
err := common.TxStmt(txn, s.selectEventStateKeyStmt).QueryRow(eventStateKeyNID).Scan(&eventStateKey)
|
stmt := common.TxStmt(txn, s.selectEventStateKeyStmt)
|
||||||
|
err := stmt.QueryRowContext(ctx, eventStateKeyNID).Scan(&eventStateKey)
|
||||||
return eventStateKey, err
|
return eventStateKey, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStateKeyStatements) bulkSelectEventStateKey(eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) {
|
func (s *eventStateKeyStatements) bulkSelectEventStateKey(
|
||||||
|
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||||
|
) (map[types.EventStateKeyNID]string, error) {
|
||||||
var nIDs pq.Int64Array
|
var nIDs pq.Int64Array
|
||||||
for i := range eventStateKeyNIDs {
|
for i := range eventStateKeyNIDs {
|
||||||
nIDs[i] = int64(eventStateKeyNIDs[i])
|
nIDs[i] = int64(eventStateKeyNIDs[i])
|
||||||
}
|
}
|
||||||
rows, err := s.bulkSelectEventStateKeyStmt.Query(nIDs)
|
rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
@ -107,20 +108,26 @@ func (s *eventTypeStatements) prepare(db *sql.DB) (err error) {
|
||||||
}.prepare(db)
|
}.prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventTypeStatements) insertEventTypeNID(eventType string) (types.EventTypeNID, error) {
|
func (s *eventTypeStatements) insertEventTypeNID(
|
||||||
|
ctx context.Context, eventType string,
|
||||||
|
) (types.EventTypeNID, error) {
|
||||||
var eventTypeNID int64
|
var eventTypeNID int64
|
||||||
err := s.insertEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID)
|
err := s.insertEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID)
|
||||||
return types.EventTypeNID(eventTypeNID), err
|
return types.EventTypeNID(eventTypeNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventTypeStatements) selectEventTypeNID(eventType string) (types.EventTypeNID, error) {
|
func (s *eventTypeStatements) selectEventTypeNID(
|
||||||
|
ctx context.Context, eventType string,
|
||||||
|
) (types.EventTypeNID, error) {
|
||||||
var eventTypeNID int64
|
var eventTypeNID int64
|
||||||
err := s.selectEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID)
|
err := s.selectEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID)
|
||||||
return types.EventTypeNID(eventTypeNID), err
|
return types.EventTypeNID(eventTypeNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventTypeStatements) bulkSelectEventTypeNID(eventTypes []string) (map[string]types.EventTypeNID, error) {
|
func (s *eventTypeStatements) bulkSelectEventTypeNID(
|
||||||
rows, err := s.bulkSelectEventTypeNIDStmt.Query(pq.StringArray(eventTypes))
|
ctx context.Context, eventTypes []string,
|
||||||
|
) (map[string]types.EventTypeNID, error) {
|
||||||
|
rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
@ -154,7 +155,10 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) insertEvent(
|
func (s *eventStatements) insertEvent(
|
||||||
roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID,
|
ctx context.Context,
|
||||||
|
roomNID types.RoomNID,
|
||||||
|
eventTypeNID types.EventTypeNID,
|
||||||
|
eventStateKeyNID types.EventStateKeyNID,
|
||||||
eventID string,
|
eventID string,
|
||||||
referenceSHA256 []byte,
|
referenceSHA256 []byte,
|
||||||
authEventNIDs []types.EventNID,
|
authEventNIDs []types.EventNID,
|
||||||
|
@ -162,24 +166,28 @@ func (s *eventStatements) insertEvent(
|
||||||
) (types.EventNID, types.StateSnapshotNID, error) {
|
) (types.EventNID, types.StateSnapshotNID, error) {
|
||||||
var eventNID int64
|
var eventNID int64
|
||||||
var stateNID int64
|
var stateNID int64
|
||||||
err := s.insertEventStmt.QueryRow(
|
err := s.insertEventStmt.QueryRowContext(
|
||||||
int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256,
|
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
|
||||||
eventNIDsAsArray(authEventNIDs), depth,
|
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
|
||||||
).Scan(&eventNID, &stateNID)
|
).Scan(&eventNID, &stateNID)
|
||||||
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) selectEvent(eventID string) (types.EventNID, types.StateSnapshotNID, error) {
|
func (s *eventStatements) selectEvent(
|
||||||
|
ctx context.Context, eventID string,
|
||||||
|
) (types.EventNID, types.StateSnapshotNID, error) {
|
||||||
var eventNID int64
|
var eventNID int64
|
||||||
var stateNID int64
|
var stateNID int64
|
||||||
err := s.selectEventStmt.QueryRow(eventID).Scan(&eventNID, &stateNID)
|
err := s.selectEventStmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID)
|
||||||
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||||
func (s *eventStatements) bulkSelectStateEventByID(eventIDs []string) ([]types.StateEntry, error) {
|
func (s *eventStatements) bulkSelectStateEventByID(
|
||||||
rows, err := s.bulkSelectStateEventByIDStmt.Query(pq.StringArray(eventIDs))
|
ctx context.Context, eventIDs []string,
|
||||||
|
) ([]types.StateEntry, error) {
|
||||||
|
rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -216,8 +224,10 @@ func (s *eventStatements) bulkSelectStateEventByID(eventIDs []string) ([]types.S
|
||||||
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
|
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
|
||||||
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
||||||
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
||||||
func (s *eventStatements) bulkSelectStateAtEventByID(eventIDs []string) ([]types.StateAtEvent, error) {
|
func (s *eventStatements) bulkSelectStateAtEventByID(
|
||||||
rows, err := s.bulkSelectStateAtEventByIDStmt.Query(pq.StringArray(eventIDs))
|
ctx context.Context, eventIDs []string,
|
||||||
|
) ([]types.StateAtEvent, error) {
|
||||||
|
rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -248,28 +258,40 @@ func (s *eventStatements) bulkSelectStateAtEventByID(eventIDs []string) ([]types
|
||||||
return results, err
|
return results, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) updateEventState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error {
|
func (s *eventStatements) updateEventState(
|
||||||
_, err := s.updateEventStateStmt.Exec(int64(eventNID), int64(stateNID))
|
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
||||||
|
) error {
|
||||||
|
_, err := s.updateEventStateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) selectEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) {
|
func (s *eventStatements) selectEventSentToOutput(
|
||||||
err = common.TxStmt(txn, s.selectEventSentToOutputStmt).QueryRow(int64(eventNID)).Scan(&sentToOutput)
|
ctx context.Context, txn *sql.Tx, eventNID types.EventNID,
|
||||||
|
) (sentToOutput bool, err error) {
|
||||||
|
stmt := common.TxStmt(txn, s.selectEventSentToOutputStmt)
|
||||||
|
stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) updateEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) error {
|
func (s *eventStatements) updateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error {
|
||||||
_, err := common.TxStmt(txn, s.updateEventSentToOutputStmt).Exec(int64(eventNID))
|
stmt := common.TxStmt(txn, s.updateEventSentToOutputStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, int64(eventNID))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) selectEventID(txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) {
|
func (s *eventStatements) selectEventID(
|
||||||
err = common.TxStmt(txn, s.selectEventIDStmt).QueryRow(int64(eventNID)).Scan(&eventID)
|
ctx context.Context, txn *sql.Tx, eventNID types.EventNID,
|
||||||
|
) (eventID string, err error) {
|
||||||
|
stmt := common.TxStmt(txn, s.selectEventIDStmt)
|
||||||
|
err = stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&eventID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) {
|
func (s *eventStatements) bulkSelectStateAtEventAndReference(
|
||||||
rows, err := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt).Query(eventNIDsAsArray(eventNIDs))
|
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||||
|
) ([]types.StateAtEventAndReference, error) {
|
||||||
|
stmt := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -304,8 +326,10 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventN
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) bulkSelectEventReference(eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error) {
|
func (s *eventStatements) bulkSelectEventReference(
|
||||||
rows, err := s.bulkSelectEventReferenceStmt.Query(eventNIDsAsArray(eventNIDs))
|
ctx context.Context, eventNIDs []types.EventNID,
|
||||||
|
) ([]gomatrixserverlib.EventReference, error) {
|
||||||
|
rows, err := s.bulkSelectEventReferenceStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -325,8 +349,8 @@ func (s *eventStatements) bulkSelectEventReference(eventNIDs []types.EventNID) (
|
||||||
}
|
}
|
||||||
|
|
||||||
// bulkSelectEventID returns a map from numeric event ID to string event ID.
|
// bulkSelectEventID returns a map from numeric event ID to string event ID.
|
||||||
func (s *eventStatements) bulkSelectEventID(eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
||||||
rows, err := s.bulkSelectEventIDStmt.Query(eventNIDsAsArray(eventNIDs))
|
rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -349,8 +373,8 @@ func (s *eventStatements) bulkSelectEventID(eventNIDs []types.EventNID) (map[typ
|
||||||
|
|
||||||
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
|
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
|
||||||
// If an event ID is not in the database then it is omitted from the map.
|
// If an event ID is not in the database then it is omitted from the map.
|
||||||
func (s *eventStatements) bulkSelectEventNID(eventIDs []string) (map[string]types.EventNID, error) {
|
func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) {
|
||||||
rows, err := s.bulkSelectEventNIDStmt.Query(pq.StringArray(eventIDs))
|
rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -367,9 +391,10 @@ func (s *eventStatements) bulkSelectEventNID(eventIDs []string) (map[string]type
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *eventStatements) selectMaxEventDepth(eventNIDs []types.EventNID) (int64, error) {
|
func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) {
|
||||||
var result int64
|
var result int64
|
||||||
err := s.selectMaxEventDepthStmt.QueryRow(eventNIDsAsArray(eventNIDs)).Scan(&result)
|
stmt := s.selectMaxEventDepthStmt
|
||||||
|
err := stmt.QueryRowContext(ctx, eventNIDsAsArray(eventNIDs)).Scan(&result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
@ -91,12 +92,13 @@ func (s *inviteStatements) prepare(db *sql.DB) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *inviteStatements) insertInviteEvent(
|
func (s *inviteStatements) insertInviteEvent(
|
||||||
|
ctx context.Context,
|
||||||
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
|
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
|
||||||
targetUserNID, senderUserNID types.EventStateKeyNID,
|
targetUserNID, senderUserNID types.EventStateKeyNID,
|
||||||
inviteEventJSON []byte,
|
inviteEventJSON []byte,
|
||||||
) (bool, error) {
|
) (bool, error) {
|
||||||
result, err := common.TxStmt(txn, s.insertInviteEventStmt).Exec(
|
result, err := common.TxStmt(txn, s.insertInviteEventStmt).ExecContext(
|
||||||
inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
|
ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
@ -109,9 +111,11 @@ func (s *inviteStatements) insertInviteEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *inviteStatements) updateInviteRetired(
|
func (s *inviteStatements) updateInviteRetired(
|
||||||
|
ctx context.Context,
|
||||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
) ([]string, error) {
|
) ([]string, error) {
|
||||||
rows, err := common.TxStmt(txn, s.updateInviteRetiredStmt).Query(roomNID, targetUserNID)
|
stmt := common.TxStmt(txn, s.updateInviteRetiredStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -129,10 +133,11 @@ func (s *inviteStatements) updateInviteRetired(
|
||||||
|
|
||||||
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs
|
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs
|
||||||
func (s *inviteStatements) selectInviteActiveForUserInRoom(
|
func (s *inviteStatements) selectInviteActiveForUserInRoom(
|
||||||
|
ctx context.Context,
|
||||||
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
||||||
) ([]types.EventStateKeyNID, error) {
|
) ([]types.EventStateKeyNID, error) {
|
||||||
rows, err := s.selectInviteActiveForUserInRoomStmt.Query(
|
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
|
||||||
targetUserNID, roomNID,
|
ctx, targetUserNID, roomNID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
@ -114,34 +115,38 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) insertMembership(
|
func (s *membershipStatements) insertMembership(
|
||||||
|
ctx context.Context,
|
||||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
) error {
|
) error {
|
||||||
_, err := common.TxStmt(txn, s.insertMembershipStmt).Exec(roomNID, targetUserNID)
|
stmt := common.TxStmt(txn, s.insertMembershipStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, roomNID, targetUserNID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) selectMembershipForUpdate(
|
func (s *membershipStatements) selectMembershipForUpdate(
|
||||||
|
ctx context.Context,
|
||||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
) (membership membershipState, err error) {
|
) (membership membershipState, err error) {
|
||||||
err = common.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRow(
|
err = common.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext(
|
||||||
roomNID, targetUserNID,
|
ctx, roomNID, targetUserNID,
|
||||||
).Scan(&membership)
|
).Scan(&membership)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) selectMembershipFromRoomAndTarget(
|
func (s *membershipStatements) selectMembershipFromRoomAndTarget(
|
||||||
|
ctx context.Context,
|
||||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
) (eventNID types.EventNID, membership membershipState, err error) {
|
) (eventNID types.EventNID, membership membershipState, err error) {
|
||||||
err = s.selectMembershipFromRoomAndTargetStmt.QueryRow(
|
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
|
||||||
roomNID, targetUserNID,
|
ctx, roomNID, targetUserNID,
|
||||||
).Scan(&membership, &eventNID)
|
).Scan(&membership, &eventNID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) selectMembershipsFromRoom(
|
func (s *membershipStatements) selectMembershipsFromRoom(
|
||||||
roomNID types.RoomNID,
|
ctx context.Context, roomNID types.RoomNID,
|
||||||
) (eventNIDs []types.EventNID, err error) {
|
) (eventNIDs []types.EventNID, err error) {
|
||||||
rows, err := s.selectMembershipsFromRoomStmt.Query(roomNID)
|
rows, err := s.selectMembershipsFromRoomStmt.QueryContext(ctx, roomNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -156,9 +161,11 @@ func (s *membershipStatements) selectMembershipsFromRoom(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
|
func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
|
||||||
|
ctx context.Context,
|
||||||
roomNID types.RoomNID, membership membershipState,
|
roomNID types.RoomNID, membership membershipState,
|
||||||
) (eventNIDs []types.EventNID, err error) {
|
) (eventNIDs []types.EventNID, err error) {
|
||||||
rows, err := s.selectMembershipsFromRoomAndMembershipStmt.Query(roomNID, membership)
|
stmt := s.selectMembershipsFromRoomAndMembershipStmt
|
||||||
|
rows, err := stmt.QueryContext(ctx, roomNID, membership)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -174,12 +181,13 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) updateMembership(
|
func (s *membershipStatements) updateMembership(
|
||||||
|
ctx context.Context,
|
||||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
senderUserNID types.EventStateKeyNID, membership membershipState,
|
senderUserNID types.EventStateKeyNID, membership membershipState,
|
||||||
eventNID types.EventNID,
|
eventNID types.EventNID,
|
||||||
) error {
|
) error {
|
||||||
_, err := common.TxStmt(txn, s.updateMembershipStmt).Exec(
|
_, err := common.TxStmt(txn, s.updateMembershipStmt).ExecContext(
|
||||||
roomNID, targetUserNID, senderUserNID, membership, eventNID,
|
ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
@ -73,14 +74,26 @@ func (s *previousEventStatements) prepare(db *sql.DB) (err error) {
|
||||||
}.prepare(db)
|
}.prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *previousEventStatements) insertPreviousEvent(txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error {
|
func (s *previousEventStatements) insertPreviousEvent(
|
||||||
_, err := common.TxStmt(txn, s.insertPreviousEventStmt).Exec(previousEventID, previousEventReferenceSHA256, int64(eventNID))
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
previousEventID string,
|
||||||
|
previousEventReferenceSHA256 []byte,
|
||||||
|
eventNID types.EventNID,
|
||||||
|
) error {
|
||||||
|
stmt := common.TxStmt(txn, s.insertPreviousEventStmt)
|
||||||
|
_, err := stmt.ExecContext(
|
||||||
|
ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
|
||||||
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the event reference exists
|
// Check if the event reference exists
|
||||||
// Returns sql.ErrNoRows if the event reference doesn't exist.
|
// Returns sql.ErrNoRows if the event reference doesn't exist.
|
||||||
func (s *previousEventStatements) selectPreviousEventExists(txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error {
|
func (s *previousEventStatements) selectPreviousEventExists(
|
||||||
|
ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte,
|
||||||
|
) error {
|
||||||
var ok int64
|
var ok int64
|
||||||
return common.TxStmt(txn, s.selectPreviousEventExistsStmt).QueryRow(eventID, eventReferenceSHA256).Scan(&ok)
|
stmt := common.TxStmt(txn, s.selectPreviousEventExistsStmt)
|
||||||
|
return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok)
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -62,22 +63,28 @@ func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) {
|
||||||
}.prepare(db)
|
}.prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomAliasesStatements) insertRoomAlias(alias string, roomID string) (err error) {
|
func (s *roomAliasesStatements) insertRoomAlias(
|
||||||
_, err = s.insertRoomAliasStmt.Exec(alias, roomID)
|
ctx context.Context, alias string, roomID string,
|
||||||
|
) (err error) {
|
||||||
|
_, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomAliasesStatements) selectRoomIDFromAlias(alias string) (roomID string, err error) {
|
func (s *roomAliasesStatements) selectRoomIDFromAlias(
|
||||||
err = s.selectRoomIDFromAliasStmt.QueryRow(alias).Scan(&roomID)
|
ctx context.Context, alias string,
|
||||||
|
) (roomID string, err error) {
|
||||||
|
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomAliasesStatements) selectAliasesFromRoomID(roomID string) (aliases []string, err error) {
|
func (s *roomAliasesStatements) selectAliasesFromRoomID(
|
||||||
|
ctx context.Context, roomID string,
|
||||||
|
) (aliases []string, err error) {
|
||||||
aliases = []string{}
|
aliases = []string{}
|
||||||
rows, err := s.selectAliasesFromRoomIDStmt.Query(roomID)
|
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -94,7 +101,9 @@ func (s *roomAliasesStatements) selectAliasesFromRoomID(roomID string) (aliases
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomAliasesStatements) deleteRoomAlias(alias string) (err error) {
|
func (s *roomAliasesStatements) deleteRoomAlias(
|
||||||
_, err = s.deleteRoomAliasStmt.Exec(alias)
|
ctx context.Context, alias string,
|
||||||
|
) (err error) {
|
||||||
|
_, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
@ -81,22 +82,31 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
|
||||||
}.prepare(db)
|
}.prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) insertRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) {
|
func (s *roomStatements) insertRoomNID(
|
||||||
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
|
) (types.RoomNID, error) {
|
||||||
var roomNID int64
|
var roomNID int64
|
||||||
err := common.TxStmt(txn, s.insertRoomNIDStmt).QueryRow(roomID).Scan(&roomNID)
|
stmt := common.TxStmt(txn, s.insertRoomNIDStmt)
|
||||||
|
err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
|
||||||
return types.RoomNID(roomNID), err
|
return types.RoomNID(roomNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) selectRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) {
|
func (s *roomStatements) selectRoomNID(
|
||||||
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
|
) (types.RoomNID, error) {
|
||||||
var roomNID int64
|
var roomNID int64
|
||||||
err := common.TxStmt(txn, s.selectRoomNIDStmt).QueryRow(roomID).Scan(&roomNID)
|
stmt := common.TxStmt(txn, s.selectRoomNIDStmt)
|
||||||
|
err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
|
||||||
return types.RoomNID(roomNID), err
|
return types.RoomNID(roomNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) selectLatestEventNIDs(roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) {
|
func (s *roomStatements) selectLatestEventNIDs(
|
||||||
|
ctx context.Context, roomNID types.RoomNID,
|
||||||
|
) ([]types.EventNID, types.StateSnapshotNID, error) {
|
||||||
var nids pq.Int64Array
|
var nids pq.Int64Array
|
||||||
var stateSnapshotNID int64
|
var stateSnapshotNID int64
|
||||||
err := s.selectLatestEventNIDsStmt.QueryRow(int64(roomNID)).Scan(&nids, &stateSnapshotNID)
|
stmt := s.selectLatestEventNIDsStmt
|
||||||
|
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
@ -107,13 +117,14 @@ func (s *roomStatements) selectLatestEventNIDs(roomNID types.RoomNID) ([]types.E
|
||||||
return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil
|
return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID types.RoomNID) (
|
func (s *roomStatements) selectLatestEventsNIDsForUpdate(
|
||||||
[]types.EventNID, types.EventNID, types.StateSnapshotNID, error,
|
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
|
||||||
) {
|
) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) {
|
||||||
var nids pq.Int64Array
|
var nids pq.Int64Array
|
||||||
var lastEventSentNID int64
|
var lastEventSentNID int64
|
||||||
var stateSnapshotNID int64
|
var stateSnapshotNID int64
|
||||||
err := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID)
|
stmt := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt)
|
||||||
|
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, 0, err
|
return nil, 0, 0, err
|
||||||
}
|
}
|
||||||
|
@ -125,11 +136,20 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID ty
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) updateLatestEventNIDs(
|
func (s *roomStatements) updateLatestEventNIDs(
|
||||||
txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID,
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
roomNID types.RoomNID,
|
||||||
|
eventNIDs []types.EventNID,
|
||||||
|
lastEventSentNID types.EventNID,
|
||||||
stateSnapshotNID types.StateSnapshotNID,
|
stateSnapshotNID types.StateSnapshotNID,
|
||||||
) error {
|
) error {
|
||||||
_, err := common.TxStmt(txn, s.updateLatestEventNIDsStmt).Exec(
|
stmt := common.TxStmt(txn, s.updateLatestEventNIDsStmt)
|
||||||
roomNID, eventNIDsAsArray(eventNIDs), int64(lastEventSentNID), int64(stateSnapshotNID),
|
_, err := stmt.ExecContext(
|
||||||
|
ctx,
|
||||||
|
roomNID,
|
||||||
|
eventNIDsAsArray(eventNIDs),
|
||||||
|
int64(lastEventSentNID),
|
||||||
|
int64(stateSnapshotNID),
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
|
@ -97,9 +98,14 @@ func (s *stateBlockStatements) prepare(db *sql.DB) (err error) {
|
||||||
}.prepare(db)
|
}.prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateBlockStatements) bulkInsertStateData(stateBlockNID types.StateBlockNID, entries []types.StateEntry) error {
|
func (s *stateBlockStatements) bulkInsertStateData(
|
||||||
|
ctx context.Context,
|
||||||
|
stateBlockNID types.StateBlockNID,
|
||||||
|
entries []types.StateEntry,
|
||||||
|
) error {
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
_, err := s.insertStateDataStmt.Exec(
|
_, err := s.insertStateDataStmt.ExecContext(
|
||||||
|
ctx,
|
||||||
int64(stateBlockNID),
|
int64(stateBlockNID),
|
||||||
int64(entry.EventTypeNID),
|
int64(entry.EventTypeNID),
|
||||||
int64(entry.EventStateKeyNID),
|
int64(entry.EventStateKeyNID),
|
||||||
|
@ -112,18 +118,22 @@ func (s *stateBlockStatements) bulkInsertStateData(stateBlockNID types.StateBloc
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateBlockStatements) selectNextStateBlockNID() (types.StateBlockNID, error) {
|
func (s *stateBlockStatements) selectNextStateBlockNID(
|
||||||
|
ctx context.Context,
|
||||||
|
) (types.StateBlockNID, error) {
|
||||||
var stateBlockNID int64
|
var stateBlockNID int64
|
||||||
err := s.selectNextStateBlockNIDStmt.QueryRow().Scan(&stateBlockNID)
|
err := s.selectNextStateBlockNIDStmt.QueryRowContext(ctx).Scan(&stateBlockNID)
|
||||||
return types.StateBlockNID(stateBlockNID), err
|
return types.StateBlockNID(stateBlockNID), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateBlockStatements) bulkSelectStateBlockEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) {
|
func (s *stateBlockStatements) bulkSelectStateBlockEntries(
|
||||||
|
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
|
||||||
|
) ([]types.StateEntryList, error) {
|
||||||
nids := make([]int64, len(stateBlockNIDs))
|
nids := make([]int64, len(stateBlockNIDs))
|
||||||
for i := range stateBlockNIDs {
|
for i := range stateBlockNIDs {
|
||||||
nids[i] = int64(stateBlockNIDs[i])
|
nids[i] = int64(stateBlockNIDs[i])
|
||||||
}
|
}
|
||||||
rows, err := s.bulkSelectStateBlockEntriesStmt.Query(pq.Int64Array(nids))
|
rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, pq.Int64Array(nids))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -165,15 +175,20 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(stateBlockNIDs []type
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
|
func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
|
||||||
stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple,
|
ctx context.Context,
|
||||||
|
stateBlockNIDs []types.StateBlockNID,
|
||||||
|
stateKeyTuples []types.StateKeyTuple,
|
||||||
) ([]types.StateEntryList, error) {
|
) ([]types.StateEntryList, error) {
|
||||||
tuples := stateKeyTupleSorter(stateKeyTuples)
|
tuples := stateKeyTupleSorter(stateKeyTuples)
|
||||||
// Sort the tuples so that we can run binary search against them as we filter the rows returned by the db.
|
// Sort the tuples so that we can run binary search against them as we filter the rows returned by the db.
|
||||||
sort.Sort(tuples)
|
sort.Sort(tuples)
|
||||||
|
|
||||||
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
||||||
rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.Query(
|
rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.QueryContext(
|
||||||
stateBlockNIDsAsArray(stateBlockNIDs), eventTypeNIDArray, eventStateKeyNIDArray,
|
ctx,
|
||||||
|
stateBlockNIDsAsArray(stateBlockNIDs),
|
||||||
|
eventTypeNIDArray,
|
||||||
|
eventStateKeyNIDArray,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -15,29 +15,30 @@
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
|
||||||
"sort"
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestStateKeyTupleSorter(t *testing.T) {
|
func TestStateKeyTupleSorter(t *testing.T) {
|
||||||
input := stateKeyTupleSorter{
|
input := stateKeyTupleSorter{
|
||||||
{1, 2},
|
{EventTypeNID: 1, EventStateKeyNID: 2},
|
||||||
{1, 4},
|
{EventTypeNID: 1, EventStateKeyNID: 4},
|
||||||
{2, 2},
|
{EventTypeNID: 2, EventStateKeyNID: 2},
|
||||||
{1, 1},
|
{EventTypeNID: 1, EventStateKeyNID: 1},
|
||||||
}
|
}
|
||||||
want := []types.StateKeyTuple{
|
want := []types.StateKeyTuple{
|
||||||
{1, 1},
|
{EventTypeNID: 1, EventStateKeyNID: 1},
|
||||||
{1, 2},
|
{EventTypeNID: 1, EventStateKeyNID: 2},
|
||||||
{1, 4},
|
{EventTypeNID: 1, EventStateKeyNID: 4},
|
||||||
{2, 2},
|
{EventTypeNID: 2, EventStateKeyNID: 2},
|
||||||
}
|
}
|
||||||
doNotWant := []types.StateKeyTuple{
|
doNotWant := []types.StateKeyTuple{
|
||||||
{0, 0},
|
{EventTypeNID: 0, EventStateKeyNID: 0},
|
||||||
{1, 3},
|
{EventTypeNID: 1, EventStateKeyNID: 3},
|
||||||
{2, 1},
|
{EventTypeNID: 2, EventStateKeyNID: 1},
|
||||||
{3, 1},
|
{EventTypeNID: 3, EventStateKeyNID: 1},
|
||||||
}
|
}
|
||||||
wantTypeNIDs := []int64{1, 2}
|
wantTypeNIDs := []int64{1, 2}
|
||||||
wantStateKeyNIDs := []int64{1, 2, 4}
|
wantStateKeyNIDs := []int64{1, 2, 4}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
@ -74,21 +75,25 @@ func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) {
|
||||||
}.prepare(db)
|
}.prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateSnapshotStatements) insertState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error) {
|
func (s *stateSnapshotStatements) insertState(
|
||||||
|
ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
|
||||||
|
) (stateNID types.StateSnapshotNID, err error) {
|
||||||
nids := make([]int64, len(stateBlockNIDs))
|
nids := make([]int64, len(stateBlockNIDs))
|
||||||
for i := range stateBlockNIDs {
|
for i := range stateBlockNIDs {
|
||||||
nids[i] = int64(stateBlockNIDs[i])
|
nids[i] = int64(stateBlockNIDs[i])
|
||||||
}
|
}
|
||||||
err = s.insertStateStmt.QueryRow(int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
|
err = s.insertStateStmt.QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) {
|
func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
|
||||||
|
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
||||||
|
) ([]types.StateBlockNIDList, error) {
|
||||||
nids := make([]int64, len(stateNIDs))
|
nids := make([]int64, len(stateNIDs))
|
||||||
for i := range stateNIDs {
|
for i := range stateNIDs {
|
||||||
nids[i] = int64(stateNIDs[i])
|
nids[i] = int64(stateNIDs[i])
|
||||||
}
|
}
|
||||||
rows, err := s.bulkSelectStateBlockNIDsStmt.Query(pq.Int64Array(nids))
|
rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
// Import the postgres database driver.
|
// Import the postgres database driver.
|
||||||
|
@ -43,7 +44,9 @@ func Open(dataSourceName string) (*Database, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// StoreEvent implements input.EventDatabase
|
// StoreEvent implements input.EventDatabase
|
||||||
func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) {
|
func (d *Database) StoreEvent(
|
||||||
|
ctx context.Context, event gomatrixserverlib.Event, authEventNIDs []types.EventNID,
|
||||||
|
) (types.RoomNID, types.StateAtEvent, error) {
|
||||||
var (
|
var (
|
||||||
roomNID types.RoomNID
|
roomNID types.RoomNID
|
||||||
eventTypeNID types.EventTypeNID
|
eventTypeNID types.EventTypeNID
|
||||||
|
@ -53,11 +56,11 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
if roomNID, err = d.assignRoomNID(nil, event.RoomID()); err != nil {
|
if roomNID, err = d.assignRoomNID(ctx, nil, event.RoomID()); err != nil {
|
||||||
return 0, types.StateAtEvent{}, err
|
return 0, types.StateAtEvent{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if eventTypeNID, err = d.assignEventTypeNID(event.Type()); err != nil {
|
if eventTypeNID, err = d.assignEventTypeNID(ctx, event.Type()); err != nil {
|
||||||
return 0, types.StateAtEvent{}, err
|
return 0, types.StateAtEvent{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,12 +68,13 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ
|
||||||
// Assigned a numeric ID for the state_key if there is one present.
|
// Assigned a numeric ID for the state_key if there is one present.
|
||||||
// Otherwise set the numeric ID for the state_key to 0.
|
// Otherwise set the numeric ID for the state_key to 0.
|
||||||
if eventStateKey != nil {
|
if eventStateKey != nil {
|
||||||
if eventStateKeyNID, err = d.assignStateKeyNID(nil, *eventStateKey); err != nil {
|
if eventStateKeyNID, err = d.assignStateKeyNID(ctx, nil, *eventStateKey); err != nil {
|
||||||
return 0, types.StateAtEvent{}, err
|
return 0, types.StateAtEvent{}, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if eventNID, stateNID, err = d.statements.insertEvent(
|
if eventNID, stateNID, err = d.statements.insertEvent(
|
||||||
|
ctx,
|
||||||
roomNID,
|
roomNID,
|
||||||
eventTypeNID,
|
eventTypeNID,
|
||||||
eventStateKeyNID,
|
eventStateKeyNID,
|
||||||
|
@ -81,14 +85,14 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ
|
||||||
); err != nil {
|
); err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
// We've already inserted the event so select the numeric event ID
|
// We've already inserted the event so select the numeric event ID
|
||||||
eventNID, stateNID, err = d.statements.selectEvent(event.EventID())
|
eventNID, stateNID, err = d.statements.selectEvent(ctx, event.EventID())
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, types.StateAtEvent{}, err
|
return 0, types.StateAtEvent{}, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = d.statements.insertEventJSON(eventNID, event.JSON()); err != nil {
|
if err = d.statements.insertEventJSON(ctx, eventNID, event.JSON()); err != nil {
|
||||||
return 0, types.StateAtEvent{}, err
|
return 0, types.StateAtEvent{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,76 +108,94 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) assignRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) {
|
func (d *Database) assignRoomNID(
|
||||||
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
|
) (types.RoomNID, error) {
|
||||||
// Check if we already have a numeric ID in the database.
|
// Check if we already have a numeric ID in the database.
|
||||||
roomNID, err := d.statements.selectRoomNID(txn, roomID)
|
roomNID, err := d.statements.selectRoomNID(ctx, txn, roomID)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
// We don't have a numeric ID so insert one into the database.
|
// We don't have a numeric ID so insert one into the database.
|
||||||
roomNID, err = d.statements.insertRoomNID(txn, roomID)
|
roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
// We raced with another insert so run the select again.
|
// We raced with another insert so run the select again.
|
||||||
roomNID, err = d.statements.selectRoomNID(txn, roomID)
|
roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return roomNID, err
|
return roomNID, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) assignEventTypeNID(eventType string) (types.EventTypeNID, error) {
|
func (d *Database) assignEventTypeNID(
|
||||||
|
ctx context.Context, eventType string,
|
||||||
|
) (types.EventTypeNID, error) {
|
||||||
// Check if we already have a numeric ID in the database.
|
// Check if we already have a numeric ID in the database.
|
||||||
eventTypeNID, err := d.statements.selectEventTypeNID(eventType)
|
eventTypeNID, err := d.statements.selectEventTypeNID(ctx, eventType)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
// We don't have a numeric ID so insert one into the database.
|
// We don't have a numeric ID so insert one into the database.
|
||||||
eventTypeNID, err = d.statements.insertEventTypeNID(eventType)
|
eventTypeNID, err = d.statements.insertEventTypeNID(ctx, eventType)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
// We raced with another insert so run the select again.
|
// We raced with another insert so run the select again.
|
||||||
eventTypeNID, err = d.statements.selectEventTypeNID(eventType)
|
eventTypeNID, err = d.statements.selectEventTypeNID(ctx, eventType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return eventTypeNID, err
|
return eventTypeNID, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) assignStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) {
|
func (d *Database) assignStateKeyNID(
|
||||||
|
ctx context.Context, txn *sql.Tx, eventStateKey string,
|
||||||
|
) (types.EventStateKeyNID, error) {
|
||||||
// Check if we already have a numeric ID in the database.
|
// Check if we already have a numeric ID in the database.
|
||||||
eventStateKeyNID, err := d.statements.selectEventStateKeyNID(txn, eventStateKey)
|
eventStateKeyNID, err := d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
// We don't have a numeric ID so insert one into the database.
|
// We don't have a numeric ID so insert one into the database.
|
||||||
eventStateKeyNID, err = d.statements.insertEventStateKeyNID(txn, eventStateKey)
|
eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
// We raced with another insert so run the select again.
|
// We raced with another insert so run the select again.
|
||||||
eventStateKeyNID, err = d.statements.selectEventStateKeyNID(txn, eventStateKey)
|
eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return eventStateKeyNID, err
|
return eventStateKeyNID, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// StateEntriesForEventIDs implements input.EventDatabase
|
// StateEntriesForEventIDs implements input.EventDatabase
|
||||||
func (d *Database) StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error) {
|
func (d *Database) StateEntriesForEventIDs(
|
||||||
return d.statements.bulkSelectStateEventByID(eventIDs)
|
ctx context.Context, eventIDs []string,
|
||||||
|
) ([]types.StateEntry, error) {
|
||||||
|
return d.statements.bulkSelectStateEventByID(ctx, eventIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// EventTypeNIDs implements state.RoomStateDatabase
|
// EventTypeNIDs implements state.RoomStateDatabase
|
||||||
func (d *Database) EventTypeNIDs(eventTypes []string) (map[string]types.EventTypeNID, error) {
|
func (d *Database) EventTypeNIDs(
|
||||||
return d.statements.bulkSelectEventTypeNID(eventTypes)
|
ctx context.Context, eventTypes []string,
|
||||||
|
) (map[string]types.EventTypeNID, error) {
|
||||||
|
return d.statements.bulkSelectEventTypeNID(ctx, eventTypes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// EventStateKeyNIDs implements state.RoomStateDatabase
|
// EventStateKeyNIDs implements state.RoomStateDatabase
|
||||||
func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) {
|
func (d *Database) EventStateKeyNIDs(
|
||||||
return d.statements.bulkSelectEventStateKeyNID(eventStateKeys)
|
ctx context.Context, eventStateKeys []string,
|
||||||
|
) (map[string]types.EventStateKeyNID, error) {
|
||||||
|
return d.statements.bulkSelectEventStateKeyNID(ctx, eventStateKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
// EventStateKeys implements query.RoomserverQueryAPIDatabase
|
// EventStateKeys implements query.RoomserverQueryAPIDatabase
|
||||||
func (d *Database) EventStateKeys(eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) {
|
func (d *Database) EventStateKeys(
|
||||||
return d.statements.bulkSelectEventStateKey(eventStateKeyNIDs)
|
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||||
|
) (map[types.EventStateKeyNID]string, error) {
|
||||||
|
return d.statements.bulkSelectEventStateKey(ctx, eventStateKeyNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// EventNIDs implements query.RoomserverQueryAPIDatabase
|
// EventNIDs implements query.RoomserverQueryAPIDatabase
|
||||||
func (d *Database) EventNIDs(eventIDs []string) (map[string]types.EventNID, error) {
|
func (d *Database) EventNIDs(
|
||||||
return d.statements.bulkSelectEventNID(eventIDs)
|
ctx context.Context, eventIDs []string,
|
||||||
|
) (map[string]types.EventNID, error) {
|
||||||
|
return d.statements.bulkSelectEventNID(ctx, eventIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Events implements input.EventDatabase
|
// Events implements input.EventDatabase
|
||||||
func (d *Database) Events(eventNIDs []types.EventNID) ([]types.Event, error) {
|
func (d *Database) Events(
|
||||||
eventJSONs, err := d.statements.bulkSelectEventJSON(eventNIDs)
|
ctx context.Context, eventNIDs []types.EventNID,
|
||||||
|
) ([]types.Event, error) {
|
||||||
|
eventJSONs, err := d.statements.bulkSelectEventJSON(ctx, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -191,78 +213,98 @@ func (d *Database) Events(eventNIDs []types.EventNID) ([]types.Event, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddState implements input.EventDatabase
|
// AddState implements input.EventDatabase
|
||||||
func (d *Database) AddState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) {
|
func (d *Database) AddState(
|
||||||
|
ctx context.Context,
|
||||||
|
roomNID types.RoomNID,
|
||||||
|
stateBlockNIDs []types.StateBlockNID,
|
||||||
|
state []types.StateEntry,
|
||||||
|
) (types.StateSnapshotNID, error) {
|
||||||
if len(state) > 0 {
|
if len(state) > 0 {
|
||||||
stateBlockNID, err := d.statements.selectNextStateBlockNID()
|
stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
if err = d.statements.bulkInsertStateData(stateBlockNID, state); err != nil {
|
if err = d.statements.bulkInsertStateData(ctx, stateBlockNID, state); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
|
stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return d.statements.insertState(roomNID, stateBlockNIDs)
|
return d.statements.insertState(ctx, roomNID, stateBlockNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetState implements input.EventDatabase
|
// SetState implements input.EventDatabase
|
||||||
func (d *Database) SetState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error {
|
func (d *Database) SetState(
|
||||||
return d.statements.updateEventState(eventNID, stateNID)
|
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
||||||
|
) error {
|
||||||
|
return d.statements.updateEventState(ctx, eventNID, stateNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StateAtEventIDs implements input.EventDatabase
|
// StateAtEventIDs implements input.EventDatabase
|
||||||
func (d *Database) StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, error) {
|
func (d *Database) StateAtEventIDs(
|
||||||
return d.statements.bulkSelectStateAtEventByID(eventIDs)
|
ctx context.Context, eventIDs []string,
|
||||||
|
) ([]types.StateAtEvent, error) {
|
||||||
|
return d.statements.bulkSelectStateAtEventByID(ctx, eventIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StateBlockNIDs implements state.RoomStateDatabase
|
// StateBlockNIDs implements state.RoomStateDatabase
|
||||||
func (d *Database) StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) {
|
func (d *Database) StateBlockNIDs(
|
||||||
return d.statements.bulkSelectStateBlockNIDs(stateNIDs)
|
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
||||||
|
) ([]types.StateBlockNIDList, error) {
|
||||||
|
return d.statements.bulkSelectStateBlockNIDs(ctx, stateNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StateEntries implements state.RoomStateDatabase
|
// StateEntries implements state.RoomStateDatabase
|
||||||
func (d *Database) StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) {
|
func (d *Database) StateEntries(
|
||||||
return d.statements.bulkSelectStateBlockEntries(stateBlockNIDs)
|
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
|
||||||
|
) ([]types.StateEntryList, error) {
|
||||||
|
return d.statements.bulkSelectStateBlockEntries(ctx, stateBlockNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SnapshotNIDFromEventID implements state.RoomStateDatabase
|
// SnapshotNIDFromEventID implements state.RoomStateDatabase
|
||||||
func (d *Database) SnapshotNIDFromEventID(eventID string) (types.StateSnapshotNID, error) {
|
func (d *Database) SnapshotNIDFromEventID(
|
||||||
_, stateNID, err := d.statements.selectEvent(eventID)
|
ctx context.Context, eventID string,
|
||||||
|
) (types.StateSnapshotNID, error) {
|
||||||
|
_, stateNID, err := d.statements.selectEvent(ctx, eventID)
|
||||||
return stateNID, err
|
return stateNID, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// EventIDs implements input.RoomEventDatabase
|
// EventIDs implements input.RoomEventDatabase
|
||||||
func (d *Database) EventIDs(eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
func (d *Database) EventIDs(
|
||||||
return d.statements.bulkSelectEventID(eventNIDs)
|
ctx context.Context, eventNIDs []types.EventNID,
|
||||||
|
) (map[types.EventNID]string, error) {
|
||||||
|
return d.statements.bulkSelectEventID(ctx, eventNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLatestEventsForUpdate implements input.EventDatabase
|
// GetLatestEventsForUpdate implements input.EventDatabase
|
||||||
func (d *Database) GetLatestEventsForUpdate(roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error) {
|
func (d *Database) GetLatestEventsForUpdate(
|
||||||
|
ctx context.Context, roomNID types.RoomNID,
|
||||||
|
) (types.RoomRecentEventsUpdater, error) {
|
||||||
txn, err := d.db.Begin()
|
txn, err := d.db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := d.statements.selectLatestEventsNIDsForUpdate(txn, roomNID)
|
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
|
||||||
|
d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
txn.Rollback()
|
txn.Rollback()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(txn, eventNIDs)
|
stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
txn.Rollback()
|
txn.Rollback()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var lastEventIDSent string
|
var lastEventIDSent string
|
||||||
if lastEventNIDSent != 0 {
|
if lastEventNIDSent != 0 {
|
||||||
lastEventIDSent, err = d.statements.selectEventID(txn, lastEventNIDSent)
|
lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
txn.Rollback()
|
txn.Rollback()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &roomRecentEventsUpdater{
|
return &roomRecentEventsUpdater{
|
||||||
transaction{txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
|
transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -293,7 +335,7 @@ func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotN
|
||||||
// StorePreviousEvents implements types.RoomRecentEventsUpdater
|
// StorePreviousEvents implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
|
func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
|
||||||
for _, ref := range previousEventReferences {
|
for _, ref := range previousEventReferences {
|
||||||
if err := u.d.statements.insertPreviousEvent(u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
|
if err := u.d.statements.insertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -302,7 +344,7 @@ func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, p
|
||||||
|
|
||||||
// IsReferenced implements types.RoomRecentEventsUpdater
|
// IsReferenced implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
|
func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
|
||||||
err := u.d.statements.selectPreviousEventExists(u.txn, eventReference.EventID, eventReference.EventSHA256)
|
err := u.d.statements.selectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
@ -321,26 +363,26 @@ func (u *roomRecentEventsUpdater) SetLatestEvents(
|
||||||
for i := range latest {
|
for i := range latest {
|
||||||
eventNIDs[i] = latest[i].EventNID
|
eventNIDs[i] = latest[i].EventNID
|
||||||
}
|
}
|
||||||
return u.d.statements.updateLatestEventNIDs(u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
|
return u.d.statements.updateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasEventBeenSent implements types.RoomRecentEventsUpdater
|
// HasEventBeenSent implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
|
func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
|
||||||
return u.d.statements.selectEventSentToOutput(u.txn, eventNID)
|
return u.d.statements.selectEventSentToOutput(u.ctx, u.txn, eventNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkEventAsSent implements types.RoomRecentEventsUpdater
|
// MarkEventAsSent implements types.RoomRecentEventsUpdater
|
||||||
func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
|
func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
|
||||||
return u.d.statements.updateEventSentToOutput(u.txn, eventNID)
|
return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) {
|
func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) {
|
||||||
return u.d.membershipUpdaterTxn(u.txn, u.roomNID, targetUserNID)
|
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoomNID implements query.RoomserverQueryAPIDB
|
// RoomNID implements query.RoomserverQueryAPIDB
|
||||||
func (d *Database) RoomNID(roomID string) (types.RoomNID, error) {
|
func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) {
|
||||||
roomNID, err := d.statements.selectRoomNID(nil, roomID)
|
roomNID, err := d.statements.selectRoomNID(ctx, nil, roomID)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
@ -348,16 +390,18 @@ func (d *Database) RoomNID(roomID string) (types.RoomNID, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// LatestEventIDs implements query.RoomserverQueryAPIDatabase
|
// LatestEventIDs implements query.RoomserverQueryAPIDatabase
|
||||||
func (d *Database) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) {
|
func (d *Database) LatestEventIDs(
|
||||||
eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(roomNID)
|
ctx context.Context, roomNID types.RoomNID,
|
||||||
|
) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) {
|
||||||
|
eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(ctx, roomNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, 0, err
|
return nil, 0, 0, err
|
||||||
}
|
}
|
||||||
references, err := d.statements.bulkSelectEventReference(eventNIDs)
|
references, err := d.statements.bulkSelectEventReference(ctx, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, 0, err
|
return nil, 0, 0, err
|
||||||
}
|
}
|
||||||
depth, err := d.statements.selectMaxEventDepth(eventNIDs)
|
depth, err := d.statements.selectMaxEventDepth(ctx, eventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, 0, err
|
return nil, 0, 0, err
|
||||||
}
|
}
|
||||||
|
@ -366,40 +410,48 @@ func (d *Database) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.Ev
|
||||||
|
|
||||||
// GetInvitesForUser implements query.RoomserverQueryAPIDatabase
|
// GetInvitesForUser implements query.RoomserverQueryAPIDatabase
|
||||||
func (d *Database) GetInvitesForUser(
|
func (d *Database) GetInvitesForUser(
|
||||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
ctx context.Context,
|
||||||
|
roomNID types.RoomNID,
|
||||||
|
targetUserNID types.EventStateKeyNID,
|
||||||
) (senderUserIDs []types.EventStateKeyNID, err error) {
|
) (senderUserIDs []types.EventStateKeyNID, err error) {
|
||||||
return d.statements.selectInviteActiveForUserInRoom(targetUserNID, roomNID)
|
return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetRoomAlias implements alias.RoomserverAliasAPIDB
|
// SetRoomAlias implements alias.RoomserverAliasAPIDB
|
||||||
func (d *Database) SetRoomAlias(alias string, roomID string) error {
|
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string) error {
|
||||||
return d.statements.insertRoomAlias(alias, roomID)
|
return d.statements.insertRoomAlias(ctx, alias, roomID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRoomIDFromAlias implements alias.RoomserverAliasAPIDB
|
// GetRoomIDFromAlias implements alias.RoomserverAliasAPIDB
|
||||||
func (d *Database) GetRoomIDFromAlias(alias string) (string, error) {
|
func (d *Database) GetRoomIDFromAlias(ctx context.Context, alias string) (string, error) {
|
||||||
return d.statements.selectRoomIDFromAlias(alias)
|
return d.statements.selectRoomIDFromAlias(ctx, alias)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAliasesFromRoomID implements alias.RoomserverAliasAPIDB
|
// GetAliasesFromRoomID implements alias.RoomserverAliasAPIDB
|
||||||
func (d *Database) GetAliasesFromRoomID(roomID string) ([]string, error) {
|
func (d *Database) GetAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error) {
|
||||||
return d.statements.selectAliasesFromRoomID(roomID)
|
return d.statements.selectAliasesFromRoomID(ctx, roomID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveRoomAlias implements alias.RoomserverAliasAPIDB
|
// RemoveRoomAlias implements alias.RoomserverAliasAPIDB
|
||||||
func (d *Database) RemoveRoomAlias(alias string) error {
|
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
|
||||||
return d.statements.deleteRoomAlias(alias)
|
return d.statements.deleteRoomAlias(ctx, alias)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StateEntriesForTuples implements state.RoomStateDatabase
|
// StateEntriesForTuples implements state.RoomStateDatabase
|
||||||
func (d *Database) StateEntriesForTuples(
|
func (d *Database) StateEntriesForTuples(
|
||||||
stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple,
|
ctx context.Context,
|
||||||
|
stateBlockNIDs []types.StateBlockNID,
|
||||||
|
stateKeyTuples []types.StateKeyTuple,
|
||||||
) ([]types.StateEntryList, error) {
|
) ([]types.StateEntryList, error) {
|
||||||
return d.statements.bulkSelectFilteredStateBlockEntries(stateBlockNIDs, stateKeyTuples)
|
return d.statements.bulkSelectFilteredStateBlockEntries(
|
||||||
|
ctx, stateBlockNIDs, stateKeyTuples,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MembershipUpdater implements input.RoomEventDatabase
|
// MembershipUpdater implements input.RoomEventDatabase
|
||||||
func (d *Database) MembershipUpdater(roomID, targetUserID string) (types.MembershipUpdater, error) {
|
func (d *Database) MembershipUpdater(
|
||||||
|
ctx context.Context, roomID, targetUserID string,
|
||||||
|
) (types.MembershipUpdater, error) {
|
||||||
txn, err := d.db.Begin()
|
txn, err := d.db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -411,17 +463,17 @@ func (d *Database) MembershipUpdater(roomID, targetUserID string) (types.Members
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
roomNID, err := d.assignRoomNID(txn, roomID)
|
roomNID, err := d.assignRoomNID(ctx, txn, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
targetUserNID, err := d.assignStateKeyNID(txn, targetUserID)
|
targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
updater, err := d.membershipUpdaterTxn(txn, roomNID, targetUserNID)
|
updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -439,20 +491,23 @@ type membershipUpdater struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) membershipUpdaterTxn(
|
func (d *Database) membershipUpdaterTxn(
|
||||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
roomNID types.RoomNID,
|
||||||
|
targetUserNID types.EventStateKeyNID,
|
||||||
) (types.MembershipUpdater, error) {
|
) (types.MembershipUpdater, error) {
|
||||||
|
|
||||||
if err := d.statements.insertMembership(txn, roomNID, targetUserNID); err != nil {
|
if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
membership, err := d.statements.selectMembershipForUpdate(txn, roomNID, targetUserNID)
|
membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &membershipUpdater{
|
return &membershipUpdater{
|
||||||
transaction{txn}, d, roomNID, targetUserNID, membership,
|
transaction{ctx, txn}, d, roomNID, targetUserNID, membership,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -473,19 +528,19 @@ func (u *membershipUpdater) IsLeave() bool {
|
||||||
|
|
||||||
// SetToInvite implements types.MembershipUpdater
|
// SetToInvite implements types.MembershipUpdater
|
||||||
func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) {
|
func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) {
|
||||||
senderUserNID, err := u.d.assignStateKeyNID(u.txn, event.Sender())
|
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
inserted, err := u.d.statements.insertInviteEvent(
|
inserted, err := u.d.statements.insertInviteEvent(
|
||||||
u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
|
u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
if u.membership != membershipStateInvite {
|
if u.membership != membershipStateInvite {
|
||||||
if err = u.d.statements.updateMembership(
|
if err = u.d.statements.updateMembership(
|
||||||
u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0,
|
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -497,7 +552,7 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er
|
||||||
func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) {
|
func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) {
|
||||||
var inviteEventIDs []string
|
var inviteEventIDs []string
|
||||||
|
|
||||||
senderUserNID, err := u.d.assignStateKeyNID(u.txn, senderUserID)
|
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -505,7 +560,7 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
|
||||||
// If this is a join event update, there is no invite to update
|
// If this is a join event update, there is no invite to update
|
||||||
if !isUpdate {
|
if !isUpdate {
|
||||||
inviteEventIDs, err = u.d.statements.updateInviteRetired(
|
inviteEventIDs, err = u.d.statements.updateInviteRetired(
|
||||||
u.txn, u.roomNID, u.targetUserNID,
|
u.ctx, u.txn, u.roomNID, u.targetUserNID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -513,14 +568,15 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look up the NID of the new join event
|
// Look up the NID of the new join event
|
||||||
nIDs, err := u.d.EventNIDs([]string{eventID})
|
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.membership != membershipStateJoin || isUpdate {
|
if u.membership != membershipStateJoin || isUpdate {
|
||||||
if err = u.d.statements.updateMembership(
|
if err = u.d.statements.updateMembership(
|
||||||
u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateJoin, nIDs[eventID],
|
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
|
||||||
|
membershipStateJoin, nIDs[eventID],
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -531,26 +587,27 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
|
||||||
|
|
||||||
// SetToLeave implements types.MembershipUpdater
|
// SetToLeave implements types.MembershipUpdater
|
||||||
func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) {
|
func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) {
|
||||||
senderUserNID, err := u.d.assignStateKeyNID(u.txn, senderUserID)
|
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
inviteEventIDs, err := u.d.statements.updateInviteRetired(
|
inviteEventIDs, err := u.d.statements.updateInviteRetired(
|
||||||
u.txn, u.roomNID, u.targetUserNID,
|
u.ctx, u.txn, u.roomNID, u.targetUserNID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look up the NID of the new leave event
|
// Look up the NID of the new leave event
|
||||||
nIDs, err := u.d.EventNIDs([]string{eventID})
|
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.membership != membershipStateLeaveOrBan {
|
if u.membership != membershipStateLeaveOrBan {
|
||||||
if err = u.d.statements.updateMembership(
|
if err = u.d.statements.updateMembership(
|
||||||
u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateLeaveOrBan, nIDs[eventID],
|
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
|
||||||
|
membershipStateLeaveOrBan, nIDs[eventID],
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -559,19 +616,18 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMembership implements query.RoomserverQueryAPIDB
|
// GetMembership implements query.RoomserverQueryAPIDB
|
||||||
func (d *Database) GetMembership(roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) {
|
func (d *Database) GetMembership(
|
||||||
txn, err := d.db.Begin()
|
ctx context.Context, roomNID types.RoomNID, requestSenderUserID string,
|
||||||
if err != nil {
|
) (membershipEventNID types.EventNID, stillInRoom bool, err error) {
|
||||||
return
|
requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID)
|
||||||
}
|
|
||||||
defer txn.Commit()
|
|
||||||
|
|
||||||
requestSenderUserNID, err := d.assignStateKeyNID(txn, requestSenderUserID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
senderMembershipEventNID, senderMembership, err := d.statements.selectMembershipFromRoomAndTarget(roomNID, requestSenderUserNID)
|
senderMembershipEventNID, senderMembership, err :=
|
||||||
|
d.statements.selectMembershipFromRoomAndTarget(
|
||||||
|
ctx, roomNID, requestSenderUserNID,
|
||||||
|
)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
// The user has never been a member of that room
|
// The user has never been a member of that room
|
||||||
return 0, false, nil
|
return 0, false, nil
|
||||||
|
@ -583,15 +639,20 @@ func (d *Database) GetMembership(roomNID types.RoomNID, requestSenderUserID stri
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
|
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
|
||||||
func (d *Database) GetMembershipEventNIDsForRoom(roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error) {
|
func (d *Database) GetMembershipEventNIDsForRoom(
|
||||||
|
ctx context.Context, roomNID types.RoomNID, joinOnly bool,
|
||||||
|
) ([]types.EventNID, error) {
|
||||||
if joinOnly {
|
if joinOnly {
|
||||||
return d.statements.selectMembershipsFromRoomAndMembership(roomNID, membershipStateJoin)
|
return d.statements.selectMembershipsFromRoomAndMembership(
|
||||||
|
ctx, roomNID, membershipStateJoin,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return d.statements.selectMembershipsFromRoom(roomNID)
|
return d.statements.selectMembershipsFromRoom(ctx, roomNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
type transaction struct {
|
type transaction struct {
|
||||||
|
ctx context.Context
|
||||||
txn *sql.Tx
|
txn *sql.Tx
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue