Skip to content

Commit

Permalink
slices: add sorting and comparison functions
Browse files Browse the repository at this point in the history
Now that the `cmp` package exists, sorting and comparison functions from
`x/exp/slices` can be ported to the standard library, using the
`cmp.Ordered` type and the `cmp.Less` and `cmp.Compare` functions.

This move also includes adjustments to the discussions in #60091 w.r.t.
NaN handling and cmp vs. less functions, and adds Min/Max functions.
The final API is taken from
golang/go#60091 (comment)

Updates #60091

Change-Id: Id7e6c88035b60d4ddd0c48dd82add8e8bc4e22d3
Reviewed-on: https://go-review.googlesource.com/c/go/+/496078
Reviewed-by: Ian Lance Taylor <[email protected]>
Reviewed-by: Eli Bendersky <[email protected]>
Run-TryBot: Eli Bendersky‎ <[email protected]>
TryBot-Result: Gopher Robot <[email protected]>
  • Loading branch information
eliben committed May 23, 2023
1 parent 6b7aab7 commit 0df6812
Show file tree
Hide file tree
Showing 10 changed files with 2,080 additions and 11 deletions.
13 changes: 13 additions & 0 deletions api/next/60091.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
pkg slices, func BinarySearchFunc[$0 interface{}, $1 interface{}]([]$0, $1, func($0, $1) int) (int, bool) #60091
pkg slices, func BinarySearch[$0 cmp.Ordered]([]$0, $0) (int, bool) #60091
pkg slices, func CompareFunc[$0 interface{}, $1 interface{}]([]$0, []$1, func($0, $1) int) int #60091
pkg slices, func Compare[$0 cmp.Ordered]([]$0, []$0) int #60091
pkg slices, func IsSortedFunc[$0 interface{}]([]$0, func($0, $0) int) bool #60091
pkg slices, func IsSorted[$0 cmp.Ordered]([]$0) bool #60091
pkg slices, func MaxFunc[$0 interface{}]([]$0, func($0, $0) int) $0 #60091
pkg slices, func Max[$0 cmp.Ordered]([]$0) $0 #60091
pkg slices, func MinFunc[$0 interface{}]([]$0, func($0, $0) int) $0 #60091
pkg slices, func Min[$0 cmp.Ordered]([]$0) $0 #60091
pkg slices, func SortFunc[$0 interface{}]([]$0, func($0, $0) int) #60091
pkg slices, func SortStableFunc[$0 interface{}]([]$0, func($0, $0) int) #60091
pkg slices, func Sort[$0 cmp.Ordered]([]$0) #60091
9 changes: 5 additions & 4 deletions src/go/build/deps_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,6 @@ var depsRules = `
unicode/utf8, unicode/utf16, unicode,
unsafe;
# slices depends on unsafe for overlapping check.
unsafe
< slices;
# These packages depend only on internal/goarch and unsafe.
internal/goarch, unsafe
< internal/abi;
Expand Down Expand Up @@ -227,6 +223,11 @@ var depsRules = `
< hash
< hash/adler32, hash/crc32, hash/crc64, hash/fnv;
# slices depends on unsafe for overlapping check, cmp for comparison
# semantics, and math/bits for # calculating bitlength of numbers.
unsafe, cmp, math/bits
< slices;
# math/big
FMT, encoding/binary, math/rand
< math/big;
Expand Down
45 changes: 45 additions & 0 deletions src/slices/slices.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package slices

import (
"cmp"
"unsafe"
)

Expand Down Expand Up @@ -44,6 +45,50 @@ func EqualFunc[E1, E2 any](s1 []E1, s2 []E2, eq func(E1, E2) bool) bool {
return true
}

// Compare compares the elements of s1 and s2, using [cmp.Compare] on each pair
// of elements. The elements are compared sequentially, starting at index 0,
// until one element is not equal to the other.
// The result of comparing the first non-matching elements is returned.
// If both slices are equal until one of them ends, the shorter slice is
// considered less than the longer one.
// The result is 0 if s1 == s2, -1 if s1 < s2, and +1 if s1 > s2.
func Compare[E cmp.Ordered](s1, s2 []E) int {
for i, v1 := range s1 {
if i >= len(s2) {
return +1
}
v2 := s2[i]
if c := cmp.Compare(v1, v2); c != 0 {
return c
}
}
if len(s1) < len(s2) {
return -1
}
return 0
}

// CompareFunc is like Compare but uses a custom comparison function on each
// pair of elements.
// The result is the first non-zero result of cmp; if cmp always
// returns 0 the result is 0 if len(s1) == len(s2), -1 if len(s1) < len(s2),
// and +1 if len(s1) > len(s2).
func CompareFunc[E1, E2 any](s1 []E1, s2 []E2, cmp func(E1, E2) int) int {
for i, v1 := range s1 {
if i >= len(s2) {
return +1
}
v2 := s2[i]
if c := cmp(v1, v2); c != 0 {
return c
}
}
if len(s1) < len(s2) {
return -1
}
return 0
}

