diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 6087bda0c..ff6a0900e 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -780,7 +780,7 @@ func handleApplicationServiceRegistration( // Don't need to worry about appending to registration stages as // application service registration is entirely separate. return completeRegistration( - req.Context(), userAPI, r.Username, r.ServerName, "", appserviceID, req.RemoteAddr, + req.Context(), userAPI, r.Username, r.ServerName, "", "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, ) @@ -800,7 +800,7 @@ func checkAndCompleteFlow( if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { // This flow was completed, registration can continue return completeRegistration( - req.Context(), userAPI, r.Username, r.ServerName, r.Password, "", req.RemoteAddr, + req.Context(), userAPI, r.Username, r.ServerName, "", r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, ) @@ -824,10 +824,10 @@ func checkAndCompleteFlow( func completeRegistration( ctx context.Context, userAPI userapi.ClientUserAPI, - username string, serverName gomatrixserverlib.ServerName, + username string, serverName gomatrixserverlib.ServerName, displayName string, password, appserviceID, ipAddr, userAgent, sessionID string, inhibitLogin eventutil.WeakBoolean, - displayName, deviceID *string, + deviceDisplayName, deviceID *string, accType userapi.AccountType, ) util.JSONResponse { if username == "" { @@ -887,12 +887,28 @@ func completeRegistration( } } + if displayName != "" { + nameReq := userapi.PerformUpdateDisplayNameRequest{ + Localpart: username, + ServerName: serverName, + DisplayName: displayName, + } + var nameRes userapi.PerformUpdateDisplayNameResponse + err = userAPI.SetDisplayName(ctx, &nameReq, &nameRes) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.Unknown("failed to set display name: " + err.Error()), + } + } + } + var devRes userapi.PerformDeviceCreationResponse err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ Localpart: username, ServerName: serverName, AccessToken: token, - DeviceDisplayName: displayName, + DeviceDisplayName: deviceDisplayName, DeviceID: deviceID, IPAddr: ipAddr, UserAgent: userAgent, @@ -1077,5 +1093,5 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien if ssrr.Admin { accType = userapi.AccountTypeAdmin } - return completeRegistration(req.Context(), userAPI, ssrr.User, cfg.Matrix.ServerName, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) + return completeRegistration(req.Context(), userAPI, ssrr.User, cfg.Matrix.ServerName, ssrr.DisplayName, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) } diff --git a/clientapi/routing/register_secret.go b/clientapi/routing/register_secret.go index 1a974b77a..f384b604a 100644 --- a/clientapi/routing/register_secret.go +++ b/clientapi/routing/register_secret.go @@ -18,12 +18,13 @@ import ( ) type SharedSecretRegistrationRequest struct { - User string `json:"username"` - Password string `json:"password"` - Nonce string `json:"nonce"` - MacBytes []byte - MacStr string `json:"mac"` - Admin bool `json:"admin"` + User string `json:"username"` + Password string `json:"password"` + Nonce string `json:"nonce"` + MacBytes []byte + MacStr string `json:"mac"` + Admin bool `json:"admin"` + DisplayName string `json:"displayname,omitempty"` } func NewSharedSecretRegistrationRequest(reader io.ReadCloser) (*SharedSecretRegistrationRequest, error) { diff --git a/clientapi/routing/register_secret_test.go b/clientapi/routing/register_secret_test.go index a2ed35853..ca265d237 100644 --- a/clientapi/routing/register_secret_test.go +++ b/clientapi/routing/register_secret_test.go @@ -10,7 +10,7 @@ import ( func TestSharedSecretRegister(t *testing.T) { // these values have come from a local synapse instance to ensure compatibility - jsonStr := []byte(`{"admin":false,"mac":"f1ba8d37123866fd659b40de4bad9b0f8965c565","nonce":"759f047f312b99ff428b21d581256f8592b8976e58bc1b543972dc6147e529a79657605b52d7becd160ff5137f3de11975684319187e06901955f79e5a6c5a79","password":"wonderland","username":"alice"}`) + jsonStr := []byte(`{"admin":false,"mac":"f1ba8d37123866fd659b40de4bad9b0f8965c565","nonce":"759f047f312b99ff428b21d581256f8592b8976e58bc1b543972dc6147e529a79657605b52d7becd160ff5137f3de11975684319187e06901955f79e5a6c5a79","password":"wonderland","username":"alice","displayname":"rabbit"}`) sharedSecret := "dendritetest" req, err := NewSharedSecretRegistrationRequest(io.NopCloser(bytes.NewBuffer(jsonStr))) diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index b8fd19e90..bccc1b79b 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -18,6 +18,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "reflect" @@ -35,7 +36,10 @@ import ( "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" + "github.com/patrickmn/go-cache" + "github.com/stretchr/testify/assert" ) var ( @@ -570,3 +574,92 @@ func Test_register(t *testing.T) { } }) } + +func TestRegisterUserWithDisplayName(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + base.Cfg.Global.ServerName = "server" + + rsAPI := roomserver.NewInternalAPI(base) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + keyAPI.SetUserAPI(userAPI) + deviceName, deviceID := "deviceName", "deviceID" + expectedDisplayName := "DisplayName" + response := completeRegistration( + base.Context(), + userAPI, + "user", + "server", + expectedDisplayName, + "password", + "", + "localhost", + "user agent", + "session", + false, + &deviceName, + &deviceID, + api.AccountTypeAdmin, + ) + + assert.Equal(t, http.StatusOK, response.Code) + + req := api.QueryProfileRequest{UserID: "@user:server"} + var res api.QueryProfileResponse + err := userAPI.QueryProfile(base.Context(), &req, &res) + assert.NoError(t, err) + assert.Equal(t, expectedDisplayName, res.DisplayName) + }) +} + +func TestRegisterAdminUsingSharedSecret(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + base.Cfg.Global.ServerName = "server" + sharedSecret := "dendritetest" + base.Cfg.ClientAPI.RegistrationSharedSecret = sharedSecret + + rsAPI := roomserver.NewInternalAPI(base) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + keyAPI.SetUserAPI(userAPI) + + expectedDisplayName := "rabbit" + jsonStr := []byte(`{"admin":true,"mac":"24dca3bba410e43fe64b9b5c28306693bf3baa9f","nonce":"759f047f312b99ff428b21d581256f8592b8976e58bc1b543972dc6147e529a79657605b52d7becd160ff5137f3de11975684319187e06901955f79e5a6c5a79","password":"wonderland","username":"alice","displayname":"rabbit"}`) + req, err := NewSharedSecretRegistrationRequest(io.NopCloser(bytes.NewBuffer(jsonStr))) + assert.NoError(t, err) + if err != nil { + t.Fatalf("failed to read request: %s", err) + } + + r := NewSharedSecretRegistration(sharedSecret) + + // force the nonce to be known + r.nonces.Set(req.Nonce, true, cache.DefaultExpiration) + + _, err = r.IsValidMacLogin(req.Nonce, req.User, req.Password, req.Admin, req.MacBytes) + assert.NoError(t, err) + + body := &bytes.Buffer{} + err = json.NewEncoder(body).Encode(req) + assert.NoError(t, err) + ssrr := httptest.NewRequest(http.MethodPost, "/", body) + + response := handleSharedSecretRegistration( + &base.Cfg.ClientAPI, + userAPI, + r, + ssrr, + ) + assert.Equal(t, http.StatusOK, response.Code) + + profilReq := api.QueryProfileRequest{UserID: "@alice:server"} + var profileRes api.QueryProfileResponse + err = userAPI.QueryProfile(base.Context(), &profilReq, &profileRes) + assert.NoError(t, err) + assert.Equal(t, expectedDisplayName, profileRes.DisplayName) + }) +}