[codegen/hcl2] Fix descent in RewriteConversions. (#4614)
Descend into anonymous function expressions, the operands to conditional expressions, and the value in for expressions.
This commit is contained in:
parent
71910a1c52
commit
647b6627a2
|
@ -40,18 +40,35 @@ func sameSchemaTypes(xt, yt model.Type) bool {
|
|||
|
||||
func RewriteConversions(x model.Expression, to model.Type) model.Expression {
|
||||
switch x := x.(type) {
|
||||
case *model.AnonymousFunctionExpression:
|
||||
x.Body = RewriteConversions(x.Body, to)
|
||||
case *model.BinaryOpExpression:
|
||||
x.LeftOperand = RewriteConversions(x.LeftOperand, model.InputType(x.LeftOperandType()))
|
||||
x.RightOperand = RewriteConversions(x.RightOperand, model.InputType(x.RightOperandType()))
|
||||
case *model.ConditionalExpression:
|
||||
x.Condition = RewriteConversions(x.Condition, model.InputType(model.BoolType))
|
||||
x.TrueResult = RewriteConversions(x.TrueResult, to)
|
||||
x.FalseResult = RewriteConversions(x.FalseResult, to)
|
||||
|
||||
diags := x.Typecheck(false)
|
||||
contract.Assert(len(diags) == 0)
|
||||
case *model.ForExpression:
|
||||
traverserType := model.NumberType
|
||||
if x.Key != nil {
|
||||
traverserType = model.StringType
|
||||
x.Key = RewriteConversions(x.Key, model.InputType(model.StringType))
|
||||
}
|
||||
if x.Condition != nil {
|
||||
x.Condition = RewriteConversions(x.Condition, model.InputType(model.BoolType))
|
||||
}
|
||||
|
||||
valueType, diags := to.Traverse(model.MakeTraverser(traverserType))
|
||||
contract.Ignore(diags)
|
||||
|
||||
x.Value = RewriteConversions(x.Value, valueType.(model.Type))
|
||||
|
||||
diags = x.Typecheck(false)
|
||||
contract.Assert(len(diags) == 0)
|
||||
case *model.FunctionCallExpression:
|
||||
args := x.Args
|
||||
for _, param := range x.Signature.Parameters {
|
||||
|
|
|
@ -48,9 +48,51 @@ func TestRewriteConversions(t *testing.T) {
|
|||
"a": model.StringType,
|
||||
}, &schema.ObjectType{})),
|
||||
},
|
||||
{
|
||||
input: `{a: "1" + 2}`,
|
||||
output: `{a: __convert( "1") + 2}`,
|
||||
to: model.NewObjectType(map[string]model.Type{
|
||||
"a": model.NumberType,
|
||||
}),
|
||||
},
|
||||
{
|
||||
input: `[{a: "b"}]`,
|
||||
output: "[\n __convert({a: \"b\"})]",
|
||||
to: model.NewListType(model.NewObjectType(map[string]model.Type{
|
||||
"a": model.StringType,
|
||||
}, &schema.ObjectType{})),
|
||||
},
|
||||
{
|
||||
input: `[for v in ["b"]: {a: v}]`,
|
||||
output: `[for v in ["b"]: __convert( {a: v})]`,
|
||||
to: model.NewListType(model.NewObjectType(map[string]model.Type{
|
||||
"a": model.StringType,
|
||||
}, &schema.ObjectType{})),
|
||||
},
|
||||
{
|
||||
input: `true ? {a: "b"} : {a: "c"}`,
|
||||
output: `true ? __convert( {a: "b"}) : __convert( {a: "c"})`,
|
||||
to: model.NewObjectType(map[string]model.Type{
|
||||
"a": model.StringType,
|
||||
}, &schema.ObjectType{}),
|
||||
},
|
||||
{
|
||||
input: `!"true"`,
|
||||
output: `!__convert("true")`,
|
||||
to: model.BoolType,
|
||||
},
|
||||
{
|
||||
input: `["a"][i]`,
|
||||
output: `["a"][__convert(i)]`,
|
||||
to: model.StringType,
|
||||
},
|
||||
}
|
||||
|
||||
scope := model.NewRootScope(syntax.None)
|
||||
scope.Define("i", &model.Variable{
|
||||
Name: "i",
|
||||
VariableType: model.StringType,
|
||||
})
|
||||
for _, c := range cases {
|
||||
expr, diags := model.BindExpressionText(c.input, scope, hcl.Pos{})
|
||||
assert.Len(t, diags, 0)
|
||||
|
@ -63,3 +105,40 @@ func TestRewriteConversions(t *testing.T) {
|
|||
assert.Equal(t, c.output, fmt.Sprintf("%v", expr))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteConversionsAfterApply(t *testing.T) {
|
||||
cases := []struct {
|
||||
input, output string
|
||||
}{
|
||||
{
|
||||
input: `f({id: v.id})`,
|
||||
output: `__apply(v,eval(v, f(__convert({id: v.id}))))`,
|
||||
},
|
||||
}
|
||||
|
||||
scope := model.NewRootScope(syntax.None)
|
||||
scope.DefineFunction("f", model.NewFunction(model.StaticFunctionSignature{
|
||||
Parameters: []model.Parameter{{
|
||||
Name: "args",
|
||||
Type: model.NewObjectType(map[string]model.Type{
|
||||
"id": model.StringType,
|
||||
}, &schema.ObjectType{}),
|
||||
}},
|
||||
ReturnType: model.DynamicType,
|
||||
}))
|
||||
scope.Define("v", &model.Variable{
|
||||
Name: "v",
|
||||
VariableType: model.NewOutputType(model.NewObjectType(map[string]model.Type{
|
||||
"id": model.StringType,
|
||||
})),
|
||||
})
|
||||
|
||||
for _, c := range cases {
|
||||
expr, diags := model.BindExpressionText(c.input, scope, hcl.Pos{})
|
||||
assert.Len(t, diags, 0)
|
||||
|
||||
expr, _ = RewriteApplies(expr, nameInfo(0), false)
|
||||
expr = RewriteConversions(expr, expr.Type())
|
||||
assert.Equal(t, c.output, fmt.Sprintf("%v", expr))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue