Skip to content

Commit

Permalink
errgroup: use WithCancelCause to cancel context
Browse files Browse the repository at this point in the history
Fixes golang/go#59355

Change-Id: Ib6a88e7e5fefe7b0d5672035af16d109aabcbf1e
Reviewed-on: https://go-review.googlesource.com/c/sync/+/481255
TryBot-Result: Gopher Robot <[email protected]>
Run-TryBot: Bryan Mills <[email protected]>
Reviewed-by: Bryan Mills <[email protected]>
Run-TryBot: Ian Lance Taylor <[email protected]>
Reviewed-by: Michael Knyszek <[email protected]>
Auto-Submit: Bryan Mills <[email protected]>
  • Loading branch information
jonjohnsonjr authored and gopherbot committed Jun 1, 2023
1 parent 4966af6 commit 93782cc
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 5 deletions.
10 changes: 5 additions & 5 deletions errgroup/errgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type token struct{}
// A zero Group is valid, has no limit on the number of active goroutines,
// and does not cancel on error.
type Group struct {
cancel func()
cancel func(error)

wg sync.WaitGroup

Expand All @@ -43,7 +43,7 @@ func (g *Group) done() {
// returns a non-nil error or the first time Wait returns, whichever occurs
// first.
func WithContext(ctx context.Context) (*Group, context.Context) {
ctx, cancel := context.WithCancel(ctx)
ctx, cancel := withCancelCause(ctx)
return &Group{cancel: cancel}, ctx
}

Expand All @@ -52,7 +52,7 @@ func WithContext(ctx context.Context) (*Group, context.Context) {
func (g *Group) Wait() error {
g.wg.Wait()
if g.cancel != nil {
g.cancel()
g.cancel(g.err)
}
return g.err
}
Expand All @@ -76,7 +76,7 @@ func (g *Group) Go(f func() error) {
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
g.cancel()
g.cancel(g.err)
}
})
}
Expand Down Expand Up @@ -105,7 +105,7 @@ func (g *Group) TryGo(f func() error) bool {
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
g.cancel()
g.cancel(g.err)
}
})
}
Expand Down
14 changes: 14 additions & 0 deletions errgroup/go120.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build go1.20
// +build go1.20

package errgroup

import "context"

func withCancelCause(parent context.Context) (context.Context, func(error)) {
return context.WithCancelCause(parent)
}
55 changes: 55 additions & 0 deletions errgroup/go120_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build go1.20
// +build go1.20

package errgroup_test

import (
"context"
"errors"
"testing"

"golang.org/x/sync/errgroup"
)

func TestCancelCause(t *testing.T) {
errDoom := errors.New("group_test: doomed")

cases := []struct {
errs []error
want error
}{
{want: nil},
{errs: []error{nil}, want: nil},
{errs: []error{errDoom}, want: errDoom},
{errs: []error{errDoom, nil}, want: errDoom},
}

for _, tc := range cases {
g, ctx := errgroup.WithContext(context.Background())

for _, err := range tc.errs {
err := err
g.TryGo(func() error { return err })
}

if err := g.Wait(); err != tc.want {
t.Errorf("after %T.TryGo(func() error { return err }) for err in %v\n"+
"g.Wait() = %v; want %v",
g, tc.errs, err, tc.want)
}

if tc.want == nil {
tc.want = context.Canceled
}

if err := context.Cause(ctx); err != tc.want {
t.Errorf("after %T.TryGo(func() error { return err }) for err in %v\n"+
"context.Cause(ctx) = %v; tc.want %v",
g, tc.errs, err, tc.want)
}
}
}
15 changes: 15 additions & 0 deletions errgroup/pre_go120.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build !go1.20
// +build !go1.20

package errgroup

import "context"

func withCancelCause(parent context.Context) (context.Context, func(error)) {
ctx, cancel := context.WithCancel(parent)
return ctx, func(error) { cancel() }
}

0 comments on commit 93782cc

Please sign in to comment.