Skip to content

Commit

Permalink
Fix operator overloading recursive
Browse files Browse the repository at this point in the history
Fixes #548
  • Loading branch information
antonmedv committed Feb 17, 2024
1 parent 62cdd42 commit 3da8527
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 5 deletions.
23 changes: 18 additions & 5 deletions expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,24 @@ func Compile(input string, ops ...Option) (*vm.Program, error) {
}

if len(config.Visitors) > 0 {
for _, v := range config.Visitors {
// We need to perform types check, because some visitors may rely on
// types information available in the tree.
_, _ = checker.Check(tree, config)
ast.Walk(&tree.Node, v)
for i := 0; i < 1000; i++ {
more := false
for _, v := range config.Visitors {
// We need to perform types check, because some visitors may rely on
// types information available in the tree.
_, _ = checker.Check(tree, config)

ast.Walk(&tree.Node, v)

if v, ok := v.(interface {
ShouldRepeat() bool
}); ok {
more = more || v.ShouldRepeat()
}
}
if !more {
break
}
}
}
_, err = checker.Check(tree, config)
Expand Down
6 changes: 6 additions & 0 deletions patcher/operator_override.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type OperatorOverride struct {
Overrides []string // List of function names to override operator with.
Types conf.TypesTable // Env types.
Functions conf.FunctionsTable // Env functions.
applied bool // Flag to indicate if any override was applied.
}

func (p *OperatorOverride) Visit(node *ast.Node) {
Expand All @@ -37,9 +38,14 @@ func (p *OperatorOverride) Visit(node *ast.Node) {
}
newNode.SetType(ret)
ast.Patch(node, newNode)
p.applied = true
}
}

func (p *OperatorOverride) ShouldRepeat() bool {
return p.applied
}

func (p *OperatorOverride) FindSuitableOperatorOverload(l, r reflect.Type) (reflect.Type, string, bool) {
t, fn, ok := p.findSuitableOperatorOverloadInFunctions(l, r)
if !ok {
Expand Down
73 changes: 73 additions & 0 deletions test/operator/issues584/issues584_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package issues584_test

import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/expr-lang/expr"
)

type Env struct{}

type Program struct {
}

func (p *Program) Foo() Value {
return func(e *Env) float64 {
return 5
}
}

func (p *Program) Bar() Value {
return func(e *Env) float64 {
return 100
}
}

func (p *Program) AndCondition(a, b Condition) Conditions {
return Conditions{a, b}
}

func (p *Program) AndConditions(a Conditions, b Condition) Conditions {
return append(a, b)
}

func (p *Program) ValueGreaterThan_float(v Value, i float64) Condition {
return func(e *Env) bool {
realized := v(e)
return realized > i
}
}

func (p *Program) ValueLessThan_float(v Value, i float64) Condition {
return func(e *Env) bool {
realized := v(e)
return realized < i
}
}

type Condition func(e *Env) bool
type Conditions []Condition

type Value func(e *Env) float64

func TestIssue584(t *testing.T) {
code := `Foo() > 1.5 and Bar() < 200.0`

p := &Program{}

opt := []expr.Option{
expr.Env(p),
expr.Operator("and", "AndCondition", "AndConditions"),
expr.Operator(">", "ValueGreaterThan_float"),
expr.Operator("<", "ValueLessThan_float"),
}

program, err := expr.Compile(code, opt...)
assert.Nil(t, err)

state, err := expr.Run(program, p)
assert.Nil(t, err)
assert.NotNil(t, state)
}
37 changes: 37 additions & 0 deletions test/operator/operator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,40 @@ func TestOperator_Polymorphic(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 6, output)
}

func TestOperator_recursive_apply(t *testing.T) {
type Decimal struct {
Int int
}

env := map[string]any{
"add": func(a, b Decimal) Decimal {
return Decimal{
Int: a.Int + b.Int,
}
},
"addInt": func(a Decimal, b int) Decimal {
return Decimal{
Int: a.Int + b,
}
},
"a": Decimal{1},
"b": Decimal{2},
"c": Decimal{3},
"d": Decimal{4},
"e": Decimal{5},
}

program, err := expr.Compile(
`a + b + 100 + c + d + e`,
expr.Env(env),
expr.Operator("+", "add"),
expr.Operator("+", "addInt"),
)
require.NoError(t, err)
require.Equal(t, `add(add(add(addInt(add(a, b), 100), c), d), e)`, program.Node().String())

output, err := expr.Run(program, env)
require.NoError(t, err)
require.Equal(t, 115, output.(Decimal).Int)
}

0 comments on commit 3da8527

Please sign in to comment.