0
0
Fork 0
mirror of https://github.com/matrix-org/dendrite synced 2024-11-18 07:40:53 +01:00

Check peek state response and refactor checking send_join response (#1732)

This commit is contained in:
Kegsay 2021-01-22 17:16:35 +00:00 committed by GitHub
parent 6757b67a32
commit ef9d5ad4fe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 88 additions and 132 deletions

View file

@ -8,7 +8,6 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/federationsender/internal/perform"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
@ -218,9 +217,9 @@ func (r *FederationSenderInternalAPI) performJoinUsingServer(
// Sanity-check the join response to ensure that it has a create // Sanity-check the join response to ensure that it has a create
// event, that the room version is known, etc. // event, that the room version is known, etc.
if err := sanityCheckSendJoinResponse(respSendJoin); err != nil { if err := sanityCheckAuthChain(respSendJoin.AuthEvents); err != nil {
cancel() cancel()
return fmt.Errorf("sanityCheckSendJoinResponse: %w", err) return fmt.Errorf("sanityCheckAuthChain: %w", err)
} }
// Process the join response in a goroutine. The idea here is // Process the join response in a goroutine. The idea here is
@ -231,11 +230,9 @@ func (r *FederationSenderInternalAPI) performJoinUsingServer(
go func() { go func() {
defer cancel() defer cancel()
// Check that the send_join response was valid. // TODO: Can we expand Check here to return a list of missing auth
joinCtx := perform.JoinContext(r.federation, r.keyRing) // events rather than failing one at a time?
respState, err := joinCtx.CheckSendJoinResponse( respState, err := respSendJoin.Check(ctx, r.keyRing, event, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName))
ctx, event, serverName, respSendJoin,
)
if err != nil { if err != nil {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"room_id": roomID, "room_id": roomID,
@ -402,8 +399,18 @@ func (r *FederationSenderInternalAPI) performOutboundPeekUsingServer(
return fmt.Errorf("respPeek.RoomVersion.EventFormat: %w", err) return fmt.Errorf("respPeek.RoomVersion.EventFormat: %w", err)
} }
// TODO: authenticate the state returned (check its auth events etc) // we have the peek state now so let's process regardless of whether upstream gives up
ctx = context.Background()
respState := respPeek.ToRespState()
// authenticate the state returned (check its auth events etc)
// the equivalent of CheckSendJoinResponse() // the equivalent of CheckSendJoinResponse()
if err = sanityCheckAuthChain(respState.AuthEvents); err != nil {
return fmt.Errorf("sanityCheckAuthChain: %w", err)
}
if err = respState.Check(ctx, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName)); err != nil {
return fmt.Errorf("Error checking state returned from peeking: %w", err)
}
// If we've got this far, the remote server is peeking. // If we've got this far, the remote server is peeking.
if renewing { if renewing {
@ -416,7 +423,6 @@ func (r *FederationSenderInternalAPI) performOutboundPeekUsingServer(
} }
} }
respState := respPeek.ToRespState()
// logrus.Warnf("got respPeek %#v", respPeek) // logrus.Warnf("got respPeek %#v", respPeek)
// Send the newly returned state to the roomserver to update our local view. // Send the newly returned state to the roomserver to update our local view.
if err = roomserverAPI.SendEventWithState( if err = roomserverAPI.SendEventWithState(
@ -607,9 +613,9 @@ func (r *FederationSenderInternalAPI) PerformBroadcastEDU(
return nil return nil
} }
func sanityCheckSendJoinResponse(respSendJoin gomatrixserverlib.RespSendJoin) error { func sanityCheckAuthChain(authChain []*gomatrixserverlib.Event) error {
// sanity check we have a create event and it has a known room version // sanity check we have a create event and it has a known room version
for _, ev := range respSendJoin.AuthEvents { for _, ev := range authChain {
if ev.Type() == gomatrixserverlib.MRoomCreate && ev.StateKeyEquals("") { if ev.Type() == gomatrixserverlib.MRoomCreate && ev.StateKeyEquals("") {
// make sure the room version is known // make sure the room version is known
content := ev.Content() content := ev.Content()
@ -627,12 +633,12 @@ func sanityCheckSendJoinResponse(respSendJoin gomatrixserverlib.RespSendJoin) er
} }
knownVersions := gomatrixserverlib.RoomVersions() knownVersions := gomatrixserverlib.RoomVersions()
if _, ok := knownVersions[gomatrixserverlib.RoomVersion(verBody.Version)]; !ok { if _, ok := knownVersions[gomatrixserverlib.RoomVersion(verBody.Version)]; !ok {
return fmt.Errorf("send_join m.room.create event has an unknown room version: %s", verBody.Version) return fmt.Errorf("auth chain m.room.create event has an unknown room version: %s", verBody.Version)
} }
return nil return nil
} }
} }
return fmt.Errorf("send_join response is missing m.room.create event") return fmt.Errorf("auth chain response is missing m.room.create event")
} }
func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder) gomatrixserverlib.RoomVersion { func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder) gomatrixserverlib.RoomVersion {
@ -656,3 +662,71 @@ func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder
} }
return gomatrixserverlib.RoomVersionV4 return gomatrixserverlib.RoomVersionV4
} }
// FederatedAuthProvider is an auth chain provider which fetches events from the server provided
func federatedAuthProvider(
ctx context.Context, federation *gomatrixserverlib.FederationClient,
keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName,
) gomatrixserverlib.AuthChainProvider {
// A list of events that we have retried, if they were not included in
// the auth events supplied in the send_join.
retries := map[string][]*gomatrixserverlib.Event{}
// Define a function which we can pass to Check to retrieve missing
// auth events inline. This greatly increases our chances of not having
// to repeat the entire set of checks just for a missing event or two.
return func(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]*gomatrixserverlib.Event, error) {
returning := []*gomatrixserverlib.Event{}
// See if we have retry entries for each of the supplied event IDs.
for _, eventID := range eventIDs {
// If we've already satisfied a request for this event ID before then
// just append the results. We won't retry the request.
if retry, ok := retries[eventID]; ok {
if retry == nil {
return nil, fmt.Errorf("missingAuth: not retrying failed event ID %q", eventID)
}
returning = append(returning, retry...)
continue
}
// Make a note of the fact that we tried to do something with this
// event ID, even if we don't succeed.
retries[eventID] = nil
// Try to retrieve the event from the server that sent us the send
// join response.
tx, txerr := federation.GetEvent(ctx, server, eventID)
if txerr != nil {
return nil, fmt.Errorf("missingAuth r.federation.GetEvent: %w", txerr)
}
// For each event returned, add it to the set of return events. We
// also will populate the retries, in case someone asks for this
// event ID again.
for _, pdu := range tx.PDUs {
// Try to parse the event.
ev, everr := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion)
if everr != nil {
return nil, fmt.Errorf("missingAuth gomatrixserverlib.NewEventFromUntrustedJSON: %w", everr)
}
// Check the signatures of the event.
if res, err := gomatrixserverlib.VerifyEventSignatures(ctx, []*gomatrixserverlib.Event{ev}, keyRing); err != nil {
return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err)
} else {
for _, err := range res {
if err != nil {
return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err)
}
}
}
// If the event is OK then add it to the results and the retry map.
returning = append(returning, ev)
retries[ev.EventID()] = append(retries[ev.EventID()], ev)
}
}
return returning, nil
}
}

