// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package relayapi_test import ( "crypto/ed25519" "encoding/hex" "encoding/json" "net/http" "net/http/httptest" "testing" "time" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/relayapi" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ) func TestCreateNewRelayInternalAPI(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { cfg, processCtx, close := testrig.CreateConfig(t, dbType) caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) defer close() cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) relayAPI := relayapi.NewRelayInternalAPI(cfg, cm, nil, nil, nil, nil, true, caches) assert.NotNil(t, relayAPI) }) } func TestCreateRelayInternalInvalidDatabasePanics(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { cfg, processCtx, close := testrig.CreateConfig(t, dbType) if dbType == test.DBTypeSQLite { cfg.RelayAPI.Database.ConnectionString = "file:" } else { cfg.RelayAPI.Database.ConnectionString = "test" } defer close() cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) assert.Panics(t, func() { relayapi.NewRelayInternalAPI(cfg, cm, nil, nil, nil, nil, true, nil) }) }) } func TestCreateInvalidRelayPublicRoutesPanics(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { cfg, _, close := testrig.CreateConfig(t, dbType) defer close() routers := httputil.NewRouters() assert.Panics(t, func() { relayapi.AddPublicRoutes(routers, cfg, nil, nil) }) }) } func createGetRelayTxnHTTPRequest(serverName spec.ServerName, userID string) *http.Request { _, sk, _ := ed25519.GenerateKey(nil) keyID := signing.KeyID pk := sk.Public().(ed25519.PublicKey) origin := spec.ServerName(hex.EncodeToString(pk)) req := fclient.NewFederationRequest("GET", origin, serverName, "/_matrix/federation/v1/relay_txn/"+userID) content := fclient.RelayEntry{EntryID: 0} req.SetContent(content) req.Sign(origin, gomatrixserverlib.KeyID(keyID), sk) httpreq, _ := req.HTTPRequest() vars := map[string]string{"userID": userID} httpreq = mux.SetURLVars(httpreq, vars) return httpreq } type sendRelayContent struct { PDUs []json.RawMessage `json:"pdus"` EDUs []gomatrixserverlib.EDU `json:"edus"` } func createSendRelayTxnHTTPRequest(serverName spec.ServerName, txnID string, userID string) *http.Request { _, sk, _ := ed25519.GenerateKey(nil) keyID := signing.KeyID pk := sk.Public().(ed25519.PublicKey) origin := spec.ServerName(hex.EncodeToString(pk)) req := fclient.NewFederationRequest("PUT", origin, serverName, "/_matrix/federation/v1/send_relay/"+txnID+"/"+userID) content := sendRelayContent{} req.SetContent(content) req.Sign(origin, gomatrixserverlib.KeyID(keyID), sk) httpreq, _ := req.HTTPRequest() vars := map[string]string{"userID": userID, "txnID": txnID} httpreq = mux.SetURLVars(httpreq, vars) return httpreq } func TestCreateRelayPublicRoutes(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { cfg, processCtx, close := testrig.CreateConfig(t, dbType) defer close() routers := httputil.NewRouters() caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) relayAPI := relayapi.NewRelayInternalAPI(cfg, cm, nil, nil, nil, nil, true, caches) assert.NotNil(t, relayAPI) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() relayapi.AddPublicRoutes(routers, cfg, keyRing, relayAPI) testCases := []struct { name string req *http.Request wantCode int }{ { name: "relay_txn invalid user id", req: createGetRelayTxnHTTPRequest(cfg.Global.ServerName, "user:local"), wantCode: 400, }, { name: "relay_txn valid user id", req: createGetRelayTxnHTTPRequest(cfg.Global.ServerName, "@user:local"), wantCode: 200, }, { name: "send_relay invalid user id", req: createSendRelayTxnHTTPRequest(cfg.Global.ServerName, "123", "user:local"), wantCode: 400, }, { name: "send_relay valid user id", req: createSendRelayTxnHTTPRequest(cfg.Global.ServerName, "123", "@user:local"), wantCode: 200, }, } for _, tc := range testCases { w := httptest.NewRecorder() routers.Federation.ServeHTTP(w, tc.req) if w.Code != tc.wantCode { t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode) } } }) } func TestDisableRelayPublicRoutes(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { cfg, processCtx, close := testrig.CreateConfig(t, dbType) defer close() routers := httputil.NewRouters() caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) relayAPI := relayapi.NewRelayInternalAPI(cfg, cm, nil, nil, nil, nil, false, caches) assert.NotNil(t, relayAPI) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() relayapi.AddPublicRoutes(routers, cfg, keyRing, relayAPI) testCases := []struct { name string req *http.Request wantCode int }{ { name: "relay_txn valid user id", req: createGetRelayTxnHTTPRequest(cfg.Global.ServerName, "@user:local"), wantCode: 404, }, { name: "send_relay valid user id", req: createSendRelayTxnHTTPRequest(cfg.Global.ServerName, "123", "@user:local"), wantCode: 404, }, } for _, tc := range testCases { w := httptest.NewRecorder() routers.Federation.ServeHTTP(w, tc.req) if w.Code != tc.wantCode { t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode) } } }) }