mirror of
https://github.com/matrix-org/dendrite
synced 2024-12-14 13:03:51 +01:00
Ensure only one transaction is used for RS input per room (#2178)
* Ensure the input API only uses a single transaction * Remove more of the dead query API call * Tidy up * Fix tests hopefully * Don't do unnecessary work for rooms that don't exist * Improve error, fix another case where transaction wasn't used properly * Add a unit test for checking single transaction on RS input API * Fix logic oops when deciding whether to use a transaction in storeEvent
This commit is contained in:
parent
a4e7d471af
commit
5106cc807c
13 changed files with 211 additions and 214 deletions
|
@ -93,11 +93,10 @@ func (o *testEDUProducer) InputCrossSigningKeyUpdate(
|
||||||
|
|
||||||
type testRoomserverAPI struct {
|
type testRoomserverAPI struct {
|
||||||
api.RoomserverInternalAPITrace
|
api.RoomserverInternalAPITrace
|
||||||
inputRoomEvents []api.InputRoomEvent
|
inputRoomEvents []api.InputRoomEvent
|
||||||
queryMissingAuthPrevEvents func(*api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse
|
queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse
|
||||||
queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse
|
queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse
|
||||||
queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse
|
queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse
|
||||||
queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *testRoomserverAPI) InputRoomEvents(
|
func (t *testRoomserverAPI) InputRoomEvents(
|
||||||
|
@ -140,20 +139,6 @@ func (t *testRoomserverAPI) QueryStateAfterEvents(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query the state after a list of events in a room from the room server.
|
|
||||||
func (t *testRoomserverAPI) QueryMissingAuthPrevEvents(
|
|
||||||
ctx context.Context,
|
|
||||||
request *api.QueryMissingAuthPrevEventsRequest,
|
|
||||||
response *api.QueryMissingAuthPrevEventsResponse,
|
|
||||||
) error {
|
|
||||||
response.RoomVersion = testRoomVersion
|
|
||||||
res := t.queryMissingAuthPrevEvents(request)
|
|
||||||
response.RoomExists = res.RoomExists
|
|
||||||
response.MissingAuthEventIDs = res.MissingAuthEventIDs
|
|
||||||
response.MissingPrevEventIDs = res.MissingPrevEventIDs
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query a list of events by event ID.
|
// Query a list of events by event ID.
|
||||||
func (t *testRoomserverAPI) QueryEventsByID(
|
func (t *testRoomserverAPI) QueryEventsByID(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
@ -312,15 +297,7 @@ func assertInputRoomEvents(t *testing.T, got []api.InputRoomEvent, want []*gomat
|
||||||
// The purpose of this test is to check that receiving an event over federation for which we have the prev_events works correctly, and passes it on
|
// The purpose of this test is to check that receiving an event over federation for which we have the prev_events works correctly, and passes it on
|
||||||
// to the roomserver. It's the most basic test possible.
|
// to the roomserver. It's the most basic test possible.
|
||||||
func TestBasicTransaction(t *testing.T) {
|
func TestBasicTransaction(t *testing.T) {
|
||||||
rsAPI := &testRoomserverAPI{
|
rsAPI := &testRoomserverAPI{}
|
||||||
queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse {
|
|
||||||
return api.QueryMissingAuthPrevEventsResponse{
|
|
||||||
RoomExists: true,
|
|
||||||
MissingAuthEventIDs: []string{},
|
|
||||||
MissingPrevEventIDs: []string{},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
pdus := []json.RawMessage{
|
pdus := []json.RawMessage{
|
||||||
testData[len(testData)-1], // a message event
|
testData[len(testData)-1], // a message event
|
||||||
}
|
}
|
||||||
|
@ -332,15 +309,7 @@ func TestBasicTransaction(t *testing.T) {
|
||||||
// The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver
|
// The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver
|
||||||
// as it does the auth check.
|
// as it does the auth check.
|
||||||
func TestTransactionFailAuthChecks(t *testing.T) {
|
func TestTransactionFailAuthChecks(t *testing.T) {
|
||||||
rsAPI := &testRoomserverAPI{
|
rsAPI := &testRoomserverAPI{}
|
||||||
queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse {
|
|
||||||
return api.QueryMissingAuthPrevEventsResponse{
|
|
||||||
RoomExists: true,
|
|
||||||
MissingAuthEventIDs: []string{},
|
|
||||||
MissingPrevEventIDs: []string{},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
pdus := []json.RawMessage{
|
pdus := []json.RawMessage{
|
||||||
testData[len(testData)-1], // a message event
|
testData[len(testData)-1], // a message event
|
||||||
}
|
}
|
||||||
|
|
|
@ -83,13 +83,6 @@ type RoomserverInternalAPI interface {
|
||||||
response *QueryStateAfterEventsResponse,
|
response *QueryStateAfterEventsResponse,
|
||||||
) error
|
) error
|
||||||
|
|
||||||
// Query whether the roomserver is missing any auth or prev events.
|
|
||||||
QueryMissingAuthPrevEvents(
|
|
||||||
ctx context.Context,
|
|
||||||
request *QueryMissingAuthPrevEventsRequest,
|
|
||||||
response *QueryMissingAuthPrevEventsResponse,
|
|
||||||
) error
|
|
||||||
|
|
||||||
// Query a list of events by event ID.
|
// Query a list of events by event ID.
|
||||||
QueryEventsByID(
|
QueryEventsByID(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
|
|
@ -129,16 +129,6 @@ func (t *RoomserverInternalAPITrace) QueryStateAfterEvents(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *RoomserverInternalAPITrace) QueryMissingAuthPrevEvents(
|
|
||||||
ctx context.Context,
|
|
||||||
req *QueryMissingAuthPrevEventsRequest,
|
|
||||||
res *QueryMissingAuthPrevEventsResponse,
|
|
||||||
) error {
|
|
||||||
err := t.Impl.QueryMissingAuthPrevEvents(ctx, req, res)
|
|
||||||
util.GetLogger(ctx).WithError(err).Infof("QueryMissingAuthPrevEvents req=%+v res=%+v", js(req), js(res))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *RoomserverInternalAPITrace) QueryEventsByID(
|
func (t *RoomserverInternalAPITrace) QueryEventsByID(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *QueryEventsByIDRequest,
|
req *QueryEventsByIDRequest,
|
||||||
|
|
|
@ -83,27 +83,6 @@ type QueryStateAfterEventsResponse struct {
|
||||||
StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"`
|
StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryMissingAuthPrevEventsRequest struct {
|
|
||||||
// The room ID to query the state in.
|
|
||||||
RoomID string `json:"room_id"`
|
|
||||||
// The list of auth events to check the existence of.
|
|
||||||
AuthEventIDs []string `json:"auth_event_ids"`
|
|
||||||
// The list of previous events to check the existence of.
|
|
||||||
PrevEventIDs []string `json:"prev_event_ids"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type QueryMissingAuthPrevEventsResponse struct {
|
|
||||||
// Does the room exist on this roomserver?
|
|
||||||
// If the room doesn't exist all other fields will be empty.
|
|
||||||
RoomExists bool `json:"room_exists"`
|
|
||||||
// The room version of the room.
|
|
||||||
RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"`
|
|
||||||
// The event IDs of the auth events that we don't know locally.
|
|
||||||
MissingAuthEventIDs []string `json:"missing_auth_event_ids"`
|
|
||||||
// The event IDs of the previous events that we don't know locally.
|
|
||||||
MissingPrevEventIDs []string `json:"missing_prev_event_ids"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryEventsByIDRequest is a request to QueryEventsByID
|
// QueryEventsByIDRequest is a request to QueryEventsByID
|
||||||
type QueryEventsByIDRequest struct {
|
type QueryEventsByIDRequest struct {
|
||||||
// The event IDs to look up.
|
// The event IDs to look up.
|
||||||
|
|
|
@ -128,20 +128,24 @@ func (r *Inputer) processRoomEvent(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
missingRes := &api.QueryMissingAuthPrevEventsResponse{}
|
// Don't waste time processing the event if the room doesn't exist.
|
||||||
serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{}
|
// A room entry locally will only be created in response to a create
|
||||||
if event.Type() != gomatrixserverlib.MRoomCreate || !event.StateKeyEquals("") {
|
// event.
|
||||||
missingReq := &api.QueryMissingAuthPrevEventsRequest{
|
isCreateEvent := event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("")
|
||||||
RoomID: event.RoomID(),
|
if !updater.RoomExists() && !isCreateEvent {
|
||||||
AuthEventIDs: event.AuthEventIDs(),
|
return rollbackTransaction, fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID())
|
||||||
PrevEventIDs: event.PrevEventIDs(),
|
}
|
||||||
}
|
|
||||||
if err := r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil {
|
var missingAuth, missingPrev bool
|
||||||
return rollbackTransaction, fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err)
|
serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{}
|
||||||
}
|
if !isCreateEvent {
|
||||||
|
missingAuthIDs, missingPrevIDs, err := updater.MissingAuthPrevEvents(ctx, event)
|
||||||
|
if err != nil {
|
||||||
|
return rollbackTransaction, fmt.Errorf("updater.MissingAuthPrevEvents: %w", err)
|
||||||
|
}
|
||||||
|
missingAuth = len(missingAuthIDs) > 0
|
||||||
|
missingPrev = !input.HasState && len(missingPrevIDs) > 0
|
||||||
}
|
}
|
||||||
missingAuth := len(missingRes.MissingAuthEventIDs) > 0
|
|
||||||
missingPrev := !input.HasState && len(missingRes.MissingPrevEventIDs) > 0
|
|
||||||
|
|
||||||
if missingAuth || missingPrev {
|
if missingAuth || missingPrev {
|
||||||
serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{
|
serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{
|
||||||
|
@ -246,14 +250,13 @@ func (r *Inputer) processRoomEvent(
|
||||||
missingState := missingStateReq{
|
missingState := missingStateReq{
|
||||||
origin: input.Origin,
|
origin: input.Origin,
|
||||||
inputer: r,
|
inputer: r,
|
||||||
queryer: r.Queryer,
|
|
||||||
db: updater,
|
db: updater,
|
||||||
federation: r.FSAPI,
|
federation: r.FSAPI,
|
||||||
keys: r.KeyRing,
|
keys: r.KeyRing,
|
||||||
roomsMu: internal.NewMutexByRoom(),
|
roomsMu: internal.NewMutexByRoom(),
|
||||||
servers: serverRes.ServerNames,
|
servers: serverRes.ServerNames,
|
||||||
hadEvents: map[string]bool{},
|
hadEvents: map[string]bool{},
|
||||||
haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{},
|
haveEvents: map[string]*gomatrixserverlib.Event{},
|
||||||
}
|
}
|
||||||
if stateSnapshot, err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil {
|
if stateSnapshot, err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil {
|
||||||
// Something went wrong with retrieving the missing state, so we can't
|
// Something went wrong with retrieving the missing state, so we can't
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
fedapi "github.com/matrix-org/dendrite/federationapi/api"
|
fedapi "github.com/matrix-org/dendrite/federationapi/api"
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/internal/query"
|
"github.com/matrix-org/dendrite/roomserver/state"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
@ -27,14 +27,13 @@ type missingStateReq struct {
|
||||||
origin gomatrixserverlib.ServerName
|
origin gomatrixserverlib.ServerName
|
||||||
db *shared.RoomUpdater
|
db *shared.RoomUpdater
|
||||||
inputer *Inputer
|
inputer *Inputer
|
||||||
queryer *query.Queryer
|
|
||||||
keys gomatrixserverlib.JSONVerifier
|
keys gomatrixserverlib.JSONVerifier
|
||||||
federation fedapi.FederationInternalAPI
|
federation fedapi.FederationInternalAPI
|
||||||
roomsMu *internal.MutexByRoom
|
roomsMu *internal.MutexByRoom
|
||||||
servers []gomatrixserverlib.ServerName
|
servers []gomatrixserverlib.ServerName
|
||||||
hadEvents map[string]bool
|
hadEvents map[string]bool
|
||||||
hadEventsMutex sync.Mutex
|
hadEventsMutex sync.Mutex
|
||||||
haveEvents map[string]*gomatrixserverlib.HeaderedEvent
|
haveEvents map[string]*gomatrixserverlib.Event
|
||||||
haveEventsMutex sync.Mutex
|
haveEventsMutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -326,20 +325,20 @@ func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion
|
||||||
for i := range respState.StateEvents {
|
for i := range respState.StateEvents {
|
||||||
se := respState.StateEvents[i]
|
se := respState.StateEvents[i]
|
||||||
if se.Type() == h.Type() && se.StateKeyEquals(*h.StateKey()) {
|
if se.Type() == h.Type() && se.StateKeyEquals(*h.StateKey()) {
|
||||||
respState.StateEvents[i] = h.Unwrap()
|
respState.StateEvents[i] = h
|
||||||
addedToState = true
|
addedToState = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !addedToState {
|
if !addedToState {
|
||||||
respState.StateEvents = append(respState.StateEvents, h.Unwrap())
|
respState.StateEvents = append(respState.StateEvents, h)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return respState, false, nil
|
return respState, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.HeaderedEvent) *gomatrixserverlib.HeaderedEvent {
|
func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.Event) *gomatrixserverlib.Event {
|
||||||
t.haveEventsMutex.Lock()
|
t.haveEventsMutex.Lock()
|
||||||
defer t.haveEventsMutex.Unlock()
|
defer t.haveEventsMutex.Unlock()
|
||||||
if cached, exists := t.haveEvents[ev.EventID()]; exists {
|
if cached, exists := t.haveEvents[ev.EventID()]; exists {
|
||||||
|
@ -350,32 +349,49 @@ func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.HeaderedEvent) *g
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *parsedRespState {
|
func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *parsedRespState {
|
||||||
var res api.QueryStateAfterEventsResponse
|
var res parsedRespState
|
||||||
err := t.queryer.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{
|
roomInfo, err := t.db.RoomInfo(ctx, roomID)
|
||||||
RoomID: roomID,
|
if err != nil {
|
||||||
PrevEventIDs: []string{eventID},
|
|
||||||
}, &res)
|
|
||||||
if err != nil || !res.PrevEventsExist {
|
|
||||||
util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to query state after %s locally, prev exists=%v", eventID, res.PrevEventsExist)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
stateEvents := make([]*gomatrixserverlib.HeaderedEvent, len(res.StateEvents))
|
roomState := state.NewStateResolution(t.db, roomInfo)
|
||||||
for i, ev := range res.StateEvents {
|
stateAtEvents, err := t.db.StateAtEventIDs(ctx, []string{eventID})
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to get state after %s locally", eventID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
stateEntries, err := roomState.LoadCombinedStateAfterEvents(ctx, stateAtEvents)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to load combined state after %s locally", eventID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
stateEventNIDs := make([]types.EventNID, 0, len(stateEntries))
|
||||||
|
for _, entry := range stateEntries {
|
||||||
|
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
|
||||||
|
}
|
||||||
|
stateEvents, err := t.db.Events(ctx, stateEventNIDs)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to load state events locally")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
res.StateEvents = make([]*gomatrixserverlib.Event, 0, len(stateEvents))
|
||||||
|
for _, ev := range stateEvents {
|
||||||
// set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this
|
// set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this
|
||||||
// processEvent request, which is better for memory.
|
// processEvent request, which is better for memory.
|
||||||
stateEvents[i] = t.cacheAndReturn(ev)
|
res.StateEvents = append(res.StateEvents, t.cacheAndReturn(ev.Event))
|
||||||
t.hadEvent(ev.EventID())
|
t.hadEvent(ev.EventID())
|
||||||
}
|
}
|
||||||
// we should never access res.StateEvents again so we delete it here to make GC faster
|
|
||||||
res.StateEvents = nil
|
|
||||||
|
|
||||||
var authEvents []*gomatrixserverlib.Event
|
// encourage GC
|
||||||
|
stateEvents, stateEventNIDs, stateEntries, stateAtEvents = nil, nil, nil, nil // nolint:ineffassign
|
||||||
|
|
||||||
missingAuthEvents := map[string]bool{}
|
missingAuthEvents := map[string]bool{}
|
||||||
|
res.AuthEvents = make([]*gomatrixserverlib.Event, 0, len(stateEvents)*3)
|
||||||
for _, ev := range stateEvents {
|
for _, ev := range stateEvents {
|
||||||
t.haveEventsMutex.Lock()
|
t.haveEventsMutex.Lock()
|
||||||
for _, ae := range ev.AuthEventIDs() {
|
for _, ae := range ev.AuthEventIDs() {
|
||||||
if aev, ok := t.haveEvents[ae]; ok {
|
if aev, ok := t.haveEvents[ae]; ok {
|
||||||
authEvents = append(authEvents, aev.Unwrap())
|
res.AuthEvents = append(res.AuthEvents, aev)
|
||||||
} else {
|
} else {
|
||||||
missingAuthEvents[ae] = true
|
missingAuthEvents[ae] = true
|
||||||
}
|
}
|
||||||
|
@ -389,25 +405,18 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, room
|
||||||
for evID := range missingAuthEvents {
|
for evID := range missingAuthEvents {
|
||||||
missingEventList = append(missingEventList, evID)
|
missingEventList = append(missingEventList, evID)
|
||||||
}
|
}
|
||||||
queryReq := api.QueryEventsByIDRequest{
|
|
||||||
EventIDs: missingEventList,
|
|
||||||
}
|
|
||||||
util.GetLogger(ctx).WithField("count", len(missingEventList)).Debugf("Fetching missing auth events")
|
util.GetLogger(ctx).WithField("count", len(missingEventList)).Debugf("Fetching missing auth events")
|
||||||
var queryRes api.QueryEventsByIDResponse
|
events, err := t.db.EventsFromIDs(ctx, missingEventList)
|
||||||
if err = t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
for i, ev := range queryRes.Events {
|
for i, ev := range events {
|
||||||
authEvents = append(authEvents, t.cacheAndReturn(queryRes.Events[i]).Unwrap())
|
res.AuthEvents = append(res.AuthEvents, t.cacheAndReturn(events[i].Event))
|
||||||
t.hadEvent(ev.EventID())
|
t.hadEvent(ev.EventID())
|
||||||
}
|
}
|
||||||
queryRes.Events = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &parsedRespState{
|
return &res
|
||||||
StateEvents: gomatrixserverlib.UnwrapEventHeaders(stateEvents),
|
|
||||||
AuthEvents: authEvents,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookuptStateBeforeEvent returns the room state before the event e, which is just /state_ids and/or /state depending on what
|
// lookuptStateBeforeEvent returns the room state before the event e, which is just /state_ids and/or /state depending on what
|
||||||
|
@ -448,7 +457,7 @@ retryAllowedState:
|
||||||
return nil, fmt.Errorf("missing auth event %s and failed to look it up: %w", missing.AuthEventID, err2)
|
return nil, fmt.Errorf("missing auth event %s and failed to look it up: %w", missing.AuthEventID, err2)
|
||||||
}
|
}
|
||||||
util.GetLogger(ctx).Tracef("fetched event %s", missing.AuthEventID)
|
util.GetLogger(ctx).Tracef("fetched event %s", missing.AuthEventID)
|
||||||
resolvedStateEvents = append(resolvedStateEvents, h.Unwrap())
|
resolvedStateEvents = append(resolvedStateEvents, h)
|
||||||
goto retryAllowedState
|
goto retryAllowedState
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
@ -513,7 +522,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve
|
||||||
logger.Debugf("get_missing_events returned %d events", len(missingResp.Events))
|
logger.Debugf("get_missing_events returned %d events", len(missingResp.Events))
|
||||||
missingEvents := make([]*gomatrixserverlib.Event, 0, len(missingResp.Events))
|
missingEvents := make([]*gomatrixserverlib.Event, 0, len(missingResp.Events))
|
||||||
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
|
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
|
||||||
missingEvents = append(missingEvents, t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap())
|
missingEvents = append(missingEvents, t.cacheAndReturn(ev))
|
||||||
}
|
}
|
||||||
|
|
||||||
// topologically sort and sanity check that we are making forward progress
|
// topologically sort and sanity check that we are making forward progress
|
||||||
|
@ -602,11 +611,11 @@ func (t *missingStateReq) lookupMissingStateViaState(
|
||||||
// We load these as trusted as we called state.Check before which loaded them as untrusted.
|
// We load these as trusted as we called state.Check before which loaded them as untrusted.
|
||||||
for i, evJSON := range state.AuthEvents {
|
for i, evJSON := range state.AuthEvents {
|
||||||
ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion)
|
ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion)
|
||||||
parsedState.AuthEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap()
|
parsedState.AuthEvents[i] = t.cacheAndReturn(ev)
|
||||||
}
|
}
|
||||||
for i, evJSON := range state.StateEvents {
|
for i, evJSON := range state.StateEvents {
|
||||||
ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion)
|
ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion)
|
||||||
parsedState.StateEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap()
|
parsedState.StateEvents[i] = t.cacheAndReturn(ev)
|
||||||
}
|
}
|
||||||
return parsedState, nil
|
return parsedState, nil
|
||||||
}
|
}
|
||||||
|
@ -634,23 +643,22 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo
|
||||||
}
|
}
|
||||||
t.haveEventsMutex.Unlock()
|
t.haveEventsMutex.Unlock()
|
||||||
|
|
||||||
// fetch as many as we can from the roomserver
|
events, err := t.db.EventsFromIDs(ctx, missingEventList)
|
||||||
queryReq := api.QueryEventsByIDRequest{
|
if err != nil {
|
||||||
EventIDs: missingEventList,
|
return nil, fmt.Errorf("t.db.EventsFromIDs: %w", err)
|
||||||
}
|
}
|
||||||
var queryRes api.QueryEventsByIDResponse
|
|
||||||
if err = t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil {
|
for i, ev := range events {
|
||||||
return nil, err
|
events[i].Event = t.cacheAndReturn(events[i].Event)
|
||||||
}
|
|
||||||
for i, ev := range queryRes.Events {
|
|
||||||
queryRes.Events[i] = t.cacheAndReturn(queryRes.Events[i])
|
|
||||||
t.hadEvent(ev.EventID())
|
t.hadEvent(ev.EventID())
|
||||||
evID := queryRes.Events[i].EventID()
|
evID := events[i].EventID()
|
||||||
if missing[evID] {
|
if missing[evID] {
|
||||||
delete(missing, evID)
|
delete(missing, evID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
queryRes.Events = nil // allow it to be GCed
|
|
||||||
|
// encourage GC
|
||||||
|
events = nil // nolint:ineffassign
|
||||||
|
|
||||||
concurrentRequests := 8
|
concurrentRequests := 8
|
||||||
missingCount := len(missing)
|
missingCount := len(missing)
|
||||||
|
@ -704,7 +712,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo
|
||||||
|
|
||||||
// Define what we'll do in order to fetch the missing event ID.
|
// Define what we'll do in order to fetch the missing event ID.
|
||||||
fetch := func(missingEventID string) {
|
fetch := func(missingEventID string) {
|
||||||
var h *gomatrixserverlib.HeaderedEvent
|
var h *gomatrixserverlib.Event
|
||||||
h, err = t.lookupEvent(ctx, roomVersion, roomID, missingEventID, false)
|
h, err = t.lookupEvent(ctx, roomVersion, roomID, missingEventID, false)
|
||||||
switch err.(type) {
|
switch err.(type) {
|
||||||
case verifySigError:
|
case verifySigError:
|
||||||
|
@ -759,7 +767,7 @@ func (t *missingStateReq) createRespStateFromStateIDs(
|
||||||
logrus.Tracef("Missing state event in createRespStateFromStateIDs: %s", stateIDs.StateEventIDs[i])
|
logrus.Tracef("Missing state event in createRespStateFromStateIDs: %s", stateIDs.StateEventIDs[i])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
respState.StateEvents = append(respState.StateEvents, ev.Unwrap())
|
respState.StateEvents = append(respState.StateEvents, ev)
|
||||||
}
|
}
|
||||||
for i := range stateIDs.AuthEventIDs {
|
for i := range stateIDs.AuthEventIDs {
|
||||||
ev, ok := t.haveEvents[stateIDs.AuthEventIDs[i]]
|
ev, ok := t.haveEvents[stateIDs.AuthEventIDs[i]]
|
||||||
|
@ -767,7 +775,7 @@ func (t *missingStateReq) createRespStateFromStateIDs(
|
||||||
logrus.Tracef("Missing auth event in createRespStateFromStateIDs: %s", stateIDs.AuthEventIDs[i])
|
logrus.Tracef("Missing auth event in createRespStateFromStateIDs: %s", stateIDs.AuthEventIDs[i])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
respState.AuthEvents = append(respState.AuthEvents, ev.Unwrap())
|
respState.AuthEvents = append(respState.AuthEvents, ev)
|
||||||
}
|
}
|
||||||
// We purposefully do not do auth checks on the returned events, as they will still
|
// We purposefully do not do auth checks on the returned events, as they will still
|
||||||
// be processed in the exact same way, just as a 'rejected' event
|
// be processed in the exact same way, just as a 'rejected' event
|
||||||
|
@ -775,17 +783,14 @@ func (t *missingStateReq) createRespStateFromStateIDs(
|
||||||
return &respState, nil
|
return &respState, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) {
|
func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (*gomatrixserverlib.Event, error) {
|
||||||
if localFirst {
|
if localFirst {
|
||||||
// fetch from the roomserver
|
// fetch from the roomserver
|
||||||
queryReq := api.QueryEventsByIDRequest{
|
events, err := t.db.EventsFromIDs(ctx, []string{missingEventID})
|
||||||
EventIDs: []string{missingEventID},
|
if err != nil {
|
||||||
}
|
|
||||||
var queryRes api.QueryEventsByIDResponse
|
|
||||||
if err := t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil {
|
|
||||||
util.GetLogger(ctx).Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err)
|
util.GetLogger(ctx).Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err)
|
||||||
} else if len(queryRes.Events) == 1 {
|
} else if len(events) == 1 {
|
||||||
return queryRes.Events[0], nil
|
return events[0].Event, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var event *gomatrixserverlib.Event
|
var event *gomatrixserverlib.Event
|
||||||
|
@ -822,7 +827,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
|
||||||
util.GetLogger(ctx).WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
|
util.GetLogger(ctx).WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
|
||||||
return nil, verifySigError{event.EventID(), err}
|
return nil, verifySigError{event.EventID(), err}
|
||||||
}
|
}
|
||||||
return t.cacheAndReturn(event.Headered(roomVersion)), nil
|
return t.cacheAndReturn(event), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []*gomatrixserverlib.Event) error {
|
func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []*gomatrixserverlib.Event) error {
|
||||||
|
|
93
roomserver/internal/input/input_test.go
Normal file
93
roomserver/internal/input/input_test.go
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
package input_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/internal/input"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
func psqlConnectionString() config.DataSource {
|
||||||
|
user := os.Getenv("POSTGRES_USER")
|
||||||
|
if user == "" {
|
||||||
|
user = "dendrite"
|
||||||
|
}
|
||||||
|
dbName := os.Getenv("POSTGRES_DB")
|
||||||
|
if dbName == "" {
|
||||||
|
dbName = "dendrite"
|
||||||
|
}
|
||||||
|
connStr := fmt.Sprintf(
|
||||||
|
"user=%s dbname=%s sslmode=disable", user, dbName,
|
||||||
|
)
|
||||||
|
password := os.Getenv("POSTGRES_PASSWORD")
|
||||||
|
if password != "" {
|
||||||
|
connStr += fmt.Sprintf(" password=%s", password)
|
||||||
|
}
|
||||||
|
host := os.Getenv("POSTGRES_HOST")
|
||||||
|
if host != "" {
|
||||||
|
connStr += fmt.Sprintf(" host=%s", host)
|
||||||
|
}
|
||||||
|
return config.DataSource(connStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSingleTransactionOnInput(t *testing.T) {
|
||||||
|
deadline, _ := t.Deadline()
|
||||||
|
if max := time.Now().Add(time.Second * 3); deadline.After(max) {
|
||||||
|
deadline = max
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithDeadline(context.Background(), deadline)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
event, err := gomatrixserverlib.NewEventFromTrustedJSON(
|
||||||
|
[]byte(`{"auth_events":[],"content":{"creator":"@neilalexander:dendrite.matrix.org","room_version":"6"},"depth":1,"hashes":{"sha256":"jqOqdNEH5r0NiN3xJtj0u5XUVmRqq9YvGbki1wxxuuM"},"origin":"dendrite.matrix.org","origin_server_ts":1644595362726,"prev_events":[],"prev_state":[],"room_id":"!jSZZRknA6GkTBXNP:dendrite.matrix.org","sender":"@neilalexander:dendrite.matrix.org","signatures":{"dendrite.matrix.org":{"ed25519:6jB2aB":"bsQXO1wketf1OSe9xlndDIWe71W9KIundc6rBw4KEZdGPW7x4Tv4zDWWvbxDsG64sS2IPWfIm+J0OOozbrWIDw"}},"state_key":"","type":"m.room.create"}`),
|
||||||
|
false, gomatrixserverlib.RoomVersionV6,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
in := api.InputRoomEvent{
|
||||||
|
Kind: api.KindOutlier, // don't panic if we generate an output event
|
||||||
|
Event: event.Headered(gomatrixserverlib.RoomVersionV6),
|
||||||
|
}
|
||||||
|
cache, err := caching.NewInMemoryLRUCache(false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
db, err := storage.Open(
|
||||||
|
&config.DatabaseOptions{
|
||||||
|
ConnectionString: psqlConnectionString(),
|
||||||
|
MaxOpenConnections: 1,
|
||||||
|
MaxIdleConnections: 1,
|
||||||
|
},
|
||||||
|
cache,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("PostgreSQL not available (%s), skipping", err)
|
||||||
|
t.SkipNow()
|
||||||
|
}
|
||||||
|
inputter := &input.Inputer{
|
||||||
|
DB: db,
|
||||||
|
}
|
||||||
|
res := &api.InputRoomEventsResponse{}
|
||||||
|
inputter.InputRoomEvents(
|
||||||
|
ctx,
|
||||||
|
&api.InputRoomEventsRequest{
|
||||||
|
InputRoomEvents: []api.InputRoomEvent{in},
|
||||||
|
Asynchronous: false,
|
||||||
|
},
|
||||||
|
res,
|
||||||
|
)
|
||||||
|
// If we fail here then it's because we've hit the test deadline,
|
||||||
|
// so we probably deadlocked
|
||||||
|
if err := res.Err(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
|
@ -125,39 +125,6 @@ func (r *Queryer) QueryStateAfterEvents(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryMissingAuthPrevEvents implements api.RoomserverInternalAPI
|
|
||||||
func (r *Queryer) QueryMissingAuthPrevEvents(
|
|
||||||
ctx context.Context,
|
|
||||||
request *api.QueryMissingAuthPrevEventsRequest,
|
|
||||||
response *api.QueryMissingAuthPrevEventsResponse,
|
|
||||||
) error {
|
|
||||||
info, err := r.DB.RoomInfo(ctx, request.RoomID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if info == nil {
|
|
||||||
return errors.New("room doesn't exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
response.RoomExists = !info.IsStub
|
|
||||||
response.RoomVersion = info.RoomVersion
|
|
||||||
|
|
||||||
for _, authEventID := range request.AuthEventIDs {
|
|
||||||
if nids, err := r.DB.EventNIDs(ctx, []string{authEventID}); err != nil || len(nids) == 0 {
|
|
||||||
response.MissingAuthEventIDs = append(response.MissingAuthEventIDs, authEventID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, prevEventID := range request.PrevEventIDs {
|
|
||||||
state, err := r.DB.StateAtEventIDs(ctx, []string{prevEventID})
|
|
||||||
if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) {
|
|
||||||
response.MissingPrevEventIDs = append(response.MissingPrevEventIDs, prevEventID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryEventsByID implements api.RoomserverInternalAPI
|
// QueryEventsByID implements api.RoomserverInternalAPI
|
||||||
func (r *Queryer) QueryEventsByID(
|
func (r *Queryer) QueryEventsByID(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
|
|
@ -40,7 +40,6 @@ const (
|
||||||
// Query operations
|
// Query operations
|
||||||
RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState"
|
RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState"
|
||||||
RoomserverQueryStateAfterEventsPath = "/roomserver/queryStateAfterEvents"
|
RoomserverQueryStateAfterEventsPath = "/roomserver/queryStateAfterEvents"
|
||||||
RoomserverQueryMissingAuthPrevEventsPath = "/roomserver/queryMissingAuthPrevEvents"
|
|
||||||
RoomserverQueryEventsByIDPath = "/roomserver/queryEventsByID"
|
RoomserverQueryEventsByIDPath = "/roomserver/queryEventsByID"
|
||||||
RoomserverQueryMembershipForUserPath = "/roomserver/queryMembershipForUser"
|
RoomserverQueryMembershipForUserPath = "/roomserver/queryMembershipForUser"
|
||||||
RoomserverQueryMembershipsForRoomPath = "/roomserver/queryMembershipsForRoom"
|
RoomserverQueryMembershipsForRoomPath = "/roomserver/queryMembershipsForRoom"
|
||||||
|
@ -302,19 +301,6 @@ func (h *httpRoomserverInternalAPI) QueryStateAfterEvents(
|
||||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryStateAfterEvents implements RoomserverQueryAPI
|
|
||||||
func (h *httpRoomserverInternalAPI) QueryMissingAuthPrevEvents(
|
|
||||||
ctx context.Context,
|
|
||||||
request *api.QueryMissingAuthPrevEventsRequest,
|
|
||||||
response *api.QueryMissingAuthPrevEventsResponse,
|
|
||||||
) error {
|
|
||||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMissingAuthPrevEvents")
|
|
||||||
defer span.Finish()
|
|
||||||
|
|
||||||
apiURL := h.roomserverURL + RoomserverQueryMissingAuthPrevEventsPath
|
|
||||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryEventsByID implements RoomserverQueryAPI
|
// QueryEventsByID implements RoomserverQueryAPI
|
||||||
func (h *httpRoomserverInternalAPI) QueryEventsByID(
|
func (h *httpRoomserverInternalAPI) QueryEventsByID(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
|
|
@ -149,20 +149,6 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
internalAPIMux.Handle(
|
|
||||||
RoomserverQueryMissingAuthPrevEventsPath,
|
|
||||||
httputil.MakeInternalAPI("queryMissingAuthPrevEvents", func(req *http.Request) util.JSONResponse {
|
|
||||||
var request api.QueryMissingAuthPrevEventsRequest
|
|
||||||
var response api.QueryMissingAuthPrevEventsResponse
|
|
||||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
|
||||||
return util.ErrorResponse(err)
|
|
||||||
}
|
|
||||||
if err := r.QueryMissingAuthPrevEvents(req.Context(), &request, &response); err != nil {
|
|
||||||
return util.ErrorResponse(err)
|
|
||||||
}
|
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
internalAPIMux.Handle(
|
internalAPIMux.Handle(
|
||||||
RoomserverQueryEventsByIDPath,
|
RoomserverQueryEventsByIDPath,
|
||||||
httputil.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse {
|
httputil.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse {
|
||||||
|
|
|
@ -76,7 +76,8 @@ func prepareEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
|
||||||
func (s *eventJSONStatements) InsertEventJSON(
|
func (s *eventJSONStatements) InsertEventJSON(
|
||||||
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte,
|
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte,
|
||||||
) error {
|
) error {
|
||||||
_, err := s.insertEventJSONStmt.ExecContext(ctx, int64(eventNID), eventJSON)
|
stmt := sqlutil.TxStmt(txn, s.insertEventJSONStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, int64(eventNID), eventJSON)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@ type RoomUpdater struct {
|
||||||
latestEvents []types.StateAtEventAndReference
|
latestEvents []types.StateAtEventAndReference
|
||||||
lastEventIDSent string
|
lastEventIDSent string
|
||||||
currentStateSnapshotNID types.StateSnapshotNID
|
currentStateSnapshotNID types.StateSnapshotNID
|
||||||
|
roomExists bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func rollback(txn *sql.Tx) {
|
func rollback(txn *sql.Tx) {
|
||||||
|
@ -33,7 +34,7 @@ func NewRoomUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo *typ
|
||||||
// succeed, processing a create event which creates the room, or it won't.
|
// succeed, processing a create event which creates the room, or it won't.
|
||||||
if roomInfo == nil {
|
if roomInfo == nil {
|
||||||
return &RoomUpdater{
|
return &RoomUpdater{
|
||||||
transaction{ctx, txn}, d, nil, nil, "", 0,
|
transaction{ctx, txn}, d, nil, nil, "", 0, false,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,10 +58,15 @@ func NewRoomUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo *typ
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &RoomUpdater{
|
return &RoomUpdater{
|
||||||
transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
|
transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, true,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RoomExists returns true if the room exists and false otherwise.
|
||||||
|
func (u *RoomUpdater) RoomExists() bool {
|
||||||
|
return u.roomExists
|
||||||
|
}
|
||||||
|
|
||||||
// Implements sqlutil.Transaction
|
// Implements sqlutil.Transaction
|
||||||
func (u *RoomUpdater) Commit() error {
|
func (u *RoomUpdater) Commit() error {
|
||||||
if u.txn == nil { // SQLite mode probably
|
if u.txn == nil { // SQLite mode probably
|
||||||
|
@ -97,6 +103,25 @@ func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
|
||||||
return u.currentStateSnapshotNID
|
return u.currentStateSnapshotNID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *RoomUpdater) MissingAuthPrevEvents(
|
||||||
|
ctx context.Context, e *gomatrixserverlib.Event,
|
||||||
|
) (missingAuth, missingPrev []string, err error) {
|
||||||
|
for _, authEventID := range e.AuthEventIDs() {
|
||||||
|
if nids, err := u.EventNIDs(ctx, []string{authEventID}); err != nil || len(nids) == 0 {
|
||||||
|
missingAuth = append(missingAuth, authEventID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prevEventID := range e.PrevEventIDs() {
|
||||||
|
state, err := u.StateAtEventIDs(ctx, []string{prevEventID})
|
||||||
|
if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) {
|
||||||
|
missingPrev = append(missingPrev, prevEventID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer
|
// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer
|
||||||
func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
|
func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
|
||||||
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
|
||||||
|
|
|
@ -553,7 +553,7 @@ func (d *Database) storeEvent(
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
var txn *sql.Tx
|
var txn *sql.Tx
|
||||||
if updater != nil {
|
if updater != nil && updater.txn != nil {
|
||||||
txn = updater.txn
|
txn = updater.txn
|
||||||
}
|
}
|
||||||
err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
|
||||||
|
|
Loading…
Reference in a new issue