support implicit flow in web-identity.go example (#12600)

when a client secret is not provided,
automatically assume implicit flow
for authentication and invoke
relevant code accordingly.
This commit is contained in:
Harshavardhana 2021-06-30 07:43:04 -07:00 committed by GitHub
parent 4575291f8a
commit 3137dc2eb3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -20,6 +20,7 @@
package main
import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
@ -108,9 +109,34 @@ func init() {
flag.IntVar(&port, "port", 8080, "Port")
}
func implicitFlowURL(c oauth2.Config, state string) string {
var buf bytes.Buffer
buf.WriteString(c.Endpoint.AuthURL)
v := url.Values{
"response_type": {"id_token"},
"response_mode": {"form_post"},
"client_id": {c.ClientID},
}
if c.RedirectURL != "" {
v.Set("redirect_uri", c.RedirectURL)
}
if len(c.Scopes) > 0 {
v.Set("scope", strings.Join(c.Scopes, " "))
}
v.Set("state", state)
v.Set("nonce", state)
if strings.Contains(c.Endpoint.AuthURL, "?") {
buf.WriteByte('&')
} else {
buf.WriteByte('?')
}
buf.WriteString(v.Encode())
return buf.String()
}
func main() {
flag.Parse()
if clientID == "" || clientSec == "" {
if clientID == "" {
flag.PrintDefaults()
return
}
@ -148,29 +174,47 @@ func main() {
http.NotFound(w, r)
return
}
http.Redirect(w, r, config.AuthCodeURL(state), http.StatusFound)
if clientSec != "" {
http.Redirect(w, r, config.AuthCodeURL(state), http.StatusFound)
} else {
http.Redirect(w, r, implicitFlowURL(config, state), http.StatusFound)
}
})
http.HandleFunc("/oauth2/callback", func(w http.ResponseWriter, r *http.Request) {
log.Printf("%s %s", r.Method, r.RequestURI)
if r.URL.Query().Get("state") != state {
if err := r.ParseForm(); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if r.Form.Get("state") != state {
http.Error(w, "state did not match", http.StatusBadRequest)
return
}
getWebTokenExpiry := func() (*credentials.WebIdentityToken, error) {
oauth2Token, err := config.Exchange(ctx, r.URL.Query().Get("code"))
if err != nil {
return nil, err
}
if !oauth2Token.Valid() {
return nil, errors.New("invalid token")
var getWebTokenExpiry func() (*credentials.WebIdentityToken, error)
if clientSec == "" {
getWebTokenExpiry = func() (*credentials.WebIdentityToken, error) {
return &credentials.WebIdentityToken{
Token: r.Form.Get("id_token"),
}, nil
}
} else {
getWebTokenExpiry = func() (*credentials.WebIdentityToken, error) {
oauth2Token, err := config.Exchange(ctx, r.URL.Query().Get("code"))
if err != nil {
return nil, err
}
if !oauth2Token.Valid() {
return nil, errors.New("invalid token")
}
return &credentials.WebIdentityToken{
Token: oauth2Token.Extra("id_token").(string),
Expiry: int(oauth2Token.Expiry.Sub(time.Now().UTC()).Seconds()),
}, nil
return &credentials.WebIdentityToken{
Token: oauth2Token.Extra("id_token").(string),
Expiry: int(oauth2Token.Expiry.Sub(time.Now().UTC()).Seconds()),
}, nil
}
}
sts, err := credentials.NewSTSWebIdentity(stsEndpoint, getWebTokenExpiry)