View file

@ -1,118 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package perform
import (
"context"
"fmt"
"github.com/matrix-org/gomatrixserverlib"
)
// This file contains helpers for the PerformJoin function.
type joinContext struct {
federation *gomatrixserverlib.FederationClient
keyRing *gomatrixserverlib.KeyRing
}
// Returns a new join context.
func JoinContext(f *gomatrixserverlib.FederationClient, k *gomatrixserverlib.KeyRing) *joinContext {
return &joinContext{
federation: f,
keyRing: k,
}
}
// checkSendJoinResponse checks that all of the signatures are correct
// and that the join is allowed by the supplied state.
func (r joinContext) CheckSendJoinResponse(
ctx context.Context,
event *gomatrixserverlib.Event,
server gomatrixserverlib.ServerName,
respSendJoin gomatrixserverlib.RespSendJoin,
) (*gomatrixserverlib.RespState, error) {
// A list of events that we have retried, if they were not included in
// the auth events supplied in the send_join.
retries := map[string][]*gomatrixserverlib.Event{}
// Define a function which we can pass to Check to retrieve missing
// auth events inline. This greatly increases our chances of not having
// to repeat the entire set of checks just for a missing event or two.
missingAuth := func(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]*gomatrixserverlib.Event, error) {
returning := []*gomatrixserverlib.Event{}
// See if we have retry entries for each of the supplied event IDs.
for _, eventID := range eventIDs {
// If we've already satisfied a request for this event ID before then
// just append the results. We won't retry the request.
if retry, ok := retries[eventID]; ok {
if retry == nil {
return nil, fmt.Errorf("missingAuth: not retrying failed event ID %q", eventID)
}
returning = append(returning, retry...)
continue
}
// Make a note of the fact that we tried to do something with this
// event ID, even if we don't succeed.
retries[event.EventID()] = nil
// Try to retrieve the event from the server that sent us the send
// join response.
tx, txerr := r.federation.GetEvent(ctx, server, eventID)
if txerr != nil {
return nil, fmt.Errorf("missingAuth r.federation.GetEvent: %w", txerr)
}
// For each event returned, add it to the set of return events. We
// also will populate the retries, in case someone asks for this
// event ID again.
for _, pdu := range tx.PDUs {
// Try to parse the event.
ev, everr := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion)
if everr != nil {
return nil, fmt.Errorf("missingAuth gomatrixserverlib.NewEventFromUntrustedJSON: %w", everr)
}
// Check the signatures of the event.
if res, err := gomatrixserverlib.VerifyEventSignatures(ctx, []*gomatrixserverlib.Event{ev}, r.keyRing); err != nil {
return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err)
} else {
for _, err := range res {
if err != nil {
return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err)
}
}
}
// If the event is OK then add it to the results and the retry map.
returning = append(returning, ev)
retries[event.EventID()] = append(retries[event.EventID()], ev)
retries[ev.EventID()] = append(retries[ev.EventID()], ev)
}
}
return returning, nil
}
// TODO: Can we expand Check here to return a list of missing auth
// events rather than failing one at a time?
rs, err := respSendJoin.Check(ctx, r.keyRing, event, missingAuth)
if err != nil {
return nil, fmt.Errorf("respSendJoin: %w", err)
}
return rs, nil
}