Fix Type suffix Go codegen for cross-package references (#4110)

* Avoid adding "Type" suffix unnecessarily

* Fix `Type` suffix for cross-package references

* Fix mixxing imports and format code
This commit is contained in:
Luke Hoban 2020-03-19 08:32:40 -07:00 committed by GitHub
parent 984f8590f8
commit 65899569ee
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -22,6 +22,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"go/format"
"io"
"path"
"reflect"
@ -31,6 +32,7 @@ import (
"unicode"
"github.com/pkg/errors"
"github.com/pulumi/pulumi/pkg/codegen/schema"
"github.com/pulumi/pulumi/pkg/util/contract"
)
@ -96,6 +98,7 @@ type pkgContext struct {
functionNames map[*schema.Function]string
needsUtils bool
tool string
packages map[string]*pkgContext
// Name overrides set in GoInfo
modToPkg map[string]string // Module name -> package name
@ -123,14 +126,18 @@ func (pkg *pkgContext) tokenToType(tok string) string {
mod = override
}
// If the package containing the type's token already has a resource with the
// same name, add a `Type` suffix.
modPkg := pkg.getPkg(mod)
name = title(name)
if modPkg.names.has(name) {
name += "Type"
}
if mod == pkg.mod {
name := title(name)
if pkg.names.has(name) {
name += "Type"
}
return name
}
return strings.Replace(mod, "/", "", -1) + "." + title(name)
return strings.Replace(mod, "/", "", -1) + "." + name
}
func tokenToName(tok string) string {
@ -749,16 +756,16 @@ func (pkg *pkgContext) genTypeRegistrations(w io.Writer, types []*schema.ObjectT
fmt.Fprintf(w, "}\n")
}
func (pkg *pkgContext) getTypeImports(t schema.Type, recurse bool, imports stringSet) {
func (pkg *pkgContext) getTypeImports(t schema.Type, recurse bool, imports stringSet, seen map[schema.Type]struct{}) {
if _, ok := seen[t]; ok {
return
}
seen[t] = struct{}{}
switch t := t.(type) {
case *schema.ArrayType:
if recurse {
pkg.getTypeImports(t.ElementType, false, imports)
}
pkg.getTypeImports(t.ElementType, recurse, imports, seen)
case *schema.MapType:
if recurse {
pkg.getTypeImports(t.ElementType, false, imports)
}
pkg.getTypeImports(t.ElementType, recurse, imports, seen)
case *schema.ObjectType:
mod := pkg.pkg.TokenToModule(t.Token)
if override, ok := pkg.modToPkg[mod]; ok {
@ -770,28 +777,27 @@ func (pkg *pkgContext) getTypeImports(t schema.Type, recurse bool, imports strin
for _, p := range t.Properties {
if recurse {
pkg.getTypeImports(p.Type, false, imports)
pkg.getTypeImports(p.Type, recurse, imports, seen)
}
}
case *schema.UnionType:
for _, e := range t.ElementTypes {
if recurse {
pkg.getTypeImports(e, false, imports)
}
pkg.getTypeImports(e, recurse, imports, seen)
}
}
}
func (pkg *pkgContext) getImports(member interface{}, imports stringSet) {
seen := map[schema.Type]struct{}{}
switch member := member.(type) {
case *schema.ObjectType:
pkg.getTypeImports(member, true, imports)
pkg.getTypeImports(member, true, imports, seen)
case *schema.Resource:
for _, p := range member.Properties {
pkg.getTypeImports(p.Type, false, imports)
pkg.getTypeImports(p.Type, false, imports, seen)
}
for _, p := range member.InputProperties {
pkg.getTypeImports(p.Type, false, imports)
pkg.getTypeImports(p.Type, false, imports, seen)
if p.IsRequired {
imports.add("github.com/pkg/errors")
@ -799,14 +805,14 @@ func (pkg *pkgContext) getImports(member interface{}, imports stringSet) {
}
case *schema.Function:
if member.Inputs != nil {
pkg.getTypeImports(member.Inputs, false, imports)
pkg.getTypeImports(member.Inputs, false, imports, seen)
}
if member.Outputs != nil {
pkg.getTypeImports(member.Outputs, false, imports)
pkg.getTypeImports(member.Outputs, false, imports, seen)
}
case []*schema.Property:
for _, p := range member {
pkg.getTypeImports(p.Type, false, imports)
pkg.getTypeImports(p.Type, false, imports, seen)
}
default:
return
@ -918,6 +924,17 @@ func (pkg *pkgContext) genPackageRegistration(w io.Writer) {
fmt.Fprintf(w, "}\n")
}
func (pkg *pkgContext) getPkg(mod string) *pkgContext {
if override, ok := pkg.modToPkg[mod]; ok {
mod = override
}
pack, ok := pkg.packages[mod]
if !ok {
return nil
}
return pack
}
// GoInfo holds information required to generate the Go SDK from a schema.
type GoInfo struct {
// Base path for package imports
@ -965,6 +982,7 @@ func GeneratePackage(tool string, pkg *schema.Package) (map[string][]byte, error
tool: tool,
modToPkg: goInfo.ModuleToPackage,
pkgImportAliases: goInfo.PackageImportAliases,
packages: packages,
}
packages[mod] = pack
}
@ -1061,7 +1079,14 @@ func GeneratePackage(tool string, pkg *schema.Package) (map[string][]byte, error
if _, ok := files[relPath]; ok {
panic(errors.Errorf("duplicate file: %s", relPath))
}
files[relPath] = []byte(contents)
// Run Go formatter on the code before saving to disk
formattedSource, err := format.Source([]byte(contents))
if err != nil {
panic(errors.Errorf("invalid Go source code:\n\n%s", contents))
}
files[relPath] = formattedSource
}
name, registerPackage := pkg.Name, pkg.Provider != nil