diff --git a/expr.go b/expr.go index 7d76a80b..d9f7c0dc 100644 --- a/expr.go +++ b/expr.go @@ -200,11 +200,24 @@ func Compile(input string, ops ...Option) (*vm.Program, error) { } if len(config.Visitors) > 0 { - for _, v := range config.Visitors { - // We need to perform types check, because some visitors may rely on - // types information available in the tree. - _, _ = checker.Check(tree, config) - ast.Walk(&tree.Node, v) + for i := 0; i < 1000; i++ { + more := false + for _, v := range config.Visitors { + // We need to perform types check, because some visitors may rely on + // types information available in the tree. + _, _ = checker.Check(tree, config) + + ast.Walk(&tree.Node, v) + + if v, ok := v.(interface { + ShouldRepeat() bool + }); ok { + more = more || v.ShouldRepeat() + } + } + if !more { + break + } } } _, err = checker.Check(tree, config) diff --git a/patcher/operator_override.go b/patcher/operator_override.go index 38b3558b..96e6894c 100644 --- a/patcher/operator_override.go +++ b/patcher/operator_override.go @@ -14,6 +14,7 @@ type OperatorOverride struct { Overrides []string // List of function names to override operator with. Types conf.TypesTable // Env types. Functions conf.FunctionsTable // Env functions. + applied bool // Flag to indicate if any override was applied. } func (p *OperatorOverride) Visit(node *ast.Node) { @@ -37,9 +38,14 @@ func (p *OperatorOverride) Visit(node *ast.Node) { } newNode.SetType(ret) ast.Patch(node, newNode) + p.applied = true } } +func (p *OperatorOverride) ShouldRepeat() bool { + return p.applied +} + func (p *OperatorOverride) FindSuitableOperatorOverload(l, r reflect.Type) (reflect.Type, string, bool) { t, fn, ok := p.findSuitableOperatorOverloadInFunctions(l, r) if !ok { diff --git a/test/operator/issues584/issues584_test.go b/test/operator/issues584/issues584_test.go new file mode 100644 index 00000000..13c9c66f --- /dev/null +++ b/test/operator/issues584/issues584_test.go @@ -0,0 +1,73 @@ +package issues584_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/expr-lang/expr" +) + +type Env struct{} + +type Program struct { +} + +func (p *Program) Foo() Value { + return func(e *Env) float64 { + return 5 + } +} + +func (p *Program) Bar() Value { + return func(e *Env) float64 { + return 100 + } +} + +func (p *Program) AndCondition(a, b Condition) Conditions { + return Conditions{a, b} +} + +func (p *Program) AndConditions(a Conditions, b Condition) Conditions { + return append(a, b) +} + +func (p *Program) ValueGreaterThan_float(v Value, i float64) Condition { + return func(e *Env) bool { + realized := v(e) + return realized > i + } +} + +func (p *Program) ValueLessThan_float(v Value, i float64) Condition { + return func(e *Env) bool { + realized := v(e) + return realized < i + } +} + +type Condition func(e *Env) bool +type Conditions []Condition + +type Value func(e *Env) float64 + +func TestIssue584(t *testing.T) { + code := `Foo() > 1.5 and Bar() < 200.0` + + p := &Program{} + + opt := []expr.Option{ + expr.Env(p), + expr.Operator("and", "AndCondition", "AndConditions"), + expr.Operator(">", "ValueGreaterThan_float"), + expr.Operator("<", "ValueLessThan_float"), + } + + program, err := expr.Compile(code, opt...) + assert.Nil(t, err) + + state, err := expr.Run(program, p) + assert.Nil(t, err) + assert.NotNil(t, state) +} diff --git a/test/operator/operator_test.go b/test/operator/operator_test.go index e17c5fd6..99817eff 100644 --- a/test/operator/operator_test.go +++ b/test/operator/operator_test.go @@ -216,3 +216,40 @@ func TestOperator_Polymorphic(t *testing.T) { require.NoError(t, err) require.Equal(t, 6, output) } + +func TestOperator_recursive_apply(t *testing.T) { + type Decimal struct { + Int int + } + + env := map[string]any{ + "add": func(a, b Decimal) Decimal { + return Decimal{ + Int: a.Int + b.Int, + } + }, + "addInt": func(a Decimal, b int) Decimal { + return Decimal{ + Int: a.Int + b, + } + }, + "a": Decimal{1}, + "b": Decimal{2}, + "c": Decimal{3}, + "d": Decimal{4}, + "e": Decimal{5}, + } + + program, err := expr.Compile( + `a + b + 100 + c + d + e`, + expr.Env(env), + expr.Operator("+", "add"), + expr.Operator("+", "addInt"), + ) + require.NoError(t, err) + require.Equal(t, `add(add(add(addInt(add(a, b), 100), c), d), e)`, program.Node().String()) + + output, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, 115, output.(Decimal).Int) +}