Skip to content

Commit

Permalink
Additional features for types and decls (#756)
Browse files Browse the repository at this point in the history
  • Loading branch information
TristonianJones authored Jun 27, 2023
1 parent 72c7ca1 commit b7f8fe1
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 88 deletions.
10 changes: 8 additions & 2 deletions cel/decls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,14 @@ func TestFunctionMerge(t *testing.T) {
t.Errorf("prg.Eval() got %v, wanted %v", out, want)
}

_, err = NewCustomEnv(size, size)
if err == nil || !strings.Contains(err.Error(), "already has singleton binding") {
sizeBad := Function("size",
Overload("size_vector", []*Type{OpaqueType("vector", TypeParamType("V"))}, IntType),
MemberOverload("vector_size", []*Type{OpaqueType("vector", TypeParamType("V"))}, IntType),
SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
return nil
}))
_, err = NewCustomEnv(size, sizeBad)
if err == nil || !strings.Contains(err.Error(), "already has a singleton binding") {
t.Errorf("NewCustomEnv(size, size) did not produce the expected error: %v", err)
}
e, err = NewCustomEnv(size,
Expand Down
79 changes: 54 additions & 25 deletions common/decls/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import (
func NewFunction(name string, opts ...FunctionOpt) (*FunctionDecl, error) {
fn := &FunctionDecl{
Name: name,
Overloads: map[string]*OverloadDecl{},
overloads: map[string]*OverloadDecl{},
overloadOrdinals: []string{},
}
var err error
Expand All @@ -44,7 +44,7 @@ func NewFunction(name string, opts ...FunctionOpt) (*FunctionDecl, error) {
return nil, err
}
}
if len(fn.Overloads) == 0 {
if len(fn.overloads) == 0 {
return nil, fmt.Errorf("function %s must have at least one overload", name)
}
return fn, nil
Expand All @@ -56,8 +56,8 @@ type FunctionDecl struct {
// Name of the function in human-readable terms, e.g. 'contains' of 'math.least'
Name string

// Overloads associated with the function name.
Overloads map[string]*OverloadDecl
// overloads associated with the function name.
overloads map[string]*OverloadDecl

// Singleton implementation of the function for all overloads.
//
Expand Down Expand Up @@ -105,9 +105,9 @@ func (f *FunctionDecl) Merge(other *FunctionDecl) (*FunctionDecl, error) {
}
merged := &FunctionDecl{
Name: f.Name,
Overloads: make(map[string]*OverloadDecl, len(f.Overloads)),
overloads: make(map[string]*OverloadDecl, len(f.overloads)),
Singleton: f.Singleton,
overloadOrdinals: make([]string, len(f.Overloads)),
overloadOrdinals: make([]string, len(f.overloads)),
// if one function is expecting type-guards and the other is not, then they
// must not be disabled.
disableTypeGuards: f.disableTypeGuards && other.disableTypeGuards,
Expand All @@ -121,20 +121,20 @@ func (f *FunctionDecl) Merge(other *FunctionDecl) (*FunctionDecl, error) {
}
// baseline copy of the overloads and their ordinals
copy(merged.overloadOrdinals, f.overloadOrdinals)
for oID, o := range f.Overloads {
merged.Overloads[oID] = o
for oID, o := range f.overloads {
merged.overloads[oID] = o
}
// overloads and their ordinals are added from the left
for _, oID := range other.overloadOrdinals {
o := other.Overloads[oID]
o := other.overloads[oID]
err := merged.AddOverload(o)
if err != nil {
return nil, fmt.Errorf("function declaration merge failed: %v", err)
}
}
if other.Singleton != nil {
if merged.Singleton != nil {
return nil, fmt.Errorf("function already has singleton binding: %s", f.Name)
if merged.Singleton != nil && merged.Singleton != other.Singleton {
return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name)
}
merged.Singleton = other.Singleton
}
Expand All @@ -145,30 +145,39 @@ func (f *FunctionDecl) Merge(other *FunctionDecl) (*FunctionDecl, error) {
// however, if the function signatures are identical, the implementation may be rewritten as its
// difficult to compare functions by object identity.
func (f *FunctionDecl) AddOverload(overload *OverloadDecl) error {
for oID, o := range f.Overloads {
for oID, o := range f.overloads {
if oID != overload.ID && o.SignatureOverlaps(overload) {
return fmt.Errorf("overload signature collision in function %s: %s collides with %s", f.Name, oID, overload.ID)
}
if oID == overload.ID {
if o.SignatureEquals(overload) && o.NonStrict == overload.NonStrict {
// Allow redefinition of an overload implementation so long as the signatures match.
f.Overloads[oID] = overload
f.overloads[oID] = overload
return nil
}
return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.Name, oID)
}
}
f.overloadOrdinals = append(f.overloadOrdinals, overload.ID)
f.Overloads[overload.ID] = overload
f.overloads[overload.ID] = overload
return nil
}

// OverloadDecls returns the overload declarations in the order in which they were declared.
func (f *FunctionDecl) OverloadDecls() []*OverloadDecl {
overloads := make([]*OverloadDecl, 0, len(f.overloads))
for _, oID := range f.overloadOrdinals {
overloads = append(overloads, f.overloads[oID])
}
return overloads
}

// Bindings produces a set of function bindings, if any are defined.
func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
overloads := []*functions.Overload{}
nonStrict := false
for _, oID := range f.overloadOrdinals {
o := f.Overloads[oID]
o := f.overloads[oID]
if o.hasBinding() {
overload := &functions.Overload{
Operator: o.ID,
Expand Down Expand Up @@ -218,7 +227,7 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
// performs dynamic dispatch to the proper overload based on the argument types.
bindings := append([]*functions.Overload{}, overloads...)
funcDispatch := func(args ...ref.Val) ref.Val {
for _, o := range f.Overloads {
for _, o := range f.overloads {
// During dynamic dispatch over multiple functions, signature agreement checks
// are preserved in order to assist with the function resolution step.
switch len(args) {
Expand Down Expand Up @@ -457,6 +466,20 @@ type OverloadDecl struct {
OperandTrait int
}

// GetTypeParams returns the type parameter names associated with the overload.
func (o *OverloadDecl) GetTypeParams() []string {
typeParams := map[string]struct{}{}
collectParamNames(typeParams, o.ResultType)
for _, arg := range o.ArgTypes {
collectParamNames(typeParams, arg)
}
params := make([]string, 0, len(typeParams))
for param := range typeParams {
params = append(params, param)
}
return params
}

// SignatureEquals determines whether the incoming overload declaration signature is equal to the current signature.
//
// Result type, operand trait, and strict-ness are not considered as part of signature equality.
Expand All @@ -469,11 +492,11 @@ func (o *OverloadDecl) SignatureEquals(other *OverloadDecl) bool {
}
for i, at := range o.ArgTypes {
oat := other.ArgTypes[i]
if !at.IsType(oat) {
if !at.IsEquivalentType(oat) {
return false
}
}
return o.ResultType.IsType(other.ResultType)
return o.ResultType.IsEquivalentType(other.ResultType)
}

// SignatureOverlaps indicates whether two functions have non-equal, but overloapping function signatures.
Expand Down Expand Up @@ -644,23 +667,29 @@ func OverloadOperandTrait(trait int) OverloadOpt {
}
}

// NewConstant creates a new constant declaration.
func NewConstant(name string, t *types.Type, v ref.Val) *VariableDecl {
return &VariableDecl{Name: name, Type: t, Value: v}
}

// NewVariable creates a new variable declaration.
func NewVariable(name string, t *types.Type) *VariableDecl {
return &VariableDecl{Name: name, Type: t}
}

// VariableDecl defines a variable declaration which may optionally have a constant value.
type VariableDecl struct {
Name string
Type *types.Type
Name string
Type *types.Type
Value ref.Val
}

// DeclarationEquals returns true if one variable declaration has the same name and same type as the input.
func (v *VariableDecl) DeclarationEquals(other *VariableDecl) bool {
// DeclarationIsEquivalent returns true if one variable declaration has the same name and same type as the input.
func (v *VariableDecl) DeclarationIsEquivalent(other *VariableDecl) bool {
if v == other {
return true
}
return v.Name == other.Name && v.Type.IsType(other.Type)
return v.Name == other.Name && v.Type.IsEquivalentType(other.Type)
}

// VariableDeclToExprDecl converts a go-native variable declaration into a protobuf-type variable declaration.
Expand All @@ -679,9 +708,9 @@ func TypeVariable(t *types.Type) *VariableDecl {

// FunctionDeclToExprDecl converts a go-native function declaration into a protobuf-typed function declaration.
func FunctionDeclToExprDecl(f *FunctionDecl) (*exprpb.Decl, error) {
overloads := make([]*exprpb.Decl_FunctionDecl_Overload, len(f.Overloads))
overloads := make([]*exprpb.Decl_FunctionDecl_Overload, len(f.overloads))
for i, oID := range f.overloadOrdinals {
o := f.Overloads[oID]
o := f.overloads[oID]
paramNames := map[string]struct{}{}
argTypes := make([]*exprpb.Type, len(o.ArgTypes))
for j, a := range o.ArgTypes {
Expand Down
57 changes: 49 additions & 8 deletions common/decls/decls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package decls

import (
"reflect"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -275,15 +276,15 @@ func TestFunctionMerge(t *testing.T) {
if (sizeMerged.Name) != "size" {
t.Errorf("Merge() produced a function with name %v, wanted 'size'", sizeMerged.Name)
}
if len(sizeMerged.Overloads) != 3 {
t.Errorf("Merge() produced %d overloads, wanted 3", len(sizeFunc.Overloads))
if len(sizeMerged.overloads) != 3 {
t.Errorf("Merge() produced %d overloads, wanted 3", len(sizeFunc.overloads))
}
overloads := map[string]bool{
"list_size": true,
"map_size": true,
"vector_size": true,
}
for _, o := range sizeMerged.Overloads {
for _, o := range sizeMerged.overloads {
delete(overloads, o.ID)
}
if len(overloads) != 0 {
Expand Down Expand Up @@ -397,7 +398,7 @@ func TestFunctionMergeSingletonRedefinition(t *testing.T) {
t.Fatalf("NewFunction() failed: %v", err)
}
_, err = sizeFunc.Merge(sizeVecFunc)
if err == nil || !strings.Contains(err.Error(), "already has singleton") {
if err == nil || !strings.Contains(err.Error(), "already has a singleton") {
t.Fatalf("Merge() expected to fail, got: %v", err)
}
}
Expand Down Expand Up @@ -661,6 +662,36 @@ func TestOverloadOperandTrait(t *testing.T) {
}
}

func TestFunctionGetTypeParams(t *testing.T) {
fn, err := NewFunction("deep_type_params",
Overload("no_type_params", []*types.Type{}, types.DynType),
Overload("one_type_param", []*types.Type{types.BoolType}, types.NewTypeParamType("K")),
Overload("deep_type_params",
[]*types.Type{types.NewTypeParamType("E1"),
types.NewMapType(types.NewTypeParamType("K"), types.NewTypeParamType("V"))},
types.NewTypeParamType("V"),
),
)
if err != nil {
t.Fatalf("NewFunction() failed: %v", err)
}
if len(fn.OverloadDecls()) != 3 {
t.Fatal("fn.OverloadDecls() not equal to 3")
}
o1 := fn.OverloadDecls()[0]
o2 := fn.OverloadDecls()[1]
o3 := fn.OverloadDecls()[2]
if len(o1.GetTypeParams()) != 0 {
t.Errorf("overload %v did not have zero type-params", o1)
}
if len(o2.GetTypeParams()) != 1 && !reflect.DeepEqual(o2.GetTypeParams(), []string{"K"}) {
t.Errorf("overload %v did not have a single type param", o2)
}
if len(o3.GetTypeParams()) != 3 {
t.Errorf("overload %v did not have three type params", o3)
}
}

func TestFunctionDisableDeclaration(t *testing.T) {
fn, err := NewFunction("in",
DisableDeclaration(true),
Expand Down Expand Up @@ -959,18 +990,28 @@ func TestFunctionDeclToExprDeclInvalid(t *testing.T) {

func TestNewVariable(t *testing.T) {
a := NewVariable("a", types.BoolType)
if !a.DeclarationEquals(a) {
if !a.DeclarationIsEquivalent(a) {
t.Error("NewVariable(a, bool) does not equal itself")
}
if !a.DeclarationEquals(NewVariable("a", types.BoolType)) {
if !a.DeclarationIsEquivalent(NewVariable("a", types.BoolType)) {
t.Error("NewVariable(a, bool) does not equal itself")
}
a1 := NewVariable("a", types.IntType)
if a.DeclarationEquals(a1) {
if a.DeclarationIsEquivalent(a1) {
t.Error("NewVariable(a, int).DeclarationEquals(NewVariable(a, bool))")
}
}

func TestNewConstant(t *testing.T) {
a := NewConstant("a", types.IntType, types.Int(42))
if !a.DeclarationIsEquivalent(a) {
t.Error("NewConstant(a, int) does not equal itself")
}
if !a.DeclarationIsEquivalent(NewVariable("a", types.IntType)) {
t.Error("NewConstant(a, int) is not declaration equivalent to int variable")
}
}

func TestTypeVariable(t *testing.T) {
tests := []struct {
t *types.Type
Expand All @@ -994,7 +1035,7 @@ func TestTypeVariable(t *testing.T) {
},
}
for _, tst := range tests {
if !TypeVariable(tst.t).DeclarationEquals(tst.v) {
if !TypeVariable(tst.t).DeclarationIsEquivalent(tst.v) {
t.Errorf("got not equal %v.Equals(%v)", TypeVariable(tst.t), tst.v)
}
}
Expand Down
23 changes: 19 additions & 4 deletions common/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,16 +262,29 @@ func (t *Type) HasTrait(trait int) bool {
return trait&t.traitMask == trait
}

// IsType indicates whether two types have the same kind, type name, and parameters.
func (t *Type) IsType(other *Type) bool {
// IsExactType indicates whether the two types are exactly the same. This check also verifies type parameter type names.
func (t *Type) IsExactType(other *Type) bool {
return t.isTypeInternal(other, true)
}

// IsEquivalentType indicates whether two types are equivalent. This check ignores type parameter type names.
func (t *Type) IsEquivalentType(other *Type) bool {
return t.isTypeInternal(other, false)
}

// isTypeInternal checks whether the two types are equivalent or exactly the same based on the checkTypeParamName flag.
func (t *Type) isTypeInternal(other *Type, checkTypeParamName bool) bool {
if t == other {
return true
}
if t.Kind != other.Kind || len(t.Parameters) != len(other.Parameters) {
return false
}
if t.Kind != TypeParamKind && t.DeclaredTypeName() != other.DeclaredTypeName() {
if (checkTypeParamName || t.Kind != TypeParamKind) && t.TypeName() != other.TypeName() {
return false
}
for i, p := range t.Parameters {
if !p.IsType(other.Parameters[i]) {
if !p.isTypeInternal(other.Parameters[i], checkTypeParamName) {
return false
}
}
Expand Down Expand Up @@ -536,6 +549,8 @@ func TypeToExprType(t *Type) (*exprpb.Type, error) {
return chkdecls.Duration, nil
case DynKind:
return chkdecls.Dyn, nil
case ErrorKind:
return chkdecls.Error, nil
case IntKind:
return maybeWrapper(t, chkdecls.Int), nil
case ListKind:
Expand Down
Loading

0 comments on commit b7f8fe1

Please sign in to comment.