| package hil |
| |
| import ( |
| "fmt" |
| "sync" |
| |
| "github.com/hashicorp/hil/ast" |
| ) |
| |
| // TypeCheck implements ast.Visitor for type checking an AST tree. |
| // It requires some configuration to look up the type of nodes. |
| // |
| // It also optionally will not type error and will insert an implicit |
| // type conversions for specific types if specified by the Implicit |
| // field. Note that this is kind of organizationally weird to put into |
| // this structure but we'd rather do that than duplicate the type checking |
| // logic multiple times. |
| type TypeCheck struct { |
| Scope ast.Scope |
| |
| // Implicit is a map of implicit type conversions that we can do, |
| // and that shouldn't error. The key of the first map is the from type, |
| // the key of the second map is the to type, and the final string |
| // value is the function to call (which must be registered in the Scope). |
| Implicit map[ast.Type]map[ast.Type]string |
| |
| // Stack of types. This shouldn't be used directly except by implementations |
| // of TypeCheckNode. |
| Stack []ast.Type |
| |
| err error |
| lock sync.Mutex |
| } |
| |
| // TypeCheckNode is the interface that must be implemented by any |
| // ast.Node that wants to support type-checking. If the type checker |
| // encounters a node that doesn't implement this, it will error. |
| type TypeCheckNode interface { |
| TypeCheck(*TypeCheck) (ast.Node, error) |
| } |
| |
| func (v *TypeCheck) Visit(root ast.Node) error { |
| v.lock.Lock() |
| defer v.lock.Unlock() |
| defer v.reset() |
| root.Accept(v.visit) |
| |
| // If the resulting type is unknown, then just let the whole thing go. |
| if v.err == errExitUnknown { |
| v.err = nil |
| } |
| |
| return v.err |
| } |
| |
| func (v *TypeCheck) visit(raw ast.Node) ast.Node { |
| if v.err != nil { |
| return raw |
| } |
| |
| var result ast.Node |
| var err error |
| switch n := raw.(type) { |
| case *ast.Arithmetic: |
| tc := &typeCheckArithmetic{n} |
| result, err = tc.TypeCheck(v) |
| case *ast.Call: |
| tc := &typeCheckCall{n} |
| result, err = tc.TypeCheck(v) |
| case *ast.Conditional: |
| tc := &typeCheckConditional{n} |
| result, err = tc.TypeCheck(v) |
| case *ast.Index: |
| tc := &typeCheckIndex{n} |
| result, err = tc.TypeCheck(v) |
| case *ast.Output: |
| tc := &typeCheckOutput{n} |
| result, err = tc.TypeCheck(v) |
| case *ast.LiteralNode: |
| tc := &typeCheckLiteral{n} |
| result, err = tc.TypeCheck(v) |
| case *ast.VariableAccess: |
| tc := &typeCheckVariableAccess{n} |
| result, err = tc.TypeCheck(v) |
| default: |
| tc, ok := raw.(TypeCheckNode) |
| if !ok { |
| err = fmt.Errorf("unknown node for type check: %#v", raw) |
| break |
| } |
| |
| result, err = tc.TypeCheck(v) |
| } |
| |
| if err != nil { |
| pos := raw.Pos() |
| v.err = fmt.Errorf("At column %d, line %d: %s", |
| pos.Column, pos.Line, err) |
| } |
| |
| return result |
| } |
| |
| type typeCheckArithmetic struct { |
| n *ast.Arithmetic |
| } |
| |
| func (tc *typeCheckArithmetic) TypeCheck(v *TypeCheck) (ast.Node, error) { |
| // The arguments are on the stack in reverse order, so pop them off. |
| exprs := make([]ast.Type, len(tc.n.Exprs)) |
| for i, _ := range tc.n.Exprs { |
| exprs[len(tc.n.Exprs)-1-i] = v.StackPop() |
| } |
| |
| // If any operand is unknown then our result is automatically unknown |
| for _, ty := range exprs { |
| if ty == ast.TypeUnknown { |
| v.StackPush(ast.TypeUnknown) |
| return tc.n, nil |
| } |
| } |
| |
| switch tc.n.Op { |
| case ast.ArithmeticOpLogicalAnd, ast.ArithmeticOpLogicalOr: |
| return tc.checkLogical(v, exprs) |
| case ast.ArithmeticOpEqual, ast.ArithmeticOpNotEqual, |
| ast.ArithmeticOpLessThan, ast.ArithmeticOpGreaterThan, |
| ast.ArithmeticOpGreaterThanOrEqual, ast.ArithmeticOpLessThanOrEqual: |
| return tc.checkComparison(v, exprs) |
| default: |
| return tc.checkNumeric(v, exprs) |
| } |
| |
| } |
| |
| func (tc *typeCheckArithmetic) checkNumeric(v *TypeCheck, exprs []ast.Type) (ast.Node, error) { |
| // Determine the resulting type we want. We do this by going over |
| // every expression until we find one with a type we recognize. |
| // We do this because the first expr might be a string ("var.foo") |
| // and we need to know what to implicit to. |
| mathFunc := "__builtin_IntMath" |
| mathType := ast.TypeInt |
| for _, v := range exprs { |
| // We assume int math but if we find ANY float, the entire |
| // expression turns into floating point math. |
| if v == ast.TypeFloat { |
| mathFunc = "__builtin_FloatMath" |
| mathType = v |
| break |
| } |
| } |
| |
| // Verify the args |
| for i, arg := range exprs { |
| if arg != mathType { |
| cn := v.ImplicitConversion(exprs[i], mathType, tc.n.Exprs[i]) |
| if cn != nil { |
| tc.n.Exprs[i] = cn |
| continue |
| } |
| |
| return nil, fmt.Errorf( |
| "operand %d should be %s, got %s", |
| i+1, mathType, arg) |
| } |
| } |
| |
| // Modulo doesn't work for floats |
| if mathType == ast.TypeFloat && tc.n.Op == ast.ArithmeticOpMod { |
| return nil, fmt.Errorf("modulo cannot be used with floats") |
| } |
| |
| // Return type |
| v.StackPush(mathType) |
| |
| // Replace our node with a call to the proper function. This isn't |
| // type checked but we already verified types. |
| args := make([]ast.Node, len(tc.n.Exprs)+1) |
| args[0] = &ast.LiteralNode{ |
| Value: tc.n.Op, |
| Typex: ast.TypeInt, |
| Posx: tc.n.Pos(), |
| } |
| copy(args[1:], tc.n.Exprs) |
| return &ast.Call{ |
| Func: mathFunc, |
| Args: args, |
| Posx: tc.n.Pos(), |
| }, nil |
| } |
| |
| func (tc *typeCheckArithmetic) checkComparison(v *TypeCheck, exprs []ast.Type) (ast.Node, error) { |
| if len(exprs) != 2 { |
| // This should never happen, because the parser never produces |
| // nodes that violate this. |
| return nil, fmt.Errorf( |
| "comparison operators must have exactly two operands", |
| ) |
| } |
| |
| // The first operand always dictates the type for a comparison. |
| compareFunc := "" |
| compareType := exprs[0] |
| switch compareType { |
| case ast.TypeBool: |
| compareFunc = "__builtin_BoolCompare" |
| case ast.TypeFloat: |
| compareFunc = "__builtin_FloatCompare" |
| case ast.TypeInt: |
| compareFunc = "__builtin_IntCompare" |
| case ast.TypeString: |
| compareFunc = "__builtin_StringCompare" |
| default: |
| return nil, fmt.Errorf( |
| "comparison operators apply only to bool, float, int, and string", |
| ) |
| } |
| |
| // For non-equality comparisons, we will do implicit conversions to |
| // integer types if possible. In this case, we need to go through and |
| // determine the type of comparison we're doing to enable the implicit |
| // conversion. |
| if tc.n.Op != ast.ArithmeticOpEqual && tc.n.Op != ast.ArithmeticOpNotEqual { |
| compareFunc = "__builtin_IntCompare" |
| compareType = ast.TypeInt |
| for _, expr := range exprs { |
| if expr == ast.TypeFloat { |
| compareFunc = "__builtin_FloatCompare" |
| compareType = ast.TypeFloat |
| break |
| } |
| } |
| } |
| |
| // Verify (and possibly, convert) the args |
| for i, arg := range exprs { |
| if arg != compareType { |
| cn := v.ImplicitConversion(exprs[i], compareType, tc.n.Exprs[i]) |
| if cn != nil { |
| tc.n.Exprs[i] = cn |
| continue |
| } |
| |
| return nil, fmt.Errorf( |
| "operand %d should be %s, got %s", |
| i+1, compareType, arg, |
| ) |
| } |
| } |
| |
| // Only ints and floats can have the <, >, <= and >= operators applied |
| switch tc.n.Op { |
| case ast.ArithmeticOpEqual, ast.ArithmeticOpNotEqual: |
| // anything goes |
| default: |
| switch compareType { |
| case ast.TypeFloat, ast.TypeInt: |
| // fine |
| default: |
| return nil, fmt.Errorf( |
| "<, >, <= and >= may apply only to int and float values", |
| ) |
| } |
| } |
| |
| // Comparison operators always return bool |
| v.StackPush(ast.TypeBool) |
| |
| // Replace our node with a call to the proper function. This isn't |
| // type checked but we already verified types. |
| args := make([]ast.Node, len(tc.n.Exprs)+1) |
| args[0] = &ast.LiteralNode{ |
| Value: tc.n.Op, |
| Typex: ast.TypeInt, |
| Posx: tc.n.Pos(), |
| } |
| copy(args[1:], tc.n.Exprs) |
| return &ast.Call{ |
| Func: compareFunc, |
| Args: args, |
| Posx: tc.n.Pos(), |
| }, nil |
| } |
| |
| func (tc *typeCheckArithmetic) checkLogical(v *TypeCheck, exprs []ast.Type) (ast.Node, error) { |
| for i, t := range exprs { |
| if t != ast.TypeBool { |
| cn := v.ImplicitConversion(t, ast.TypeBool, tc.n.Exprs[i]) |
| if cn == nil { |
| return nil, fmt.Errorf( |
| "logical operators require boolean operands, not %s", |
| t, |
| ) |
| } |
| tc.n.Exprs[i] = cn |
| } |
| } |
| |
| // Return type is always boolean |
| v.StackPush(ast.TypeBool) |
| |
| // Arithmetic nodes are replaced with a call to a built-in function |
| args := make([]ast.Node, len(tc.n.Exprs)+1) |
| args[0] = &ast.LiteralNode{ |
| Value: tc.n.Op, |
| Typex: ast.TypeInt, |
| Posx: tc.n.Pos(), |
| } |
| copy(args[1:], tc.n.Exprs) |
| return &ast.Call{ |
| Func: "__builtin_Logical", |
| Args: args, |
| Posx: tc.n.Pos(), |
| }, nil |
| } |
| |
| type typeCheckCall struct { |
| n *ast.Call |
| } |
| |
| func (tc *typeCheckCall) TypeCheck(v *TypeCheck) (ast.Node, error) { |
| // Look up the function in the map |
| function, ok := v.Scope.LookupFunc(tc.n.Func) |
| if !ok { |
| return nil, fmt.Errorf("unknown function called: %s", tc.n.Func) |
| } |
| |
| // The arguments are on the stack in reverse order, so pop them off. |
| args := make([]ast.Type, len(tc.n.Args)) |
| for i, _ := range tc.n.Args { |
| args[len(tc.n.Args)-1-i] = v.StackPop() |
| } |
| |
| // Verify the args |
| for i, expected := range function.ArgTypes { |
| if expected == ast.TypeAny { |
| continue |
| } |
| |
| if args[i] == ast.TypeUnknown { |
| v.StackPush(ast.TypeUnknown) |
| return tc.n, nil |
| } |
| |
| if args[i] != expected { |
| cn := v.ImplicitConversion(args[i], expected, tc.n.Args[i]) |
| if cn != nil { |
| tc.n.Args[i] = cn |
| continue |
| } |
| |
| return nil, fmt.Errorf( |
| "%s: argument %d should be %s, got %s", |
| tc.n.Func, i+1, expected.Printable(), args[i].Printable()) |
| } |
| } |
| |
| // If we're variadic, then verify the types there |
| if function.Variadic && function.VariadicType != ast.TypeAny { |
| args = args[len(function.ArgTypes):] |
| for i, t := range args { |
| if t == ast.TypeUnknown { |
| v.StackPush(ast.TypeUnknown) |
| return tc.n, nil |
| } |
| |
| if t != function.VariadicType { |
| realI := i + len(function.ArgTypes) |
| cn := v.ImplicitConversion( |
| t, function.VariadicType, tc.n.Args[realI]) |
| if cn != nil { |
| tc.n.Args[realI] = cn |
| continue |
| } |
| |
| return nil, fmt.Errorf( |
| "%s: argument %d should be %s, got %s", |
| tc.n.Func, realI, |
| function.VariadicType.Printable(), t.Printable()) |
| } |
| } |
| } |
| |
| // Return type |
| v.StackPush(function.ReturnType) |
| |
| return tc.n, nil |
| } |
| |
| type typeCheckConditional struct { |
| n *ast.Conditional |
| } |
| |
| func (tc *typeCheckConditional) TypeCheck(v *TypeCheck) (ast.Node, error) { |
| // On the stack we have the types of the condition, true and false |
| // expressions, but they are in reverse order. |
| falseType := v.StackPop() |
| trueType := v.StackPop() |
| condType := v.StackPop() |
| |
| if condType == ast.TypeUnknown { |
| v.StackPush(ast.TypeUnknown) |
| return tc.n, nil |
| } |
| |
| if condType != ast.TypeBool { |
| cn := v.ImplicitConversion(condType, ast.TypeBool, tc.n.CondExpr) |
| if cn == nil { |
| return nil, fmt.Errorf( |
| "condition must be type bool, not %s", condType.Printable(), |
| ) |
| } |
| tc.n.CondExpr = cn |
| } |
| |
| // The types of the true and false expression must match |
| if trueType != falseType && trueType != ast.TypeUnknown && falseType != ast.TypeUnknown { |
| |
| // Since passing around stringified versions of other types is |
| // common, we pragmatically allow the false expression to dictate |
| // the result type when the true expression is a string. |
| if trueType == ast.TypeString { |
| cn := v.ImplicitConversion(trueType, falseType, tc.n.TrueExpr) |
| if cn == nil { |
| return nil, fmt.Errorf( |
| "true and false expression types must match; have %s and %s", |
| trueType.Printable(), falseType.Printable(), |
| ) |
| } |
| tc.n.TrueExpr = cn |
| trueType = falseType |
| } else { |
| cn := v.ImplicitConversion(falseType, trueType, tc.n.FalseExpr) |
| if cn == nil { |
| return nil, fmt.Errorf( |
| "true and false expression types must match; have %s and %s", |
| trueType.Printable(), falseType.Printable(), |
| ) |
| } |
| tc.n.FalseExpr = cn |
| falseType = trueType |
| } |
| } |
| |
| // Currently list and map types cannot be used, because we cannot |
| // generally assert that their element types are consistent. |
| // Such support might be added later, either by improving the type |
| // system or restricting usage to only variable and literal expressions, |
| // but for now this is simply prohibited because it doesn't seem to |
| // be a common enough case to be worth the complexity. |
| switch trueType { |
| case ast.TypeList: |
| return nil, fmt.Errorf( |
| "conditional operator cannot be used with list values", |
| ) |
| case ast.TypeMap: |
| return nil, fmt.Errorf( |
| "conditional operator cannot be used with map values", |
| ) |
| } |
| |
| // Result type (guaranteed to also match falseType due to the above) |
| if trueType == ast.TypeUnknown { |
| // falseType may also be unknown, but that's okay because two |
| // unknowns means our result is unknown anyway. |
| v.StackPush(falseType) |
| } else { |
| v.StackPush(trueType) |
| } |
| |
| return tc.n, nil |
| } |
| |
| type typeCheckOutput struct { |
| n *ast.Output |
| } |
| |
| func (tc *typeCheckOutput) TypeCheck(v *TypeCheck) (ast.Node, error) { |
| n := tc.n |
| types := make([]ast.Type, len(n.Exprs)) |
| for i, _ := range n.Exprs { |
| types[len(n.Exprs)-1-i] = v.StackPop() |
| } |
| |
| for _, ty := range types { |
| if ty == ast.TypeUnknown { |
| v.StackPush(ast.TypeUnknown) |
| return tc.n, nil |
| } |
| } |
| |
| // If there is only one argument and it is a list, we evaluate to a list |
| if len(types) == 1 { |
| switch t := types[0]; t { |
| case ast.TypeList: |
| fallthrough |
| case ast.TypeMap: |
| v.StackPush(t) |
| return n, nil |
| } |
| } |
| |
| // Otherwise, all concat args must be strings, so validate that |
| resultType := ast.TypeString |
| for i, t := range types { |
| |
| if t == ast.TypeUnknown { |
| resultType = ast.TypeUnknown |
| continue |
| } |
| |
| if t != ast.TypeString { |
| cn := v.ImplicitConversion(t, ast.TypeString, n.Exprs[i]) |
| if cn != nil { |
| n.Exprs[i] = cn |
| continue |
| } |
| |
| return nil, fmt.Errorf( |
| "output of an HIL expression must be a string, or a single list (argument %d is %s)", i+1, t) |
| } |
| } |
| |
| // This always results in type string, unless there are unknowns |
| v.StackPush(resultType) |
| |
| return n, nil |
| } |
| |
| type typeCheckLiteral struct { |
| n *ast.LiteralNode |
| } |
| |
| func (tc *typeCheckLiteral) TypeCheck(v *TypeCheck) (ast.Node, error) { |
| v.StackPush(tc.n.Typex) |
| return tc.n, nil |
| } |
| |
| type typeCheckVariableAccess struct { |
| n *ast.VariableAccess |
| } |
| |
| func (tc *typeCheckVariableAccess) TypeCheck(v *TypeCheck) (ast.Node, error) { |
| // Look up the variable in the map |
| variable, ok := v.Scope.LookupVar(tc.n.Name) |
| if !ok { |
| return nil, fmt.Errorf( |
| "unknown variable accessed: %s", tc.n.Name) |
| } |
| |
| // Add the type to the stack |
| v.StackPush(variable.Type) |
| |
| return tc.n, nil |
| } |
| |
| type typeCheckIndex struct { |
| n *ast.Index |
| } |
| |
| func (tc *typeCheckIndex) TypeCheck(v *TypeCheck) (ast.Node, error) { |
| keyType := v.StackPop() |
| targetType := v.StackPop() |
| |
| if keyType == ast.TypeUnknown || targetType == ast.TypeUnknown { |
| v.StackPush(ast.TypeUnknown) |
| return tc.n, nil |
| } |
| |
| // Ensure we have a VariableAccess as the target |
| varAccessNode, ok := tc.n.Target.(*ast.VariableAccess) |
| if !ok { |
| return nil, fmt.Errorf( |
| "target of an index must be a VariableAccess node, was %T", tc.n.Target) |
| } |
| |
| // Get the variable |
| variable, ok := v.Scope.LookupVar(varAccessNode.Name) |
| if !ok { |
| return nil, fmt.Errorf( |
| "unknown variable accessed: %s", varAccessNode.Name) |
| } |
| |
| switch targetType { |
| case ast.TypeList: |
| if keyType != ast.TypeInt { |
| tc.n.Key = v.ImplicitConversion(keyType, ast.TypeInt, tc.n.Key) |
| if tc.n.Key == nil { |
| return nil, fmt.Errorf( |
| "key of an index must be an int, was %s", keyType) |
| } |
| } |
| |
| valType, err := ast.VariableListElementTypesAreHomogenous( |
| varAccessNode.Name, variable.Value.([]ast.Variable)) |
| if err != nil { |
| return tc.n, err |
| } |
| |
| v.StackPush(valType) |
| return tc.n, nil |
| case ast.TypeMap: |
| if keyType != ast.TypeString { |
| tc.n.Key = v.ImplicitConversion(keyType, ast.TypeString, tc.n.Key) |
| if tc.n.Key == nil { |
| return nil, fmt.Errorf( |
| "key of an index must be a string, was %s", keyType) |
| } |
| } |
| |
| valType, err := ast.VariableMapValueTypesAreHomogenous( |
| varAccessNode.Name, variable.Value.(map[string]ast.Variable)) |
| if err != nil { |
| return tc.n, err |
| } |
| |
| v.StackPush(valType) |
| return tc.n, nil |
| default: |
| return nil, fmt.Errorf("invalid index operation into non-indexable type: %s", variable.Type) |
| } |
| } |
| |
| func (v *TypeCheck) ImplicitConversion( |
| actual ast.Type, expected ast.Type, n ast.Node) ast.Node { |
| if v.Implicit == nil { |
| return nil |
| } |
| |
| fromMap, ok := v.Implicit[actual] |
| if !ok { |
| return nil |
| } |
| |
| toFunc, ok := fromMap[expected] |
| if !ok { |
| return nil |
| } |
| |
| return &ast.Call{ |
| Func: toFunc, |
| Args: []ast.Node{n}, |
| Posx: n.Pos(), |
| } |
| } |
| |
| func (v *TypeCheck) reset() { |
| v.Stack = nil |
| v.err = nil |
| } |
| |
| func (v *TypeCheck) StackPush(t ast.Type) { |
| v.Stack = append(v.Stack, t) |
| } |
| |
| func (v *TypeCheck) StackPop() ast.Type { |
| var x ast.Type |
| x, v.Stack = v.Stack[len(v.Stack)-1], v.Stack[:len(v.Stack)-1] |
| return x |
| } |
| |
| func (v *TypeCheck) StackPeek() ast.Type { |
| if len(v.Stack) == 0 { |
| return ast.TypeInvalid |
| } |
| |
| return v.Stack[len(v.Stack)-1] |
| } |