[codegen/hcl2] Fix descent in RewriteConversions. (#4614)

Descend into anonymous function expressions, the operands to conditional
expressions, and the value in for expressions.
This commit is contained in:
Pat Gavlin 2020-05-13 08:25:26 -07:00 committed by GitHub
parent 71910a1c52
commit 647b6627a2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 96 additions and 0 deletions

View file

@ -40,18 +40,35 @@ func sameSchemaTypes(xt, yt model.Type) bool {
func RewriteConversions(x model.Expression, to model.Type) model.Expression {
switch x := x.(type) {
case *model.AnonymousFunctionExpression:
x.Body = RewriteConversions(x.Body, to)
case *model.BinaryOpExpression:
x.LeftOperand = RewriteConversions(x.LeftOperand, model.InputType(x.LeftOperandType()))
x.RightOperand = RewriteConversions(x.RightOperand, model.InputType(x.RightOperandType()))
case *model.ConditionalExpression:
x.Condition = RewriteConversions(x.Condition, model.InputType(model.BoolType))
x.TrueResult = RewriteConversions(x.TrueResult, to)
x.FalseResult = RewriteConversions(x.FalseResult, to)
diags := x.Typecheck(false)
contract.Assert(len(diags) == 0)
case *model.ForExpression:
traverserType := model.NumberType
if x.Key != nil {
traverserType = model.StringType
x.Key = RewriteConversions(x.Key, model.InputType(model.StringType))
}
if x.Condition != nil {
x.Condition = RewriteConversions(x.Condition, model.InputType(model.BoolType))
}
valueType, diags := to.Traverse(model.MakeTraverser(traverserType))
contract.Ignore(diags)
x.Value = RewriteConversions(x.Value, valueType.(model.Type))
diags = x.Typecheck(false)
contract.Assert(len(diags) == 0)
case *model.FunctionCallExpression:
args := x.Args
for _, param := range x.Signature.Parameters {

View file

@ -48,9 +48,51 @@ func TestRewriteConversions(t *testing.T) {
"a": model.StringType,
}, &schema.ObjectType{})),
},
{
input: `{a: "1" + 2}`,
output: `{a: __convert( "1") + 2}`,
to: model.NewObjectType(map[string]model.Type{
"a": model.NumberType,
}),
},
{
input: `[{a: "b"}]`,
output: "[\n __convert({a: \"b\"})]",
to: model.NewListType(model.NewObjectType(map[string]model.Type{
"a": model.StringType,
}, &schema.ObjectType{})),
},
{
input: `[for v in ["b"]: {a: v}]`,
output: `[for v in ["b"]: __convert( {a: v})]`,
to: model.NewListType(model.NewObjectType(map[string]model.Type{
"a": model.StringType,
}, &schema.ObjectType{})),
},
{
input: `true ? {a: "b"} : {a: "c"}`,
output: `true ? __convert( {a: "b"}) : __convert( {a: "c"})`,
to: model.NewObjectType(map[string]model.Type{
"a": model.StringType,
}, &schema.ObjectType{}),
},
{
input: `!"true"`,
output: `!__convert("true")`,
to: model.BoolType,
},
{
input: `["a"][i]`,
output: `["a"][__convert(i)]`,
to: model.StringType,
},
}
scope := model.NewRootScope(syntax.None)
scope.Define("i", &model.Variable{
Name: "i",
VariableType: model.StringType,
})
for _, c := range cases {
expr, diags := model.BindExpressionText(c.input, scope, hcl.Pos{})
assert.Len(t, diags, 0)
@ -63,3 +105,40 @@ func TestRewriteConversions(t *testing.T) {
assert.Equal(t, c.output, fmt.Sprintf("%v", expr))
}
}
func TestRewriteConversionsAfterApply(t *testing.T) {
cases := []struct {
input, output string
}{
{
input: `f({id: v.id})`,
output: `__apply(v,eval(v, f(__convert({id: v.id}))))`,
},
}
scope := model.NewRootScope(syntax.None)
scope.DefineFunction("f", model.NewFunction(model.StaticFunctionSignature{
Parameters: []model.Parameter{{
Name: "args",
Type: model.NewObjectType(map[string]model.Type{
"id": model.StringType,
}, &schema.ObjectType{}),
}},
ReturnType: model.DynamicType,
}))
scope.Define("v", &model.Variable{
Name: "v",
VariableType: model.NewOutputType(model.NewObjectType(map[string]model.Type{
"id": model.StringType,
})),
})
for _, c := range cases {
expr, diags := model.BindExpressionText(c.input, scope, hcl.Pos{})
assert.Len(t, diags, 0)
expr, _ = RewriteApplies(expr, nameInfo(0), false)
expr = RewriteConversions(expr, expr.Type())
assert.Equal(t, c.output, fmt.Sprintf("%v", expr))
}
}