diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/login.go b/src/github.com/matrix-org/dendrite/clientapi/routing/login.go index e0a4e6327..3804da47e 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/login.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/login.go @@ -16,13 +16,13 @@ package routing import ( "net/http" - "strings" "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -82,24 +82,11 @@ func Login( util.GetLogger(req.Context()).WithField("user", r.User).Info("Processing login request") - // r.User can either be a user ID or just the localpart... or other things maybe. - localpart := r.User - if strings.HasPrefix(r.User, "@") { - var domain gomatrixserverlib.ServerName - var err error - localpart, domain, err = gomatrixserverlib.SplitID('@', r.User) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Invalid username"), - } - } - - if domain != cfg.Matrix.ServerName { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("User ID not ours"), - } + localpart, err := userutil.ParseUsernameParam(r.User, &cfg.Matrix.ServerName) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(err.Error()), } } diff --git a/src/github.com/matrix-org/dendrite/clientapi/userutil/userutil.go b/src/github.com/matrix-org/dendrite/clientapi/userutil/userutil.go new file mode 100644 index 000000000..de2d1959f --- /dev/null +++ b/src/github.com/matrix-org/dendrite/clientapi/userutil/userutil.go @@ -0,0 +1,43 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package userutil + +import ( + "errors" + "strings" + + "github.com/matrix-org/gomatrixserverlib" +) + +// ParseUsernameParam extracts localpart from usernameParam. +// usernameParam can either be a user ID or just the localpart/username. +// If serverName is passed, it is verified against the domain obtained from usernameParam (if present) +// Returns error in case of invalid usernameParam. +func ParseUsernameParam(usernameParam string, expectedServerName *gomatrixserverlib.ServerName) (string, error) { + localpart := usernameParam + + if strings.HasPrefix(usernameParam, "@") { + lp, domain, err := gomatrixserverlib.SplitID('@', usernameParam) + + if err != nil { + return "", errors.New("Invalid username") + } + + if expectedServerName != nil && domain != *expectedServerName { + return "", errors.New("User ID does not belong to this server") + } + + localpart = lp + } + return localpart, nil +} diff --git a/src/github.com/matrix-org/dendrite/clientapi/userutil/userutil_test.go b/src/github.com/matrix-org/dendrite/clientapi/userutil/userutil_test.go new file mode 100644 index 000000000..2628642fb --- /dev/null +++ b/src/github.com/matrix-org/dendrite/clientapi/userutil/userutil_test.go @@ -0,0 +1,71 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package userutil + +import ( + "testing" + + "github.com/matrix-org/gomatrixserverlib" +) + +var ( + localpart = "somelocalpart" + serverName gomatrixserverlib.ServerName = "someservername" + invalidServerName gomatrixserverlib.ServerName = "invalidservername" + goodUserID = "@" + localpart + ":" + string(serverName) + badUserID = "@bad:user:name@noservername:" +) + +// TestGoodUserID checks that correct localpart is returned for a valid user ID. +func TestGoodUserID(t *testing.T) { + lp, err := ParseUsernameParam(goodUserID, &serverName) + + if err != nil { + t.Error("User ID Parsing failed for ", goodUserID, " with error: ", err.Error()) + } + + if lp != localpart { + t.Error("Incorrect username, returned: ", lp, " should be: ", localpart) + } +} + +// TestWithLocalpartOnly checks that localpart is returned when usernameParam contains only localpart. +func TestWithLocalpartOnly(t *testing.T) { + lp, err := ParseUsernameParam(localpart, &serverName) + + if err != nil { + t.Error("User ID Parsing failed for ", localpart, " with error: ", err.Error()) + } + + if lp != localpart { + t.Error("Incorrect username, returned: ", lp, " should be: ", localpart) + } +} + +// TestIncorrectDomain checks for error when there's server name mismatch. +func TestIncorrectDomain(t *testing.T) { + _, err := ParseUsernameParam(goodUserID, &invalidServerName) + + if err == nil { + t.Error("Invalid Domain should return an error") + } +} + +// TestBadUserID checks that ParseUsernameParam fails for invalid user ID +func TestBadUserID(t *testing.T) { + _, err := ParseUsernameParam(badUserID, &serverName) + + if err == nil { + t.Error("Illegal User ID should return an error") + } +}