// @author Couchbase <info@couchbase.com> // @copyright 2018 Couchbase, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package scramsha provides implementation of client side SCRAM-SHA // via Http according to https://tools.ietf.org/html/rfc7804 package scramsha import ( "encoding/base64" "github.com/pkg/errors" "io" "io/ioutil" "net/http" "strings" ) // consts used to parse scramsha response from target const ( WWWAuthenticate = "WWW-Authenticate" AuthenticationInfo = "Authentication-Info" Authorization = "Authorization" DataPrefix = "data=" SidPrefix = "sid=" ) // Request provides implementation of http request that can be retried type Request struct { body io.ReadSeeker // Embed an HTTP request directly. This makes a *Request act exactly // like an *http.Request so that all meta methods are supported. *http.Request } type lenReader interface { Len() int } // NewRequest creates http request that can be retried func NewRequest(method, url string, body io.ReadSeeker) (*Request, error) { // Wrap the body in a noop ReadCloser if non-nil. This prevents the // reader from being closed by the HTTP client. var rcBody io.ReadCloser if body != nil { rcBody = ioutil.NopCloser(body) } // Make the request with the noop-closer for the body. httpReq, err := http.NewRequest(method, url, rcBody) if err != nil { return nil, err } // Check if we can set the Content-Length automatically. if lr, ok := body.(lenReader); ok { httpReq.ContentLength = int64(lr.Len()) } return &Request{body, httpReq}, nil } func encode(str string) string { return base64.StdEncoding.EncodeToString([]byte(str)) } func decode(str string) (string, error) { bytes, err := base64.StdEncoding.DecodeString(str) if err != nil { return "", errors.Errorf("Cannot base64 decode %s", str) } return string(bytes), err } func trimPrefix(s, prefix string) (string, error) { l := len(s) trimmed := strings.TrimPrefix(s, prefix) if l == len(trimmed) { return trimmed, errors.Errorf("Prefix %s not found in %s", prefix, s) } return trimmed, nil } func drainBody(resp *http.Response) { defer resp.Body.Close() io.Copy(ioutil.Discard, resp.Body) } // DoScramSha performs SCRAM-SHA handshake via Http func DoScramSha(req *Request, username string, password string, client *http.Client) (*http.Response, error) { method := "SCRAM-SHA-512" s, err := NewScramSha("SCRAM-SHA512") if err != nil { return nil, errors.Wrap(err, "Unable to initialize SCRAM-SHA handler") } message, err := s.GetStartRequest(username) if err != nil { return nil, err } encodedMessage := method + " " + DataPrefix + encode(message) req.Header.Set(Authorization, encodedMessage) res, err := client.Do(req.Request) if err != nil { return nil, errors.Wrap(err, "Problem sending SCRAM-SHA start"+ "request") } if res.StatusCode != http.StatusUnauthorized { return res, nil } authHeader := res.Header.Get(WWWAuthenticate) if authHeader == "" { drainBody(res) return nil, errors.Errorf("Header %s is not populated in "+ "SCRAM-SHA start response", WWWAuthenticate) } authHeader, err = trimPrefix(authHeader, method+" ") if err != nil { if strings.HasPrefix(authHeader, "Basic ") { // user not found return res, nil } drainBody(res) return nil, errors.Wrapf(err, "Error while parsing SCRAM-SHA "+ "start response %s", authHeader) } drainBody(res) sid, response, err := parseSidAndData(authHeader) if err != nil { return nil, errors.Wrapf(err, "Error while parsing SCRAM-SHA "+ "start response %s", authHeader) } err = s.HandleStartResponse(response) if err != nil { return nil, errors.Wrapf(err, "Error parsing SCRAM-SHA start "+ "response %s", response) } message = s.GetFinalRequest(password) encodedMessage = method + " " + SidPrefix + sid + "," + DataPrefix + encode(message) req.Header.Set(Authorization, encodedMessage) // rewind request body so it can be resent again if req.body != nil { if _, err = req.body.Seek(0, 0); err != nil { return nil, errors.Errorf("Failed to seek body: %v", err) } } res, err = client.Do(req.Request) if err != nil { return nil, errors.Wrap(err, "Problem sending SCRAM-SHA final"+ "request") } if res.StatusCode == http.StatusUnauthorized { // TODO retrieve and return error return res, nil } if res.StatusCode >= http.StatusInternalServerError { // in this case we cannot expect server to set headers properly return res, nil } authHeader = res.Header.Get(AuthenticationInfo) if authHeader == "" { drainBody(res) return nil, errors.Errorf("Header %s is not populated in "+ "SCRAM-SHA final response", AuthenticationInfo) } finalSid, response, err := parseSidAndData(authHeader) if err != nil { drainBody(res) return nil, errors.Wrapf(err, "Error while parsing SCRAM-SHA "+ "final response %s", authHeader) } if finalSid != sid { drainBody(res) return nil, errors.Errorf("Sid %s returned by server "+ "doesn't match the original sid %s", finalSid, sid) } err = s.HandleFinalResponse(response) if err != nil { drainBody(res) return nil, errors.Wrapf(err, "Error handling SCRAM-SHA final server response %s", response) } return res, nil } func parseSidAndData(authHeader string) (string, string, error) { sidIndex := strings.Index(authHeader, SidPrefix) if sidIndex < 0 { return "", "", errors.Errorf("Cannot find %s in %s", SidPrefix, authHeader) } sidEndIndex := strings.Index(authHeader, ",") if sidEndIndex < 0 { return "", "", errors.Errorf("Cannot find ',' in %s", authHeader) } sid := authHeader[sidIndex+len(SidPrefix) : sidEndIndex] dataIndex := strings.Index(authHeader, DataPrefix) if dataIndex < 0 { return "", "", errors.Errorf("Cannot find %s in %s", DataPrefix, authHeader) } data, err := decode(authHeader[dataIndex+len(DataPrefix):]) if err != nil { return "", "", err } return sid, data, nil }