// Index returns the index of the first occurrence of v in s,
// or -1 if not present.
func Index[E comparable](s []E, v E) int {
Expand Down
208 changes: 208 additions & 0 deletions src/slices/slices_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package slices

import (
"cmp"
"internal/race"
"internal/testenv"
"math"
Expand Down Expand Up @@ -134,6 +135,213 @@ func BenchmarkEqualFunc_Large(b *testing.B) {
}
}

var compareIntTests = []struct {
s1, s2 []int
want int
}{
{
[]int{1},
[]int{1},
0,
},
{
[]int{1},
[]int{},
1,
},
{
[]int{},
[]int{1},
-1,
},
{
[]int{},
[]int{},
0,
},
{
[]int{1, 2, 3},
[]int{1, 2, 3},
0,
},
{
[]int{1, 2, 3},
[]int{1, 2, 3, 4},
-1,
},
{
[]int{1, 2, 3, 4},
[]int{1, 2, 3},
+1,
},
{
[]int{1, 2, 3},
[]int{1, 4, 3},
-1,
},
{
[]int{1, 4, 3},
[]int{1, 2, 3},
+1,
},
{
[]int{1, 4, 3},
[]int{1, 2, 3, 8, 9},
+1,
},
}

var compareFloatTests = []struct {
s1, s2 []float64
want int
}{
{
[]float64{},
[]float64{},
0,
},
{
[]float64{1},
[]float64{1},
0,
},
{
[]float64{math.NaN()},
[]float64{math.NaN()},
0,
},
{
[]float64{1, 2, math.NaN()},
[]float64{1, 2, math.NaN()},
0,
},
{
[]float64{1, math.NaN(), 3},
[]float64{1, math.NaN(), 4},
-1,
},
{
[]float64{1, math.NaN(), 3},
[]float64{1, 2, 4},
-1,
},
{
[]float64{1, math.NaN(), 3},
[]float64{1, 2, math.NaN()},
-1,
},
{
[]float64{1, 2, 3},
[]float64{1, 2, math.NaN()},
+1,
},
{
[]float64{1, 2, 3},
[]float64{1, math.NaN(), 3},
+1,
},
{
[]float64{1, math.NaN(), 3, 4},
[]float64{1, 2, math.NaN()},
-1,
},
}

func TestCompare(t *testing.T) {
intWant := func(want bool) string {
if want {
return "0"
}
return "!= 0"
}
for _, test := range equalIntTests {
if got := Compare(test.s1, test.s2); (got == 0) != test.want {
t.Errorf("Compare(%v, %v) = %d, want %s", test.s1, test.s2, got, intWant(test.want))
}
}
for _, test := range equalFloatTests {
if got := Compare(test.s1, test.s2); (got == 0) != test.wantEqualNaN {
t.Errorf("Compare(%v, %v) = %d, want %s", test.s1, test.s2, got, intWant(test.wantEqualNaN))
}
}

for _, test := range compareIntTests {
if got := Compare(test.s1, test.s2); got != test.want {
t.Errorf("Compare(%v, %v) = %d, want %d", test.s1, test.s2, got, test.want)
}
}
for _, test := range compareFloatTests {
if got := Compare(test.s1, test.s2); got != test.want {
t.Errorf("Compare(%v, %v) = %d, want %d", test.s1, test.s2, got, test.want)
}
}
}

func equalToCmp[T comparable](eq func(T, T) bool) func(T, T) int {
return func(v1, v2 T) int {
if eq(v1, v2) {
return 0
}
return 1
}
}

func TestCompareFunc(t *testing.T) {
intWant := func(want bool) string {
if want {
return "0"
}
return "!= 0"
}
for _, test := range equalIntTests {
if got := CompareFunc(test.s1, test.s2, equalToCmp(equal[int])); (got == 0) != test.want {
t.Errorf("CompareFunc(%v, %v, equalToCmp(equal[int])) = %d, want %s", test.s1, test.s2, got, intWant(test.want))
}
}
for _, test := range equalFloatTests {
if got := CompareFunc(test.s1, test.s2, equalToCmp(equal[float64])); (got == 0) != test.wantEqual {
t.Errorf("CompareFunc(%v, %v, equalToCmp(equal[float64])) = %d, want %s", test.s1, test.s2, got, intWant(test.wantEqual))
}
}

for _, test := range compareIntTests {
if got := CompareFunc(test.s1, test.s2, cmp.Compare[int]); got != test.want {
t.Errorf("CompareFunc(%v, %v, cmp[int]) = %d, want %d", test.s1, test.s2, got, test.want)
}
}
for _, test := range compareFloatTests {
if got := CompareFunc(test.s1, test.s2, cmp.Compare[float64]); got != test.want {
t.Errorf("CompareFunc(%v, %v, cmp[float64]) = %d, want %d", test.s1, test.s2, got, test.want)
}
}

s1 := []int{1, 2, 3}
s2 := []int{2, 3, 4}
if got := CompareFunc(s1, s2, equalToCmp(offByOne)); got != 0 {
t.Errorf("CompareFunc(%v, %v, offByOne) = %d, want 0", s1, s2, got)
}

s3 := []string{"a", "b", "c"}
s4 := []string{"A", "B", "C"}
if got := CompareFunc(s3, s4, strings.Compare); got != 1 {
t.Errorf("CompareFunc(%v, %v, strings.Compare) = %d, want 1", s3, s4, got)
}

compareLower := func(v1, v2 string) int {
return strings.Compare(strings.ToLower(v1), strings.ToLower(v2))
}
if got := CompareFunc(s3, s4, compareLower); got != 0 {
t.Errorf("CompareFunc(%v, %v, compareLower) = %d, want 0", s3, s4, got)
}

cmpIntString := func(v1 int, v2 string) int {
return strings.Compare(string(rune(v1)-1+'a'), v2)
}
if got := CompareFunc(s1, s3, cmpIntString); got != 0 {
t.Errorf("CompareFunc(%v, %v, cmpIntString) = %d, want 0", s1, s3, got)
}
}

var indexTests = []struct {
s []int
v int
Expand Down
Loading

0 comments on commit 0df6812

Please sign in to comment.