diff --git a/README.rst b/README.rst index ea2640d..e85b478 100644 --- a/README.rst +++ b/README.rst @@ -417,6 +417,33 @@ funk.MustSet ............ Short hand for funk.Set if struct does not contain interface{} field type to discard errors. +funk.Prune +.......... +Copy a struct with only selected fields. Slice is handled by pruning all elements. + +.. code-block:: go + bar := &Bar{ + Name: "Test", + } + + foo1 := &Foo{ + ID: 1, + FirstName: "Dark", + LastName: "Vador", + Bar: bar, + } + + pruned, _ := Prune(foo1, []string{"FirstName", "Bar.Name"}) + // *Foo{ + // ID: 0, + // FirstName: "Dark", + // LastName: "", + // Bar: &Bar{Name: "Test}, + // } + +funk.PruneByTag +.......... +Same functionality as funk.Prune, but uses struct tags instead of struct field names. funk.Keys ......... diff --git a/funk_test.go b/funk_test.go index 2a607bd..b8dcd5e 100644 --- a/funk_test.go +++ b/funk_test.go @@ -8,7 +8,7 @@ type Model interface { // Bar is type Bar struct { - Name string + Name string `tag_name:"BarName"` Bar *Bar Bars []*Bar } @@ -23,7 +23,7 @@ type Foo struct { FirstName string `tag_name:"tag 1"` LastName string `tag_name:"tag 2"` Age int `tag_name:"tag 3"` - Bar *Bar + Bar *Bar `tag_name:"tag 4"` Bars []*Bar EmptyValue sql.NullInt64 diff --git a/transform.go b/transform.go index f2429ce..ef9dc1f 100644 --- a/transform.go +++ b/transform.go @@ -4,6 +4,7 @@ import ( "fmt" "math/rand" "reflect" + "strings" ) // Chunk creates an array of elements split into groups with the length of size. @@ -394,3 +395,114 @@ func Drop(in interface{}, n int) interface{} { panic(fmt.Sprintf("Type %s is not supported by Drop", valueType.String())) } + +// Prune returns a copy of "in" that only contains fields in "paths" +// which are looked up using struct field name. +// For lookup paths by field tag instead, use funk.PruneByTag() +func Prune(in interface{}, paths []string) (interface{}, error) { + return pruneByTag(in, paths, nil /*tag*/) +} + +// pruneByTag returns a copy of "in" that only contains fields in "paths" +// which are looked up using struct field Tag "tag". +func PruneByTag(in interface{}, paths []string, tag string) (interface{}, error) { + return pruneByTag(in, paths, &tag) +} + +// pruneByTag returns a copy of "in" that only contains fields in "paths" +// which are looked up using struct field Tag "tag". If tag is nil, +// traverse paths using struct field name +func pruneByTag(in interface{}, paths []string, tag *string) (interface{}, error) { + + inValue := reflect.ValueOf(in) + + ret := reflect.New(inValue.Type()).Elem() + + for _, path := range paths { + parts := strings.Split(path, ".") + if err := prune(inValue, ret, parts, tag); err != nil { + return nil, err + } + } + return ret.Interface(), nil +} + +func prune(inValue reflect.Value, ret reflect.Value, parts []string, tag *string) error { + + if len(parts) == 0 { + // we reached the location that ret needs to hold inValue + // Note: The value at the end of the path is not copied, maybe we need to change. + // ret and the original data holds the same reference to this value + ret.Set(inValue) + return nil + } + + inKind := inValue.Kind() + + switch inKind { + case reflect.Ptr: + if inValue.IsNil() { + // TODO validate + return nil + } + if ret.IsNil() { + // init ret and go to next level + ret.Set(reflect.New(inValue.Type().Elem())) + } + return prune(inValue.Elem(), ret.Elem(), parts, tag) + case reflect.Struct: + part := parts[0] + var fValue reflect.Value + var fRet reflect.Value + if tag == nil { + // use field name + fValue = inValue.FieldByName(part) + if !fValue.IsValid() { + return fmt.Errorf("field name %v is not found in struct %v", part, inValue.Type().String()) + } + fRet = ret.FieldByName(part) + } else { + // search tag that has key equal to part + found := false + for i := 0; i < inValue.NumField(); i++ { + f := inValue.Type().Field(i) + if key, ok := f.Tag.Lookup(*tag); ok { + if key == part { + fValue = inValue.Field(i) + fRet = ret.Field(i) + found = true + break + } + } + } + if !found { + return fmt.Errorf("Struct tag %v is not found with key %v", *tag, part) + } + } + // init Ret is zero and go down one more level + if fRet.IsZero() { + fRet.Set(reflect.New(fValue.Type()).Elem()) + } + return prune(fValue, fRet, parts[1:], tag) + case reflect.Array, reflect.Slice: + // set all its elements + length := inValue.Len() + // init ret + if ret.IsZero() { + if inKind == reflect.Slice { + ret.Set(reflect.MakeSlice(inValue.Type(), length /*len*/, length /*cap*/)) + } else { // array + ret.Set(reflect.New(inValue.Type()).Elem()) + } + } + for j := 0; j < length; j++ { + if err := prune(inValue.Index(j), ret.Index(j), parts, tag); err != nil { + return err + } + } + default: + return fmt.Errorf("path %v cannot be looked up on kind of %v", strings.Join(parts, "."), inValue.Kind()) + } + + return nil +} diff --git a/transform_test.go b/transform_test.go index 2c4a0dc..d52d0d0 100644 --- a/transform_test.go +++ b/transform_test.go @@ -1,11 +1,13 @@ package funk import ( + "database/sql" "fmt" "reflect" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMap(t *testing.T) { @@ -166,3 +168,294 @@ func TestDrop(t *testing.T) { is.Equal([]int{2, 3, 0, 0, 12}, results) } + +func TestPrune(t *testing.T) { + + var testCases = []struct { + OriginalFoo *Foo + Paths []string + ExpectedFoo *Foo + }{ + { + foo, + []string{"FirstName"}, + &Foo{ + FirstName: foo.FirstName, + }, + }, + { + foo, + []string{"FirstName", "ID"}, + &Foo{ + FirstName: foo.FirstName, + ID: foo.ID, + }, + }, + { + foo, + []string{"EmptyValue.Int64"}, + &Foo{ + EmptyValue: sql.NullInt64{ + Int64: foo.EmptyValue.Int64, + }, + }, + }, + { + foo, + []string{"FirstName", "ID", "EmptyValue.Int64"}, + &Foo{ + FirstName: foo.FirstName, + ID: foo.ID, + EmptyValue: sql.NullInt64{ + Int64: foo.EmptyValue.Int64, + }, + }, + }, + { + foo, + []string{"FirstName", "ID", "EmptyValue.Int64"}, + &Foo{ + FirstName: foo.FirstName, + ID: foo.ID, + EmptyValue: sql.NullInt64{ + Int64: foo.EmptyValue.Int64, + }, + }, + }, + { + foo, + []string{"FirstName", "ID", "Bar"}, + &Foo{ + FirstName: foo.FirstName, + ID: foo.ID, + Bar: foo.Bar, + }, + }, + { + foo, + []string{"Bar", "Bars"}, + &Foo{ + Bar: foo.Bar, + Bars: foo.Bars, + }, + }, + { + foo, + []string{"FirstName", "Bars.Name"}, + &Foo{ + FirstName: foo.FirstName, + Bars: []*Bar{ + {Name: bar.Name}, + {Name: bar.Name}, + }, + }, + }, + { + foo, + []string{"Bars.Name", "Bars.Bars.Name"}, + &Foo{ + Bars: []*Bar{ + {Name: bar.Name, Bars: []*Bar{{Name: "Level1-1"}, {Name: "Level1-2"}}}, + {Name: bar.Name, Bars: []*Bar{{Name: "Level1-1"}, {Name: "Level1-2"}}}, + }, + }, + }, + { + foo, + []string{"BarInterface", "BarPointer"}, + &Foo{ + BarInterface: bar, + BarPointer: &bar, + }, + }, + } + + // pass to prune by pointer to struct + for idx, tc := range testCases { + t.Run(fmt.Sprintf("Prune pointer test case #%v", idx), func(t *testing.T) { + is := assert.New(t) + res, err := Prune(tc.OriginalFoo, tc.Paths) + require.NoError(t, err) + + fooPrune := res.(*Foo) + is.Equal(tc.ExpectedFoo, fooPrune) + }) + } + + // pass to prune by struct directly + for idx, tc := range testCases { + t.Run(fmt.Sprintf("Prune non pointer test case #%v", idx), func(t *testing.T) { + is := assert.New(t) + fooNonPtr := *tc.OriginalFoo + res, err := Prune(fooNonPtr, tc.Paths) + require.NoError(t, err) + + fooPrune := res.(Foo) + is.Equal(*tc.ExpectedFoo, fooPrune) + }) + } + + // test PruneByTag + var TagTestCases = []struct { + OriginalFoo *Foo + Paths []string + ExpectedFoo *Foo + Tag string + }{ + { + foo, + []string{"tag 1", "tag 4.BarName"}, + &Foo{ + FirstName: foo.FirstName, + Bar: &Bar{ + Name: bar.Name, + }, + }, + "tag_name", + }, + } + + for idx, tc := range TagTestCases { + t.Run(fmt.Sprintf("PruneByTag test case #%v", idx), func(t *testing.T) { + is := assert.New(t) + fooNonPtr := *tc.OriginalFoo + res, err := PruneByTag(fooNonPtr, tc.Paths, tc.Tag) + require.NoError(t, err) + + fooPrune := res.(Foo) + is.Equal(*tc.ExpectedFoo, fooPrune) + }) + } + + t.Run("Bar Slice", func(t *testing.T) { + barSlice := []*Bar{bar, bar} + barSlicePruned, err := pruneByTag(barSlice, []string{"Name"}, nil /*tag*/) + require.NoError(t, err) + assert.Equal(t, []*Bar{{Name: bar.Name}, {Name: bar.Name}}, barSlicePruned) + }) + + t.Run("Bar Array", func(t *testing.T) { + barArr := [2]*Bar{bar, bar} + barArrPruned, err := pruneByTag(barArr, []string{"Name"}, nil /*tag*/) + require.NoError(t, err) + assert.Equal(t, [2]*Bar{{Name: bar.Name}, {Name: bar.Name}}, barArrPruned) + }) + + // test values are copied and not referenced in return result + // NOTE: pointers at the end of path are referenced. Maybe we need to make a copy + t.Run("Copy Value Str", func(t *testing.T) { + is := assert.New(t) + fooTest := &Foo{ + Bar: &Bar{ + Name: "bar", + }, + } + res, err := pruneByTag(fooTest, []string{"Bar.Name"}, nil) + require.NoError(t, err) + fooTestPruned := res.(*Foo) + is.Equal(fooTest, fooTestPruned) + + // change pruned + fooTestPruned.Bar.Name = "changed bar" + // check original is unchanged + is.Equal(fooTest.Bar.Name, "bar") + }) + + // error cases + var errCases = []struct { + InputFoo *Foo + Paths []string + TagName *string + }{ + { + foo, + []string{"NotExist"}, + nil, + }, + { + foo, + []string{"FirstName.NotExist", "LastName"}, + nil, + }, + { + foo, + []string{"LastName", "FirstName.NotExist"}, + nil, + }, + { + foo, + []string{"LastName", "Bars.NotExist"}, + nil, + }, + // tags + { + foo, + []string{"tag 999"}, + &[]string{"tag_name"}[0], + }, + { + foo, + []string{"tag 1.NotExist"}, + &[]string{"tag_name"}[0], + }, + { + foo, + []string{"tag 4.NotExist"}, + &[]string{"tag_name"}[0], + }, + { + foo, + []string{"FirstName"}, + &[]string{"tag_name_not_exist"}[0], + }, + } + + for idx, errTC := range errCases { + t.Run(fmt.Sprintf("error test case #%v", idx), func(t *testing.T) { + _, err := pruneByTag(errTC.InputFoo, errTC.Paths, errTC.TagName) + assert.Error(t, err) + }) + } +} + +func ExamplePrune() { + type ExampleFoo struct { + ExampleFooPtr *ExampleFoo `json:"example_foo_ptr"` + Name string `json:"name"` + Number int `json:"number"` + } + + exampleFoo := ExampleFoo{ + ExampleFooPtr: &ExampleFoo{ + Name: "ExampleFooPtr", + Number: 2, + }, + Name: "ExampleFoo", + Number: 1, + } + + // prune using struct field name + res, _ := Prune(exampleFoo, []string{"ExampleFooPtr.Name", "Number"}) + prunedFoo := res.(ExampleFoo) + fmt.Println(prunedFoo.ExampleFooPtr.Name) + fmt.Println(prunedFoo.ExampleFooPtr.Number) + fmt.Println(prunedFoo.Name) + fmt.Println(prunedFoo.Number) + + // prune using struct json tag + res2, _ := PruneByTag(exampleFoo, []string{"example_foo_ptr.name", "number"}, "json") + prunedByTagFoo := res2.(ExampleFoo) + fmt.Println(prunedByTagFoo.ExampleFooPtr.Name) + fmt.Println(prunedByTagFoo.ExampleFooPtr.Number) + fmt.Println(prunedByTagFoo.Name) + fmt.Println(prunedByTagFoo.Number) + // output: + // ExampleFooPtr + // 0 + // + // 1 + // ExampleFooPtr + // 0 + // + // 1 +}