2022-12-02 16:42:23 +01:00
|
|
|
package appservice_test
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"encoding/json"
|
|
|
|
"net/http"
|
|
|
|
"net/http/httptest"
|
|
|
|
"reflect"
|
|
|
|
"regexp"
|
|
|
|
"strings"
|
|
|
|
"testing"
|
|
|
|
|
|
|
|
"github.com/matrix-org/dendrite/appservice"
|
|
|
|
"github.com/matrix-org/dendrite/appservice/api"
|
|
|
|
"github.com/matrix-org/dendrite/roomserver"
|
|
|
|
"github.com/matrix-org/dendrite/setup/config"
|
|
|
|
"github.com/matrix-org/dendrite/test"
|
|
|
|
"github.com/matrix-org/dendrite/userapi"
|
|
|
|
|
|
|
|
"github.com/matrix-org/dendrite/test/testrig"
|
|
|
|
)
|
|
|
|
|
|
|
|
func TestAppserviceInternalAPI(t *testing.T) {
|
|
|
|
|
|
|
|
// Set expected results
|
|
|
|
existingProtocol := "irc"
|
|
|
|
wantLocationResponse := []api.ASLocationResponse{{Protocol: existingProtocol, Fields: []byte("{}")}}
|
|
|
|
wantUserResponse := []api.ASUserResponse{{Protocol: existingProtocol, Fields: []byte("{}")}}
|
|
|
|
wantProtocolResponse := api.ASProtocolResponse{Instances: []api.ProtocolInstance{{Fields: []byte("{}")}}}
|
|
|
|
wantProtocolResult := map[string]api.ASProtocolResponse{
|
|
|
|
existingProtocol: wantProtocolResponse,
|
|
|
|
}
|
|
|
|
|
|
|
|
// create a dummy AS url, handling some cases
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
switch {
|
|
|
|
case strings.Contains(r.URL.Path, "location"):
|
|
|
|
// Check if we've got an existing protocol, if so, return a proper response.
|
|
|
|
if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol {
|
|
|
|
if err := json.NewEncoder(w).Encode(wantLocationResponse); err != nil {
|
|
|
|
t.Fatalf("failed to encode response: %s", err)
|
|
|
|
}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
if err := json.NewEncoder(w).Encode([]api.ASLocationResponse{}); err != nil {
|
|
|
|
t.Fatalf("failed to encode response: %s", err)
|
|
|
|
}
|
|
|
|
return
|
|
|
|
case strings.Contains(r.URL.Path, "user"):
|
|
|
|
if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol {
|
|
|
|
if err := json.NewEncoder(w).Encode(wantUserResponse); err != nil {
|
|
|
|
t.Fatalf("failed to encode response: %s", err)
|
|
|
|
}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
if err := json.NewEncoder(w).Encode([]api.UserResponse{}); err != nil {
|
|
|
|
t.Fatalf("failed to encode response: %s", err)
|
|
|
|
}
|
|
|
|
return
|
|
|
|
case strings.Contains(r.URL.Path, "protocol"):
|
|
|
|
if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol {
|
|
|
|
if err := json.NewEncoder(w).Encode(wantProtocolResponse); err != nil {
|
|
|
|
t.Fatalf("failed to encode response: %s", err)
|
|
|
|
}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
if err := json.NewEncoder(w).Encode(nil); err != nil {
|
|
|
|
t.Fatalf("failed to encode response: %s", err)
|
|
|
|
}
|
|
|
|
return
|
|
|
|
default:
|
|
|
|
t.Logf("hit location: %s", r.URL.Path)
|
|
|
|
}
|
|
|
|
}))
|
|
|
|
|
|
|
|
// The test cases to run
|
|
|
|
runCases := func(t *testing.T, testAPI api.AppServiceInternalAPI) {
|
|
|
|
t.Run("UserIDExists", func(t *testing.T) {
|
|
|
|
testUserIDExists(t, testAPI, "@as-testing:test", true)
|
|
|
|
testUserIDExists(t, testAPI, "@as1-testing:test", false)
|
|
|
|
})
|
|
|
|
|
|
|
|
t.Run("AliasExists", func(t *testing.T) {
|
|
|
|
testAliasExists(t, testAPI, "@asroom-testing:test", true)
|
|
|
|
testAliasExists(t, testAPI, "@asroom1-testing:test", false)
|
|
|
|
})
|
|
|
|
|
|
|
|
t.Run("Locations", func(t *testing.T) {
|
|
|
|
testLocations(t, testAPI, existingProtocol, wantLocationResponse)
|
|
|
|
testLocations(t, testAPI, "abc", nil)
|
|
|
|
})
|
|
|
|
|
|
|
|
t.Run("User", func(t *testing.T) {
|
|
|
|
testUser(t, testAPI, existingProtocol, wantUserResponse)
|
|
|
|
testUser(t, testAPI, "abc", nil)
|
|
|
|
})
|
|
|
|
|
|
|
|
t.Run("Protocols", func(t *testing.T) {
|
|
|
|
testProtocol(t, testAPI, existingProtocol, wantProtocolResult)
|
|
|
|
testProtocol(t, testAPI, existingProtocol, wantProtocolResult) // tests the cache
|
|
|
|
testProtocol(t, testAPI, "", wantProtocolResult) // tests getting all protocols
|
|
|
|
testProtocol(t, testAPI, "abc", nil)
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
2022-12-05 15:09:59 +01:00
|
|
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
2022-12-05 16:00:02 +01:00
|
|
|
base, closeBase := testrig.CreateBaseDendrite(t, dbType)
|
2022-12-05 15:09:59 +01:00
|
|
|
defer closeBase()
|
|
|
|
|
|
|
|
// Create a dummy application service
|
|
|
|
base.Cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{
|
|
|
|
{
|
|
|
|
ID: "someID",
|
|
|
|
URL: srv.URL,
|
|
|
|
ASToken: "",
|
|
|
|
HSToken: "",
|
|
|
|
SenderLocalpart: "senderLocalPart",
|
|
|
|
NamespaceMap: map[string][]config.ApplicationServiceNamespace{
|
|
|
|
"users": {{RegexpObject: regexp.MustCompile("as-.*")}},
|
|
|
|
"aliases": {{RegexpObject: regexp.MustCompile("asroom-.*")}},
|
|
|
|
},
|
|
|
|
Protocols: []string{existingProtocol},
|
|
|
|
},
|
2022-12-02 16:42:23 +01:00
|
|
|
}
|
|
|
|
|
2022-12-05 15:09:59 +01:00
|
|
|
// Create required internal APIs
|
|
|
|
rsAPI := roomserver.NewInternalAPI(base)
|
2023-02-20 14:58:03 +01:00
|
|
|
usrAPI := userapi.NewInternalAPI(base, rsAPI, nil)
|
2022-12-05 15:09:59 +01:00
|
|
|
asAPI := appservice.NewInternalAPI(base, usrAPI, rsAPI)
|
|
|
|
|
2023-02-14 12:47:47 +01:00
|
|
|
runCases(t, asAPI)
|
2022-12-05 15:09:59 +01:00
|
|
|
})
|
2022-12-02 16:42:23 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func testUserIDExists(t *testing.T, asAPI api.AppServiceInternalAPI, userID string, wantExists bool) {
|
|
|
|
ctx := context.Background()
|
|
|
|
userResp := &api.UserIDExistsResponse{}
|
|
|
|
|
|
|
|
if err := asAPI.UserIDExists(ctx, &api.UserIDExistsRequest{
|
|
|
|
UserID: userID,
|
|
|
|
}, userResp); err != nil {
|
|
|
|
t.Errorf("failed to get userID: %s", err)
|
|
|
|
}
|
|
|
|
if userResp.UserIDExists != wantExists {
|
|
|
|
t.Errorf("unexpected result for UserIDExists(%s): %v, expected %v", userID, userResp.UserIDExists, wantExists)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func testAliasExists(t *testing.T, asAPI api.AppServiceInternalAPI, alias string, wantExists bool) {
|
|
|
|
ctx := context.Background()
|
|
|
|
aliasResp := &api.RoomAliasExistsResponse{}
|
|
|
|
|
|
|
|
if err := asAPI.RoomAliasExists(ctx, &api.RoomAliasExistsRequest{
|
|
|
|
Alias: alias,
|
|
|
|
}, aliasResp); err != nil {
|
|
|
|
t.Errorf("failed to get alias: %s", err)
|
|
|
|
}
|
|
|
|
if aliasResp.AliasExists != wantExists {
|
|
|
|
t.Errorf("unexpected result for RoomAliasExists(%s): %v, expected %v", alias, aliasResp.AliasExists, wantExists)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func testLocations(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult []api.ASLocationResponse) {
|
|
|
|
ctx := context.Background()
|
|
|
|
locationResp := &api.LocationResponse{}
|
|
|
|
|
|
|
|
if err := asAPI.Locations(ctx, &api.LocationRequest{
|
|
|
|
Protocol: proto,
|
|
|
|
}, locationResp); err != nil {
|
|
|
|
t.Errorf("failed to get locations: %s", err)
|
|
|
|
}
|
|
|
|
if !reflect.DeepEqual(locationResp.Locations, wantResult) {
|
|
|
|
t.Errorf("unexpected result for Locations(%s): %+v, expected %+v", proto, locationResp.Locations, wantResult)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func testUser(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult []api.ASUserResponse) {
|
|
|
|
ctx := context.Background()
|
|
|
|
userResp := &api.UserResponse{}
|
|
|
|
|
|
|
|
if err := asAPI.User(ctx, &api.UserRequest{
|
|
|
|
Protocol: proto,
|
|
|
|
}, userResp); err != nil {
|
|
|
|
t.Errorf("failed to get user: %s", err)
|
|
|
|
}
|
|
|
|
if !reflect.DeepEqual(userResp.Users, wantResult) {
|
|
|
|
t.Errorf("unexpected result for User(%s): %+v, expected %+v", proto, userResp.Users, wantResult)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func testProtocol(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult map[string]api.ASProtocolResponse) {
|
|
|
|
ctx := context.Background()
|
|
|
|
protoResp := &api.ProtocolResponse{}
|
|
|
|
|
|
|
|
if err := asAPI.Protocols(ctx, &api.ProtocolRequest{
|
|
|
|
Protocol: proto,
|
|
|
|
}, protoResp); err != nil {
|
|
|
|
t.Errorf("failed to get Protocols: %s", err)
|
|
|
|
}
|
|
|
|
if !reflect.DeepEqual(protoResp.Protocols, wantResult) {
|
|
|
|
t.Errorf("unexpected result for Protocols(%s): %+v, expected %+v", proto, protoResp.Protocols[proto], wantResult)
|
|
|
|
}
|
|
|
|
}
|