0
0
Fork 0
mirror of https://github.com/matrix-org/dendrite synced 2024-12-14 16:13:49 +01:00

Fix some cases where accepting invites over federation doesn't work (#1028)

* Handle cases where accepting invites doesn't work for historic rooms

* Rewrite pairUpChanges

* Review comments
This commit is contained in:
Neil Alexander 2020-05-14 14:58:47 +01:00 committed by GitHub
parent 8adc128225
commit 640a0265df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 79 additions and 79 deletions

View file

@ -231,8 +231,7 @@ func updateToLeaveMembership(
return updates, nil return updates, nil
} }
// membershipChanges pairs up the membership state changes from a sorted list // membershipChanges pairs up the membership state changes.
// of state removed and a sorted list of state added.
func membershipChanges(removed, added []types.StateEntry) []stateChange { func membershipChanges(removed, added []types.StateEntry) []stateChange {
changes := pairUpChanges(removed, added) changes := pairUpChanges(removed, added)
var result []stateChange var result []stateChange
@ -251,64 +250,39 @@ type stateChange struct {
} }
// pairUpChanges pairs up the state events added and removed for each type, // pairUpChanges pairs up the state events added and removed for each type,
// state key tuple. Assumes that removed and added are sorted. // state key tuple.
func pairUpChanges(removed, added []types.StateEntry) []stateChange { func pairUpChanges(removed, added []types.StateEntry) []stateChange {
var ai int tuples := make(map[types.StateKeyTuple]stateChange)
var ri int changes := []stateChange{}
var result []stateChange
for { // First, go through the newly added state entries.
switch { for _, add := range added {
case ai == len(added): if change, ok := tuples[add.StateKeyTuple]; ok {
// We've reached the end of the added entries. // If we already have an entry, update it.
// The rest of the removed list are events that were removed without change.addedEventNID = add.EventNID
// an event with the same state key being added. tuples[add.StateKeyTuple] = change
for _, s := range removed[ri:] { } else {
result = append(result, stateChange{ // Otherwise, create a new entry.
StateKeyTuple: s.StateKeyTuple, tuples[add.StateKeyTuple] = stateChange{add.StateKeyTuple, 0, add.EventNID}
removedEventNID: s.EventNID,
})
}
return result
case ri == len(removed):
// We've reached the end of the removed entries.
// The rest of the added list are events that were added without
// an event with the same state key being removed.
for _, s := range added[ai:] {
result = append(result, stateChange{
StateKeyTuple: s.StateKeyTuple,
addedEventNID: s.EventNID,
})
}
return result
case added[ai].StateKeyTuple == removed[ri].StateKeyTuple:
// The tuple is in both lists so an event with that key is being
// removed and another event with the same key is being added.
result = append(result, stateChange{
StateKeyTuple: added[ai].StateKeyTuple,
removedEventNID: removed[ri].EventNID,
addedEventNID: added[ai].EventNID,
})
ai++
ri++
case added[ai].StateKeyTuple.LessThan(removed[ri].StateKeyTuple):
// The lists are sorted so the added entry being less than the
// removed entry means that the added event was added without an
// event with the same key being removed.
result = append(result, stateChange{
StateKeyTuple: added[ai].StateKeyTuple,
addedEventNID: added[ai].EventNID,
})
ai++
default:
// Reaching the default case implies that the removed entry is less
// than the added entry. Since the lists are sorted this means that
// the removed event was removed without an event with the same
// key being added.
result = append(result, stateChange{
StateKeyTuple: removed[ai].StateKeyTuple,
removedEventNID: removed[ri].EventNID,
})
ri++
} }
} }
// Now go through the removed state entries.
for _, remove := range removed {
if change, ok := tuples[remove.StateKeyTuple]; ok {
// If we already have an entry, update it.
change.removedEventNID = remove.EventNID
tuples[remove.StateKeyTuple] = change
} else {
// Otherwise, create a new entry.
tuples[remove.StateKeyTuple] = stateChange{remove.StateKeyTuple, remove.EventNID, 0}
}
}
// Now return the changes as an array.
for _, change := range tuples {
changes = append(changes, change)
}
return changes
} }

View file

@ -121,6 +121,22 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
return fmt.Errorf("eb.SetContent: %w", err) return fmt.Errorf("eb.SetContent: %w", err)
} }
// First work out if this is in response to an existing invite.
// If it is then we avoid the situation where we might think we
// know about a room in the following section but don't know the
// latest state as all of our users have left.
isInvitePending, inviteSender, err := r.isInvitePending(ctx, req.RoomIDOrAlias, req.UserID)
if err == nil && isInvitePending {
// Add the server of the person who invited us to the server list,
// as they should be a fairly good bet.
if _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender); ierr == nil {
req.ServerNames = append(req.ServerNames, inviterDomain)
}
// Perform a federated room join.
return r.performFederatedJoinRoomByID(ctx, req, res)
}
// Try to construct an actual join event from the template. // Try to construct an actual join event from the template.
// If this succeeds then it is a sign that the room already exists // If this succeeds then it is a sign that the room already exists
// locally on the homeserver. // locally on the homeserver.
@ -178,6 +194,22 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
return fmt.Errorf("Room ID %q does not exist", req.RoomIDOrAlias) return fmt.Errorf("Room ID %q does not exist", req.RoomIDOrAlias)
} }
// Perform a federated room join.
return r.performFederatedJoinRoomByID(ctx, req, res)
default:
// Something else went wrong.
return fmt.Errorf("Error joining local room: %q", err)
}
return nil
}
func (r *RoomserverInternalAPI) performFederatedJoinRoomByID(
ctx context.Context,
req *api.PerformJoinRequest,
res *api.PerformJoinResponse, // nolint:unparam
) error {
// Try joining by all of the supplied server names. // Try joining by all of the supplied server names.
fedReq := fsAPI.PerformJoinRequest{ fedReq := fsAPI.PerformJoinRequest{
RoomID: req.RoomIDOrAlias, // the room ID to try and join RoomID: req.RoomIDOrAlias, // the room ID to try and join
@ -186,14 +218,9 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
Content: req.Content, // the membership event content Content: req.Content, // the membership event content
} }
fedRes := fsAPI.PerformJoinResponse{} fedRes := fsAPI.PerformJoinResponse{}
err = r.fsAPI.PerformJoin(ctx, &fedReq, &fedRes) if err := r.fsAPI.PerformJoin(ctx, &fedReq, &fedRes); err != nil {
if err != nil {
return fmt.Errorf("Error joining federated room: %q", err) return fmt.Errorf("Error joining federated room: %q", err)
} }
default:
return fmt.Errorf("Error joining room %q: %w", req.RoomIDOrAlias, err)
}
return nil return nil
} }

