Fix Type
suffix Go codegen for cross-package references (#4110)
* Avoid adding "Type" suffix unnecessarily * Fix `Type` suffix for cross-package references * Fix mixxing imports and format code
This commit is contained in:
parent
984f8590f8
commit
65899569ee
|
@ -22,6 +22,7 @@ import (
|
|||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"io"
|
||||
"path"
|
||||
"reflect"
|
||||
|
@ -31,6 +32,7 @@ import (
|
|||
"unicode"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/pulumi/pulumi/pkg/codegen/schema"
|
||||
"github.com/pulumi/pulumi/pkg/util/contract"
|
||||
)
|
||||
|
@ -96,6 +98,7 @@ type pkgContext struct {
|
|||
functionNames map[*schema.Function]string
|
||||
needsUtils bool
|
||||
tool string
|
||||
packages map[string]*pkgContext
|
||||
|
||||
// Name overrides set in GoInfo
|
||||
modToPkg map[string]string // Module name -> package name
|
||||
|
@ -123,14 +126,18 @@ func (pkg *pkgContext) tokenToType(tok string) string {
|
|||
mod = override
|
||||
}
|
||||
|
||||
// If the package containing the type's token already has a resource with the
|
||||
// same name, add a `Type` suffix.
|
||||
modPkg := pkg.getPkg(mod)
|
||||
name = title(name)
|
||||
if modPkg.names.has(name) {
|
||||
name += "Type"
|
||||
}
|
||||
|
||||
if mod == pkg.mod {
|
||||
name := title(name)
|
||||
if pkg.names.has(name) {
|
||||
name += "Type"
|
||||
}
|
||||
return name
|
||||
}
|
||||
return strings.Replace(mod, "/", "", -1) + "." + title(name)
|
||||
return strings.Replace(mod, "/", "", -1) + "." + name
|
||||
}
|
||||
|
||||
func tokenToName(tok string) string {
|
||||
|
@ -749,16 +756,16 @@ func (pkg *pkgContext) genTypeRegistrations(w io.Writer, types []*schema.ObjectT
|
|||
fmt.Fprintf(w, "}\n")
|
||||
}
|
||||
|
||||
func (pkg *pkgContext) getTypeImports(t schema.Type, recurse bool, imports stringSet) {
|
||||
func (pkg *pkgContext) getTypeImports(t schema.Type, recurse bool, imports stringSet, seen map[schema.Type]struct{}) {
|
||||
if _, ok := seen[t]; ok {
|
||||
return
|
||||
}
|
||||
seen[t] = struct{}{}
|
||||
switch t := t.(type) {
|
||||
case *schema.ArrayType:
|
||||
if recurse {
|
||||
pkg.getTypeImports(t.ElementType, false, imports)
|
||||
}
|
||||
pkg.getTypeImports(t.ElementType, recurse, imports, seen)
|
||||
case *schema.MapType:
|
||||
if recurse {
|
||||
pkg.getTypeImports(t.ElementType, false, imports)
|
||||
}
|
||||
pkg.getTypeImports(t.ElementType, recurse, imports, seen)
|
||||
case *schema.ObjectType:
|
||||
mod := pkg.pkg.TokenToModule(t.Token)
|
||||
if override, ok := pkg.modToPkg[mod]; ok {
|
||||
|
@ -770,28 +777,27 @@ func (pkg *pkgContext) getTypeImports(t schema.Type, recurse bool, imports strin
|
|||
|
||||
for _, p := range t.Properties {
|
||||
if recurse {
|
||||
pkg.getTypeImports(p.Type, false, imports)
|
||||
pkg.getTypeImports(p.Type, recurse, imports, seen)
|
||||
}
|
||||
}
|
||||
case *schema.UnionType:
|
||||
for _, e := range t.ElementTypes {
|
||||
if recurse {
|
||||
pkg.getTypeImports(e, false, imports)
|
||||
}
|
||||
pkg.getTypeImports(e, recurse, imports, seen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pkg *pkgContext) getImports(member interface{}, imports stringSet) {
|
||||
seen := map[schema.Type]struct{}{}
|
||||
switch member := member.(type) {
|
||||
case *schema.ObjectType:
|
||||
pkg.getTypeImports(member, true, imports)
|
||||
pkg.getTypeImports(member, true, imports, seen)
|
||||
case *schema.Resource:
|
||||
for _, p := range member.Properties {
|
||||
pkg.getTypeImports(p.Type, false, imports)
|
||||
pkg.getTypeImports(p.Type, false, imports, seen)
|
||||
}
|
||||
for _, p := range member.InputProperties {
|
||||
pkg.getTypeImports(p.Type, false, imports)
|
||||
pkg.getTypeImports(p.Type, false, imports, seen)
|
||||
|
||||
if p.IsRequired {
|
||||
imports.add("github.com/pkg/errors")
|
||||
|
@ -799,14 +805,14 @@ func (pkg *pkgContext) getImports(member interface{}, imports stringSet) {
|
|||
}
|
||||
case *schema.Function:
|
||||
if member.Inputs != nil {
|
||||
pkg.getTypeImports(member.Inputs, false, imports)
|
||||
pkg.getTypeImports(member.Inputs, false, imports, seen)
|
||||
}
|
||||
if member.Outputs != nil {
|
||||
pkg.getTypeImports(member.Outputs, false, imports)
|
||||
pkg.getTypeImports(member.Outputs, false, imports, seen)
|
||||
}
|
||||
case []*schema.Property:
|
||||
for _, p := range member {
|
||||
pkg.getTypeImports(p.Type, false, imports)
|
||||
pkg.getTypeImports(p.Type, false, imports, seen)
|
||||
}
|
||||
default:
|
||||
return
|
||||
|
@ -918,6 +924,17 @@ func (pkg *pkgContext) genPackageRegistration(w io.Writer) {
|
|||
fmt.Fprintf(w, "}\n")
|
||||
}
|
||||
|
||||
func (pkg *pkgContext) getPkg(mod string) *pkgContext {
|
||||
if override, ok := pkg.modToPkg[mod]; ok {
|
||||
mod = override
|
||||
}
|
||||
pack, ok := pkg.packages[mod]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return pack
|
||||
}
|
||||
|
||||
// GoInfo holds information required to generate the Go SDK from a schema.
|
||||
type GoInfo struct {
|
||||
// Base path for package imports
|
||||
|
@ -965,6 +982,7 @@ func GeneratePackage(tool string, pkg *schema.Package) (map[string][]byte, error
|
|||
tool: tool,
|
||||
modToPkg: goInfo.ModuleToPackage,
|
||||
pkgImportAliases: goInfo.PackageImportAliases,
|
||||
packages: packages,
|
||||
}
|
||||
packages[mod] = pack
|
||||
}
|
||||
|
@ -1061,7 +1079,14 @@ func GeneratePackage(tool string, pkg *schema.Package) (map[string][]byte, error
|
|||
if _, ok := files[relPath]; ok {
|
||||
panic(errors.Errorf("duplicate file: %s", relPath))
|
||||
}
|
||||
files[relPath] = []byte(contents)
|
||||
|
||||
// Run Go formatter on the code before saving to disk
|
||||
formattedSource, err := format.Source([]byte(contents))
|
||||
if err != nil {
|
||||
panic(errors.Errorf("invalid Go source code:\n\n%s", contents))
|
||||
}
|
||||
|
||||
files[relPath] = formattedSource
|
||||
}
|
||||
|
||||
name, registerPackage := pkg.Name, pkg.Provider != nil
|
||||
|
|
Loading…
Reference in a new issue