diff --git a/build/scripts/find-lint.sh b/build/scripts/find-lint.sh index af87e14d7..e3564ae38 100755 --- a/build/scripts/find-lint.sh +++ b/build/scripts/find-lint.sh @@ -33,7 +33,7 @@ echo "Looking for lint..." # Capture exit code to ensure go.{mod,sum} is restored before exiting exit_code=0 -golangci-lint run $args || exit_code=1 +PATH="$PATH:${GOPATH:-~/go}/bin" golangci-lint run $args || exit_code=1 # Restore go.{mod,sum} mv go.mod.bak go.mod && mv go.sum.bak go.sum diff --git a/clientapi/auth/auth.go b/clientapi/auth/auth.go index c850bf91e..575c5377f 100644 --- a/clientapi/auth/auth.go +++ b/clientapi/auth/auth.go @@ -42,6 +42,7 @@ type DeviceDatabase interface { type AccountDatabase interface { // Look up the account matching the given localpart. GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) + GetAccountByPassword(ctx context.Context, localpart, password string) (*api.Account, error) } // VerifyUserFromRequest authenticates the HTTP request, diff --git a/clientapi/auth/authtypes/logintypes.go b/clientapi/auth/authtypes/logintypes.go index da0324251..f01e48f80 100644 --- a/clientapi/auth/authtypes/logintypes.go +++ b/clientapi/auth/authtypes/logintypes.go @@ -10,4 +10,5 @@ const ( LoginTypeSharedSecret = "org.matrix.login.shared_secret" LoginTypeRecaptcha = "m.login.recaptcha" LoginTypeApplicationService = "m.login.application_service" + LoginTypeToken = "m.login.token" ) diff --git a/clientapi/auth/login.go b/clientapi/auth/login.go new file mode 100644 index 000000000..1c14c6fbd --- /dev/null +++ b/clientapi/auth/login.go @@ -0,0 +1,83 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "encoding/json" + "io" + "io/ioutil" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +// LoginFromJSONReader performs authentication given a login request body reader and +// some context. It returns the basic login information and a cleanup function to be +// called after authorization has completed, with the result of the authorization. +// If the final return value is non-nil, an error occurred and the cleanup function +// is nil. +func LoginFromJSONReader(ctx context.Context, r io.Reader, accountDB AccountDatabase, userAPI UserInternalAPIForLogin, cfg *config.ClientAPI) (*Login, LoginCleanupFunc, *util.JSONResponse) { + reqBytes, err := ioutil.ReadAll(r) + if err != nil { + err := &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("Reading request body failed: " + err.Error()), + } + return nil, nil, err + } + + var header struct { + Type string `json:"type"` + } + if err := json.Unmarshal(reqBytes, &header); err != nil { + err := &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("Reading request body failed: " + err.Error()), + } + return nil, nil, err + } + + var typ Type + switch header.Type { + case authtypes.LoginTypePassword: + typ = &LoginTypePassword{ + GetAccountByPassword: accountDB.GetAccountByPassword, + Config: cfg, + } + case authtypes.LoginTypeToken: + typ = &LoginTypeToken{ + UserAPI: userAPI, + Config: cfg, + } + default: + err := util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("unhandled login type: " + header.Type), + } + return nil, nil, &err + } + + return typ.LoginFromJSON(ctx, reqBytes) +} + +// UserInternalAPIForLogin contains the aspects of UserAPI required for logging in. +type UserInternalAPIForLogin interface { + uapi.LoginTokenInternalAPI +} diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go new file mode 100644 index 000000000..e295f8f07 --- /dev/null +++ b/clientapi/auth/login_test.go @@ -0,0 +1,194 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "database/sql" + "net/http" + "reflect" + "strings" + "testing" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +func TestLoginFromJSONReader(t *testing.T) { + ctx := context.Background() + + tsts := []struct { + Name string + Body string + + WantUsername string + WantDeviceID string + WantDeletedTokens []string + }{ + { + Name: "passwordWorks", + Body: `{ + "type": "m.login.password", + "identifier": { "type": "m.id.user", "user": "alice" }, + "password": "herpassword", + "device_id": "adevice" + }`, + WantUsername: "alice", + WantDeviceID: "adevice", + }, + { + Name: "tokenWorks", + Body: `{ + "type": "m.login.token", + "token": "atoken", + "device_id": "adevice" + }`, + WantUsername: "@auser:example.com", + WantDeviceID: "adevice", + WantDeletedTokens: []string{"atoken"}, + }, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + var accountDB fakeAccountDB + var userAPI fakeUserInternalAPI + cfg := &config.ClientAPI{ + Matrix: &config.Global{ + ServerName: serverName, + }, + } + login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &accountDB, &userAPI, cfg) + if err != nil { + t.Fatalf("LoginFromJSONReader failed: %+v", err) + } + cleanup(ctx, &util.JSONResponse{Code: http.StatusOK}) + + if login.Username() != tst.WantUsername { + t.Errorf("Username: got %q, want %q", login.Username(), tst.WantUsername) + } + + if login.DeviceID == nil { + if tst.WantDeviceID != "" { + t.Errorf("DeviceID: got %v, want %q", login.DeviceID, tst.WantDeviceID) + } + } else { + if *login.DeviceID != tst.WantDeviceID { + t.Errorf("DeviceID: got %q, want %q", *login.DeviceID, tst.WantDeviceID) + } + } + + if !reflect.DeepEqual(userAPI.DeletedTokens, tst.WantDeletedTokens) { + t.Errorf("DeletedTokens: got %+v, want %+v", userAPI.DeletedTokens, tst.WantDeletedTokens) + } + }) + } +} + +func TestBadLoginFromJSONReader(t *testing.T) { + ctx := context.Background() + + tsts := []struct { + Name string + Body string + + WantErrCode string + }{ + {Name: "empty", WantErrCode: "M_BAD_JSON"}, + { + Name: "badUnmarshal", + Body: `badsyntaxJSON`, + WantErrCode: "M_BAD_JSON", + }, + { + Name: "badPassword", + Body: `{ + "type": "m.login.password", + "identifier": { "type": "m.id.user", "user": "alice" }, + "password": "invalidpassword", + "device_id": "adevice" + }`, + WantErrCode: "M_FORBIDDEN", + }, + { + Name: "badToken", + Body: `{ + "type": "m.login.token", + "token": "invalidtoken", + "device_id": "adevice" + }`, + WantErrCode: "M_FORBIDDEN", + }, + { + Name: "badType", + Body: `{ + "type": "m.login.invalid", + "device_id": "adevice" + }`, + WantErrCode: "M_INVALID_ARGUMENT_VALUE", + }, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + var accountDB fakeAccountDB + var userAPI fakeUserInternalAPI + cfg := &config.ClientAPI{ + Matrix: &config.Global{ + ServerName: serverName, + }, + } + _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &accountDB, &userAPI, cfg) + if errRes == nil { + cleanup(ctx, nil) + t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) + } else if merr, ok := errRes.JSON.(*jsonerror.MatrixError); ok && merr.ErrCode != tst.WantErrCode { + t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) + } + }) + } +} + +type fakeAccountDB struct { + AccountDatabase +} + +func (*fakeAccountDB) GetAccountByPassword(ctx context.Context, localpart, password string) (*uapi.Account, error) { + if password == "invalidpassword" { + return nil, sql.ErrNoRows + } + + return &uapi.Account{}, nil +} + +type fakeUserInternalAPI struct { + UserInternalAPIForLogin + + DeletedTokens []string +} + +func (ua *fakeUserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *uapi.PerformLoginTokenDeletionRequest, res *uapi.PerformLoginTokenDeletionResponse) error { + ua.DeletedTokens = append(ua.DeletedTokens, req.Token) + return nil +} + +func (*fakeUserInternalAPI) QueryLoginToken(ctx context.Context, req *uapi.QueryLoginTokenRequest, res *uapi.QueryLoginTokenResponse) error { + if req.Token == "invalidtoken" { + return nil + } + + res.Data = &uapi.LoginTokenData{UserID: "@auser:example.com"} + return nil +} diff --git a/clientapi/auth/login_token.go b/clientapi/auth/login_token.go new file mode 100644 index 000000000..845eb5de9 --- /dev/null +++ b/clientapi/auth/login_token.go @@ -0,0 +1,83 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +// LoginTypeToken describes how to authenticate with a login token. +type LoginTypeToken struct { + UserAPI uapi.LoginTokenInternalAPI + Config *config.ClientAPI +} + +// Name implements Type. +func (t *LoginTypeToken) Name() string { + return authtypes.LoginTypeToken +} + +// LoginFromJSON implements Type. The cleanup function deletes the token from +// the database on success. +func (t *LoginTypeToken) LoginFromJSON(ctx context.Context, reqBytes []byte) (*Login, LoginCleanupFunc, *util.JSONResponse) { + var r loginTokenRequest + if err := httputil.UnmarshalJSON(reqBytes, &r); err != nil { + return nil, nil, err + } + + var res uapi.QueryLoginTokenResponse + if err := t.UserAPI.QueryLoginToken(ctx, &uapi.QueryLoginTokenRequest{Token: r.Token}, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("UserAPI.QueryLoginToken failed") + jsonErr := jsonerror.InternalServerError() + return nil, nil, &jsonErr + } + if res.Data == nil { + return nil, nil, &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("invalid login token"), + } + } + + r.Login.Identifier.Type = "m.id.user" + r.Login.Identifier.User = res.Data.UserID + + cleanup := func(ctx context.Context, authRes *util.JSONResponse) { + if authRes == nil { + util.GetLogger(ctx).Error("No JSONResponse provided to LoginTokenType cleanup function") + return + } + if authRes.Code == http.StatusOK { + var res uapi.PerformLoginTokenDeletionResponse + if err := t.UserAPI.PerformLoginTokenDeletion(ctx, &uapi.PerformLoginTokenDeletionRequest{Token: r.Token}, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("UserAPI.PerformLoginTokenDeletion failed") + } + } + } + return &r.Login, cleanup, nil +} + +// loginTokenRequest struct to hold the possible parameters from an HTTP request. +type loginTokenRequest struct { + Login + Token string `json:"token"` +} diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index 9179d8da1..18cf94979 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -20,6 +20,8 @@ import ( "net/http" "strings" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" @@ -41,16 +43,26 @@ type LoginTypePassword struct { } func (t *LoginTypePassword) Name() string { - return "m.login.password" + return authtypes.LoginTypePassword } -func (t *LoginTypePassword) Request() interface{} { - return &PasswordRequest{} +func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte) (*Login, LoginCleanupFunc, *util.JSONResponse) { + var r PasswordRequest + if err := httputil.UnmarshalJSON(reqBytes, &r); err != nil { + return nil, nil, err + } + + login, err := t.Login(ctx, &r) + if err != nil { + return nil, nil, err + } + + return login, func(context.Context, *util.JSONResponse) {}, nil } func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) { r := req.(*PasswordRequest) - username := r.Username() + username := strings.ToLower(r.Username()) if username == "" { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, diff --git a/clientapi/auth/user_interactive.go b/clientapi/auth/user_interactive.go index 30469fc47..9cab7956c 100644 --- a/clientapi/auth/user_interactive.go +++ b/clientapi/auth/user_interactive.go @@ -32,22 +32,24 @@ import ( type Type interface { // Name returns the name of the auth type e.g `m.login.password` Name() string - // Request returns a pointer to a new request body struct to unmarshal into. - Request() interface{} // Login with the auth type, returning an error response on failure. // Not all types support login, only m.login.password and m.login.token // See https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-login - // `req` is guaranteed to be the type returned from Request() // This function will be called when doing login and when doing 'sudo' style // actions e.g deleting devices. The response must be a 401 as per: // "If the homeserver decides that an attempt on a stage was unsuccessful, but the // client may make a second attempt, it returns the same HTTP status 401 response as above, // with the addition of the standard errcode and error fields describing the error." - Login(ctx context.Context, req interface{}) (login *Login, errRes *util.JSONResponse) + // + // The returned cleanup function must be non-nil on success, and will be called after + // authorization has been completed. Its argument is the final result of authorization. + LoginFromJSON(ctx context.Context, reqBytes []byte) (login *Login, cleanup LoginCleanupFunc, errRes *util.JSONResponse) // TODO: Extend to support Register() flow // Register(ctx context.Context, sessionID string, req interface{}) } +type LoginCleanupFunc func(context.Context, *util.JSONResponse) + // LoginIdentifier represents identifier types // https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types type LoginIdentifier struct { @@ -61,11 +63,8 @@ type LoginIdentifier struct { // Login represents the shared fields used in all forms of login/sudo endpoints. type Login struct { - Type string `json:"type"` - Identifier LoginIdentifier `json:"identifier"` - User string `json:"user"` // deprecated in favour of identifier - Medium string `json:"medium"` // deprecated in favour of identifier - Address string `json:"address"` // deprecated in favour of identifier + LoginIdentifier // Flat fields deprecated in favour of `identifier`. + Identifier LoginIdentifier `json:"identifier"` // Both DeviceID and InitialDisplayName can be omitted, or empty strings ("") // Thus a pointer is needed to differentiate between the two @@ -111,12 +110,11 @@ type UserInteractive struct { Sessions map[string][]string } -func NewUserInteractive(getAccByPass GetAccountByPassword, cfg *config.ClientAPI) *UserInteractive { +func NewUserInteractive(accountDB AccountDatabase, cfg *config.ClientAPI) *UserInteractive { typePassword := &LoginTypePassword{ - GetAccountByPassword: getAccByPass, + GetAccountByPassword: accountDB.GetAccountByPassword, Config: cfg, } - // TODO: Add SSO login return &UserInteractive{ Completed: []string{}, Flows: []userInteractiveFlow{ @@ -236,18 +234,13 @@ func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device * } } - r := loginType.Request() - if err := json.Unmarshal([]byte(gjson.GetBytes(bodyBytes, "auth").Raw), r); err != nil { - return nil, &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), - } + login, cleanup, resErr := loginType.LoginFromJSON(ctx, []byte(gjson.GetBytes(bodyBytes, "auth").Raw)) + if resErr != nil { + return nil, u.ResponseWithChallenge(sessionID, resErr.JSON) } - login, resErr := loginType.Login(ctx, r) - if resErr == nil { - u.AddCompletedStage(sessionID, authType) - // TODO: Check if there's more stages to go and return an error - return login, nil - } - return nil, u.ResponseWithChallenge(sessionID, resErr.JSON) + + u.AddCompletedStage(sessionID, authType) + cleanup(ctx, nil) + // TODO: Check if there's more stages to go and return an error + return login, nil } diff --git a/clientapi/auth/user_interactive_test.go b/clientapi/auth/user_interactive_test.go index 0b7df3545..76d161a74 100644 --- a/clientapi/auth/user_interactive_test.go +++ b/clientapi/auth/user_interactive_test.go @@ -24,7 +24,11 @@ var ( } ) -func getAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) { +type fakeAccountDatabase struct { + AccountDatabase +} + +func (*fakeAccountDatabase) GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) { acc, ok := lookup[localpart+" "+plaintextPassword] if !ok { return nil, fmt.Errorf("unknown user/password") @@ -38,7 +42,7 @@ func setup() *UserInteractive { ServerName: serverName, }, } - return NewUserInteractive(getAccountByPassword, cfg) + return NewUserInteractive(&fakeAccountDatabase{}, cfg) } func TestUserInteractiveChallenge(t *testing.T) { diff --git a/clientapi/httputil/httputil.go b/clientapi/httputil/httputil.go index 29d7b0b37..b47701368 100644 --- a/clientapi/httputil/httputil.go +++ b/clientapi/httputil/httputil.go @@ -36,6 +36,10 @@ func UnmarshalJSONRequest(req *http.Request, iface interface{}) *util.JSONRespon return &resp } + return UnmarshalJSON(body, iface) +} + +func UnmarshalJSON(body []byte, iface interface{}) *util.JSONResponse { if !utf8.Valid(body) { return &util.JSONResponse{ Code: http.StatusBadRequest, diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 589efe0b2..b48b9e93b 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -19,7 +19,6 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" @@ -65,21 +64,14 @@ func Login( JSON: passwordLogin(), } } else if req.Method == http.MethodPost { - typePassword := auth.LoginTypePassword{ - GetAccountByPassword: accountDB.GetAccountByPassword, - Config: cfg, - } - r := typePassword.Request() - resErr := httputil.UnmarshalJSONRequest(req, r) - if resErr != nil { - return *resErr - } - login, authErr := typePassword.Login(req.Context(), r) + login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, accountDB, userAPI, cfg) if authErr != nil { return *authErr } // make a device/access token - return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent()) + authErr2 := completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent()) + cleanup(req.Context(), &authErr2) + return authErr2 } return util.JSONResponse{ Code: http.StatusMethodNotAllowed, diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 9263c66bb..732066166 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -62,7 +62,7 @@ func Setup( mscCfg *config.MSCs, ) { rateLimits := httputil.NewRateLimits(&cfg.RateLimiting) - userInteractiveAuth := auth.NewUserInteractive(accountDB.GetAccountByPassword, cfg) + userInteractiveAuth := auth.NewUserInteractive(accountDB, cfg) unstableFeatures := map[string]bool{ "org.matrix.e2e_cross_signing": true, diff --git a/userapi/api/api.go b/userapi/api/api.go index 04609659c..46a13d971 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -24,6 +24,8 @@ import ( // UserInternalAPI is the internal API for information about users and devices. type UserInternalAPI interface { + LoginTokenInternalAPI + InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error diff --git a/userapi/api/api_logintoken.go b/userapi/api/api_logintoken.go new file mode 100644 index 000000000..f3aa037e4 --- /dev/null +++ b/userapi/api/api_logintoken.go @@ -0,0 +1,69 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "time" +) + +type LoginTokenInternalAPI interface { + // PerformLoginTokenCreation creates a new login token and associates it with the provided data. + PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error + + // PerformLoginTokenDeletion ensures the token doesn't exist. Success + // is returned even if the token didn't exist, or was already deleted. + PerformLoginTokenDeletion(ctx context.Context, req *PerformLoginTokenDeletionRequest, res *PerformLoginTokenDeletionResponse) error + + // QueryLoginToken returns the data associated with a login token. If + // the token is not valid, success is returned, but res.Data == nil. + QueryLoginToken(ctx context.Context, req *QueryLoginTokenRequest, res *QueryLoginTokenResponse) error +} + +// LoginTokenData is the data that can be retrieved given a login token. This is +// provided by the calling code. +type LoginTokenData struct { + // UserID is the full mxid of the user. + UserID string +} + +// LoginTokenMetadata contains metadata created and maintained by the User API. +type LoginTokenMetadata struct { + Token string + Expiration time.Time +} + +type PerformLoginTokenCreationRequest struct { + Data LoginTokenData +} + +type PerformLoginTokenCreationResponse struct { + Metadata LoginTokenMetadata +} + +type PerformLoginTokenDeletionRequest struct { + Token string +} + +type PerformLoginTokenDeletionResponse struct{} + +type QueryLoginTokenRequest struct { + Token string +} + +type QueryLoginTokenResponse struct { + // Data is nil if the token was invalid. + Data *LoginTokenData +} diff --git a/userapi/api/api_trace_logintoken.go b/userapi/api/api_trace_logintoken.go new file mode 100644 index 000000000..e60dae594 --- /dev/null +++ b/userapi/api/api_trace_logintoken.go @@ -0,0 +1,39 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + + "github.com/matrix-org/util" +) + +func (t *UserInternalAPITrace) PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error { + err := t.Impl.PerformLoginTokenCreation(ctx, req, res) + util.GetLogger(ctx).Infof("PerformLoginTokenCreation req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *UserInternalAPITrace) PerformLoginTokenDeletion(ctx context.Context, req *PerformLoginTokenDeletionRequest, res *PerformLoginTokenDeletionResponse) error { + err := t.Impl.PerformLoginTokenDeletion(ctx, req, res) + util.GetLogger(ctx).Infof("PerformLoginTokenDeletion req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *UserInternalAPITrace) QueryLoginToken(ctx context.Context, req *QueryLoginTokenRequest, res *QueryLoginTokenResponse) error { + err := t.Impl.QueryLoginToken(ctx, req, res) + util.GetLogger(ctx).Infof("QueryLoginToken req=%+v res=%+v", js(req), js(res)) + return err +} diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go new file mode 100644 index 000000000..86ffc58f3 --- /dev/null +++ b/userapi/internal/api_logintoken.go @@ -0,0 +1,78 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// PerformLoginTokenCreation creates a new login token and associates it with the provided data. +func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *api.PerformLoginTokenCreationRequest, res *api.PerformLoginTokenCreationResponse) error { + util.GetLogger(ctx).WithField("user_id", req.Data.UserID).Info("PerformLoginTokenCreation") + _, domain, err := gomatrixserverlib.SplitID('@', req.Data.UserID) + if err != nil { + return err + } + if domain != a.ServerName { + return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName) + } + tokenMeta, err := a.DeviceDB.CreateLoginToken(ctx, &req.Data) + if err != nil { + return err + } + res.Metadata = *tokenMeta + return nil +} + +// PerformLoginTokenDeletion ensures the token doesn't exist. +func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error { + util.GetLogger(ctx).Info("PerformLoginTokenDeletion") + return a.DeviceDB.RemoveLoginToken(ctx, req.Token) +} + +// QueryLoginToken returns the data associated with a login token. If +// the token is not valid, success is returned, but res.Data == nil. +func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error { + tokenData, err := a.DeviceDB.GetLoginTokenDataByToken(ctx, req.Token) + if err != nil { + res.Data = nil + if err == sql.ErrNoRows { + return nil + } + return err + } + localpart, domain, err := gomatrixserverlib.SplitID('@', tokenData.UserID) + if err != nil { + return err + } + if domain != a.ServerName { + return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName) + } + if _, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart); err != nil { + res.Data = nil + if err == sql.ErrNoRows { + return nil + } + return err + } + res.Data = tokenData + return nil +} diff --git a/userapi/inthttp/client_logintoken.go b/userapi/inthttp/client_logintoken.go new file mode 100644 index 000000000..366a97099 --- /dev/null +++ b/userapi/inthttp/client_logintoken.go @@ -0,0 +1,65 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inthttp + +import ( + "context" + + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/opentracing/opentracing-go" +) + +const ( + PerformLoginTokenCreationPath = "/userapi/performLoginTokenCreation" + PerformLoginTokenDeletionPath = "/userapi/performLoginTokenDeletion" + QueryLoginTokenPath = "/userapi/queryLoginToken" +) + +func (h *httpUserInternalAPI) PerformLoginTokenCreation( + ctx context.Context, + request *api.PerformLoginTokenCreationRequest, + response *api.PerformLoginTokenCreationResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenCreation") + defer span.Finish() + + apiURL := h.apiURL + PerformLoginTokenCreationPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) PerformLoginTokenDeletion( + ctx context.Context, + request *api.PerformLoginTokenDeletionRequest, + response *api.PerformLoginTokenDeletionResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenDeletion") + defer span.Finish() + + apiURL := h.apiURL + PerformLoginTokenDeletionPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) QueryLoginToken( + ctx context.Context, + request *api.QueryLoginTokenRequest, + response *api.QueryLoginTokenResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLoginToken") + defer span.Finish() + + apiURL := h.apiURL + QueryLoginTokenPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index ac05bcd09..d00ee042c 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -27,6 +27,8 @@ import ( // nolint: gocyclo func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { + addRoutesLoginToken(internalAPIMux, s) + internalAPIMux.Handle(PerformAccountCreationPath, httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse { request := api.PerformAccountCreationRequest{} diff --git a/userapi/inthttp/server_logintoken.go b/userapi/inthttp/server_logintoken.go new file mode 100644 index 000000000..1f2eb34b9 --- /dev/null +++ b/userapi/inthttp/server_logintoken.go @@ -0,0 +1,68 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inthttp + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +// addRoutesLoginToken adds routes for all login token API calls. +func addRoutesLoginToken(internalAPIMux *mux.Router, s api.UserInternalAPI) { + internalAPIMux.Handle(PerformLoginTokenCreationPath, + httputil.MakeInternalAPI("performLoginTokenCreation", func(req *http.Request) util.JSONResponse { + request := api.PerformLoginTokenCreationRequest{} + response := api.PerformLoginTokenCreationResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformLoginTokenCreation(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(PerformLoginTokenDeletionPath, + httputil.MakeInternalAPI("performLoginTokenDeletion", func(req *http.Request) util.JSONResponse { + request := api.PerformLoginTokenDeletionRequest{} + response := api.PerformLoginTokenDeletionResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformLoginTokenDeletion(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(QueryLoginTokenPath, + httputil.MakeInternalAPI("queryLoginToken", func(req *http.Request) util.JSONResponse { + request := api.QueryLoginTokenRequest{} + response := api.QueryLoginTokenResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryLoginToken(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) +} diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go index 95fe99f33..8ff91cf1c 100644 --- a/userapi/storage/devices/interface.go +++ b/userapi/storage/devices/interface.go @@ -38,4 +38,15 @@ type Database interface { RemoveDevices(ctx context.Context, localpart string, devices []string) error // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) + + // CreateLoginToken generates a token, stores and returns it. The lifetime is + // determined by the loginTokenLifetime given to the Database constructor. + CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) + + // RemoveLoginToken removes the named token (and may clean up other expired tokens). + RemoveLoginToken(ctx context.Context, token string) error + + // GetLoginTokenDataByToken returns the data associated with the given token. + // May return sql.ErrNoRows. + GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) } diff --git a/userapi/storage/devices/postgres/logintoken_table.go b/userapi/storage/devices/postgres/logintoken_table.go new file mode 100644 index 000000000..f601fc7db --- /dev/null +++ b/userapi/storage/devices/postgres/logintoken_table.go @@ -0,0 +1,93 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +type loginTokenStatements struct { + insertStmt *sql.Stmt + deleteStmt *sql.Stmt + selectByTokenStmt *sql.Stmt +} + +// execSchema ensures tables and indices exist. +func (s *loginTokenStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(` +CREATE TABLE IF NOT EXISTS login_tokens ( + -- The random value of the token issued to a user + token TEXT NOT NULL PRIMARY KEY, + -- When the token expires + token_expires_at TIMESTAMP NOT NULL, + + -- The mxid for this account + user_id TEXT NOT NULL +); + +-- This index allows efficient garbage collection of expired tokens. +CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); +`) + return err +} + +// prepare runs statement preparation. +func (s *loginTokenStatements) prepare(db *sql.DB) error { + return sqlutil.StatementList{ + {&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, + {&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, + {&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"}, + }.Prepare(db) +} + +// insert adds an already generated token to the database. +func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { + stmt := sqlutil.TxStmt(txn, s.insertStmt) + _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) + return err +} + +// deleteByToken removes the named token. +// +// As a simple way to garbage-collect stale tokens, we also remove all expired tokens. +// The login_tokens_expiration_idx index should make that efficient. +func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error { + stmt := sqlutil.TxStmt(txn, s.deleteStmt) + res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) + if err != nil { + return err + } + if n, err := res.RowsAffected(); err == nil && n > 1 { + util.GetLogger(ctx).WithField("num_deleted", n).Infof("Deleted %d login tokens (%d likely additional expired token)", n, n-1) + } + return nil +} + +// selectByToken returns the data associated with the given token. May return sql.ErrNoRows. +func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + var data api.LoginTokenData + err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) + if err != nil { + return nil, err + } + + return &data, nil +} diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go index 485234331..fd9d513f1 100644 --- a/userapi/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -19,6 +19,7 @@ import ( "crypto/rand" "database/sql" "encoding/base64" + "time" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" @@ -27,28 +28,38 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -// The length of generated device IDs -var deviceIDByteLength = 6 +const ( + // The length of generated device IDs + deviceIDByteLength = 6 + loginTokenByteLength = 32 +) // Database represents a device database. type Database struct { - db *sql.DB - devices devicesStatements + db *sql.DB + devices devicesStatements + loginTokens loginTokenStatements + loginTokenLifetime time.Duration } // NewDatabase creates a new device database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) { db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err } - d := devicesStatements{} + var d devicesStatements + var lt loginTokenStatements // Create tables before executing migrations so we don't fail if the table is missing, // and THEN prepare statements so we don't fail due to referencing new columns if err = d.execSchema(db); err != nil { return nil, err } + if err = lt.execSchema(db); err != nil { + return nil, err + } + m := sqlutil.NewMigrations() deltas.LoadLastSeenTSIP(m) if err = m.RunDeltas(db, dbProperties); err != nil { @@ -58,8 +69,11 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.prepare(db, serverName); err != nil { return nil, err } + if err = lt.prepare(db); err != nil { + return nil, err + } - return &Database{db, d}, nil + return &Database{db, d, lt, loginTokenLifetime}, nil } // GetDeviceByAccessToken returns the device matching the given access token. @@ -210,3 +224,47 @@ func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) }) } + +// CreateLoginToken generates a token, stores and returns it. The lifetime is +// determined by the loginTokenLifetime given to the Database constructor. +func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { + tok, err := generateLoginToken() + if err != nil { + return nil, err + } + meta := &api.LoginTokenMetadata{ + Token: tok, + Expiration: time.Now().Add(d.loginTokenLifetime), + } + + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.loginTokens.insert(ctx, txn, meta, data) + }) + if err != nil { + return nil, err + } + + return meta, nil +} + +func generateLoginToken() (string, error) { + b := make([]byte, loginTokenByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// RemoveLoginToken removes the named token (and may clean up other expired tokens). +func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.loginTokens.deleteByToken(ctx, txn, token) + }) +} + +// GetLoginTokenDataByToken returns the data associated with the given token. +// May return sql.ErrNoRows. +func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + return d.loginTokens.selectByToken(ctx, token) +} diff --git a/userapi/storage/devices/sqlite3/logintoken_table.go b/userapi/storage/devices/sqlite3/logintoken_table.go new file mode 100644 index 000000000..75ef272f8 --- /dev/null +++ b/userapi/storage/devices/sqlite3/logintoken_table.go @@ -0,0 +1,93 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +type loginTokenStatements struct { + insertStmt *sql.Stmt + deleteStmt *sql.Stmt + selectByTokenStmt *sql.Stmt +} + +// execSchema ensures tables and indices exist. +func (s *loginTokenStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(` +CREATE TABLE IF NOT EXISTS login_tokens ( + -- The random value of the token issued to a user + token TEXT NOT NULL PRIMARY KEY, + -- When the token expires + token_expires_at TIMESTAMP NOT NULL, + + -- The mxid for this account + user_id TEXT NOT NULL +); + +-- This index allows efficient garbage collection of expired tokens. +CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); +`) + return err +} + +// prepare runs statement preparation. +func (s *loginTokenStatements) prepare(db *sql.DB) error { + return sqlutil.StatementList{ + {&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, + {&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, + {&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"}, + }.Prepare(db) +} + +// insert adds an already generated token to the database. +func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { + stmt := sqlutil.TxStmt(txn, s.insertStmt) + _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) + return err +} + +// deleteByToken removes the named token. +// +// As a simple way to garbage-collect stale tokens, we also remove all expired tokens. +// The login_tokens_expiration_idx index should make that efficient. +func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error { + stmt := sqlutil.TxStmt(txn, s.deleteStmt) + res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) + if err != nil { + return err + } + if n, err := res.RowsAffected(); err == nil && n > 1 { + util.GetLogger(ctx).WithField("num_deleted", n).Infof("Deleted %d login tokens (%d likely additional expired token)", n, n-1) + } + return nil +} + +// selectByToken returns the data associated with the given token. May return sql.ErrNoRows. +func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + var data api.LoginTokenData + err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) + if err != nil { + return nil, err + } + + return &data, nil +} diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index 538644837..6e90413be 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -19,6 +19,7 @@ import ( "crypto/rand" "database/sql" "encoding/base64" + "time" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" @@ -27,30 +28,41 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -// The length of generated device IDs -var deviceIDByteLength = 6 +const ( + // The length of generated device IDs + deviceIDByteLength = 6 + + loginTokenByteLength = 32 +) // Database represents a device database. type Database struct { - db *sql.DB - writer sqlutil.Writer - devices devicesStatements + db *sql.DB + writer sqlutil.Writer + devices devicesStatements + loginTokens loginTokenStatements + loginTokenLifetime time.Duration } // NewDatabase creates a new device database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) { db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err } writer := sqlutil.NewExclusiveWriter() - d := devicesStatements{} + var d devicesStatements + var lt loginTokenStatements // Create tables before executing migrations so we don't fail if the table is missing, // and THEN prepare statements so we don't fail due to referencing new columns if err = d.execSchema(db); err != nil { return nil, err } + if err = lt.execSchema(db); err != nil { + return nil, err + } + m := sqlutil.NewMigrations() deltas.LoadLastSeenTSIP(m) if err = m.RunDeltas(db, dbProperties); err != nil { @@ -59,7 +71,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.prepare(db, writer, serverName); err != nil { return nil, err } - return &Database{db, writer, d}, nil + if err = lt.prepare(db); err != nil { + return nil, err + } + return &Database{db, writer, d, lt, loginTokenLifetime}, nil } // GetDeviceByAccessToken returns the device matching the given access token. @@ -210,3 +225,47 @@ func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) }) } + +// CreateLoginToken generates a token, stores and returns it. The lifetime is +// determined by the loginTokenLifetime given to the Database constructor. +func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { + tok, err := generateLoginToken() + if err != nil { + return nil, err + } + meta := &api.LoginTokenMetadata{ + Token: tok, + Expiration: time.Now().Add(d.loginTokenLifetime), + } + + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.loginTokens.insert(ctx, txn, meta, data) + }) + if err != nil { + return nil, err + } + + return meta, nil +} + +func generateLoginToken() (string, error) { + b := make([]byte, loginTokenByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// RemoveLoginToken removes the named token (and may clean up other expired tokens). +func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.loginTokens.deleteByToken(ctx, txn, token) + }) +} + +// GetLoginTokenDataByToken returns the data associated with the given token. +// May return sql.ErrNoRows. +func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + return d.loginTokens.selectByToken(ctx, token) +} diff --git a/userapi/storage/devices/storage.go b/userapi/storage/devices/storage.go index 3c2034300..15cf8150c 100644 --- a/userapi/storage/devices/storage.go +++ b/userapi/storage/devices/storage.go @@ -19,6 +19,7 @@ package devices import ( "fmt" + "time" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/devices/postgres" @@ -27,13 +28,14 @@ import ( ) // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) -// and sets postgres connection parameters -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (Database, error) { +// and sets postgres connection parameters. loginTokenLifetime determines how long a +// login token from CreateLoginToken is valid. +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName) + return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties, serverName) + return postgres.NewDatabase(dbProperties, serverName, loginTokenLifetime) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/userapi/storage/devices/storage_wasm.go b/userapi/storage/devices/storage_wasm.go index f360f9857..3de7880b9 100644 --- a/userapi/storage/devices/storage_wasm.go +++ b/userapi/storage/devices/storage_wasm.go @@ -16,6 +16,7 @@ package devices import ( "fmt" + "time" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3" @@ -25,10 +26,11 @@ import ( func NewDatabase( dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, + loginTokenLifetime time.Duration, ) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName) + return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/userapi/userapi.go b/userapi/userapi.go index 74702020a..c7e1f6674 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -15,6 +15,8 @@ package userapi import ( + "time" + "github.com/gorilla/mux" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" @@ -26,6 +28,13 @@ import ( "github.com/sirupsen/logrus" ) +// defaultLoginTokenLifetime determines how old a valid token may be. +// +// NOTSPEC: The current spec says "SHOULD be limited to around five +// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low. +// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325). +const defaultLoginTokenLifetime = 2 * time.Minute + // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions // on the given input API. func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { @@ -37,11 +46,21 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { func NewInternalAPI( accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, ) api.UserInternalAPI { - deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName) + deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName, defaultLoginTokenLifetime) if err != nil { logrus.WithError(err).Panicf("failed to connect to device db") } + return newInternalAPI(accountDB, deviceDB, cfg, appServices, keyAPI) +} + +func newInternalAPI( + accountDB accounts.Database, + deviceDB devices.Database, + cfg *config.UserAPI, + appServices []config.ApplicationService, + keyAPI keyapi.KeyInternalAPI, +) api.UserInternalAPI { return &internal.UserInternalAPI{ AccountDB: accountDB, DeviceDB: deviceDB, diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 0141258e6..266f5ed58 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -1,4 +1,18 @@ -package userapi_test +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package userapi import ( "context" @@ -6,15 +20,16 @@ import ( "net/http" "reflect" "testing" + "time" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" ) @@ -23,31 +38,41 @@ const ( serverName = gomatrixserverlib.ServerName("example.com") ) -func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database) { - accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{ - ConnectionString: "file::memory:", - }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS) +type apiTestOpts struct { + loginTokenLifetime time.Duration +} + +func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, accounts.Database) { + if opts.loginTokenLifetime == 0 { + opts.loginTokenLifetime = defaultLoginTokenLifetime + } + dbopts := &config.DatabaseOptions{ + ConnectionString: "file::memory:", + MaxOpenConnections: 1, + MaxIdleConnections: 1, + } + accountDB, err := accounts.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS) if err != nil { t.Fatalf("failed to create account DB: %s", err) } + deviceDB, err := devices.NewDatabase(dbopts, serverName, opts.loginTokenLifetime) + if err != nil { + t.Fatalf("failed to create device DB: %s", err) + } + cfg := &config.UserAPI{ - DeviceDatabase: config.DatabaseOptions{ - ConnectionString: "file::memory:", - MaxOpenConnections: 1, - MaxIdleConnections: 1, - }, Matrix: &config.Global{ ServerName: serverName, }, } - return userapi.NewInternalAPI(accountDB, cfg, nil, nil), accountDB + return newInternalAPI(accountDB, deviceDB, cfg, nil, nil), accountDB } func TestQueryProfile(t *testing.T) { aliceAvatarURL := "mxc://example.com/alice" aliceDisplayName := "Alice" - userAPI, accountDB := MustMakeInternalAPI(t) + userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{}) _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "") if err != nil { t.Fatalf("failed to make account: %s", err) @@ -106,7 +131,7 @@ func TestQueryProfile(t *testing.T) { t.Run("HTTP API", func(t *testing.T) { router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() - userapi.AddInternalRoutes(router, userAPI) + AddInternalRoutes(router, userAPI) apiURL, cancel := test.ListenAndServe(t, router, false) defer cancel() httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) @@ -119,3 +144,115 @@ func TestQueryProfile(t *testing.T) { runCases(userAPI) }) } + +func TestLoginToken(t *testing.T) { + ctx := context.Background() + + t.Run("tokenLoginFlow", func(t *testing.T) { + userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{}) + + _, err := accountDB.CreateAccount(ctx, "auser", "apassword", "") + if err != nil { + t.Fatalf("failed to make account: %s", err) + } + + t.Log("Creating a login token like the SSO callback would...") + + creq := api.PerformLoginTokenCreationRequest{ + Data: api.LoginTokenData{UserID: "@auser:example.com"}, + } + var cresp api.PerformLoginTokenCreationResponse + if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { + t.Fatalf("PerformLoginTokenCreation failed: %v", err) + } + + if cresp.Metadata.Token == "" { + t.Errorf("PerformLoginTokenCreation Token: got %q, want non-empty", cresp.Metadata.Token) + } + if cresp.Metadata.Expiration.Before(time.Now()) { + t.Errorf("PerformLoginTokenCreation Expiration: got %v, want non-expired", cresp.Metadata.Expiration) + } + + t.Log("Querying the login token like /login with m.login.token would...") + + qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} + var qresp api.QueryLoginTokenResponse + if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { + t.Fatalf("QueryLoginToken failed: %v", err) + } + + if qresp.Data == nil { + t.Errorf("QueryLoginToken Data: got %v, want non-nil", qresp.Data) + } else if want := "@auser:example.com"; qresp.Data.UserID != want { + t.Errorf("QueryLoginToken UserID: got %q, want %q", qresp.Data.UserID, want) + } + + t.Log("Deleting the login token like /login with m.login.token would...") + + dreq := api.PerformLoginTokenDeletionRequest{Token: cresp.Metadata.Token} + var dresp api.PerformLoginTokenDeletionResponse + if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { + t.Fatalf("PerformLoginTokenDeletion failed: %v", err) + } + }) + + t.Run("expiredTokenIsNotReturned", func(t *testing.T) { + userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{loginTokenLifetime: -1 * time.Second}) + + creq := api.PerformLoginTokenCreationRequest{ + Data: api.LoginTokenData{UserID: "@auser:example.com"}, + } + var cresp api.PerformLoginTokenCreationResponse + if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { + t.Fatalf("PerformLoginTokenCreation failed: %v", err) + } + + qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} + var qresp api.QueryLoginTokenResponse + if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { + t.Fatalf("QueryLoginToken failed: %v", err) + } + + if qresp.Data != nil { + t.Errorf("QueryLoginToken Data: got %v, want nil", qresp.Data) + } + }) + + t.Run("deleteWorks", func(t *testing.T) { + userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{}) + + creq := api.PerformLoginTokenCreationRequest{ + Data: api.LoginTokenData{UserID: "@auser:example.com"}, + } + var cresp api.PerformLoginTokenCreationResponse + if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { + t.Fatalf("PerformLoginTokenCreation failed: %v", err) + } + + dreq := api.PerformLoginTokenDeletionRequest{Token: cresp.Metadata.Token} + var dresp api.PerformLoginTokenDeletionResponse + if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { + t.Fatalf("PerformLoginTokenDeletion failed: %v", err) + } + + qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} + var qresp api.QueryLoginTokenResponse + if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { + t.Fatalf("QueryLoginToken failed: %v", err) + } + + if qresp.Data != nil { + t.Errorf("QueryLoginToken Data: got %v, want nil", qresp.Data) + } + }) + + t.Run("deleteUnknownIsNoOp", func(t *testing.T) { + userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{}) + + dreq := api.PerformLoginTokenDeletionRequest{Token: "non-existent token"} + var dresp api.PerformLoginTokenDeletionResponse + if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { + t.Fatalf("PerformLoginTokenDeletion failed: %v", err) + } + }) +}