View file

@ -38,7 +38,7 @@ func (r *RoomserverInternalAPI) performLeaveRoomByID(
) error { ) error {
// If there's an invite outstanding for the room then respond to // If there's an invite outstanding for the room then respond to
// that. // that.
isInvitePending, senderUser, err := r.isInvitePending(ctx, req, res) isInvitePending, senderUser, err := r.isInvitePending(ctx, req.RoomID, req.UserID)
if err == nil && isInvitePending { if err == nil && isInvitePending {
return r.performRejectInvite(ctx, req, res, senderUser) return r.performRejectInvite(ctx, req, res, senderUser)
} }
@ -160,23 +160,22 @@ func (r *RoomserverInternalAPI) performRejectInvite(
func (r *RoomserverInternalAPI) isInvitePending( func (r *RoomserverInternalAPI) isInvitePending(
ctx context.Context, ctx context.Context,
req *api.PerformLeaveRequest, roomID, userID string,
res *api.PerformLeaveResponse, // nolint:unparam
) (bool, string, error) { ) (bool, string, error) {
// Look up the room NID for the supplied room ID. // Look up the room NID for the supplied room ID.
roomNID, err := r.DB.RoomNID(ctx, req.RoomID) roomNID, err := r.DB.RoomNID(ctx, roomID)
if err != nil { if err != nil {
return false, "", fmt.Errorf("r.DB.RoomNID: %w", err) return false, "", fmt.Errorf("r.DB.RoomNID: %w", err)
} }
// Look up the state key NID for the supplied user ID. // Look up the state key NID for the supplied user ID.
targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{req.UserID}) targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{userID})
if err != nil { if err != nil {
return false, "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err) return false, "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
} }
targetUserNID, targetUserFound := targetUserNIDs[req.UserID] targetUserNID, targetUserFound := targetUserNIDs[userID]
if !targetUserFound { if !targetUserFound {
return false, "", fmt.Errorf("missing NID for user %q (%+v)", req.UserID, targetUserNIDs) return false, "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs)
} }
// Let's see if we have an event active for the user in the room. If // Let's see if we have an event active for the user in the room. If