Skip to content

Commit

Permalink
gee-rpc: if client is nil, close the connection
Browse files Browse the repository at this point in the history
  • Loading branch information
geektutu committed Oct 7, 2020
1 parent 804fde7 commit f0af6b8
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 123 deletions.
20 changes: 11 additions & 9 deletions gee-rpc/day2-client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,19 +226,21 @@ func newClientCodec(cc codec.Codec, opt *Option) *Client {
return client
}

func dial(network, address string, opt *Option) (*Client, error) {
conn, err := net.Dial(network, address)
// Dial connects to an RPC server at the specified network address
func Dial(network, address string, opts ...*Option) (client *Client, err error) {
opt, err := parseOptions(opts...)
if err != nil {
return nil, err
}
return NewClient(conn, opt)
}

// Dial connects to an RPC server at the specified network address
func Dial(network, address string, opts ...*Option) (*Client, error) {
opt, err := parseOptions(opts...)
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return dial(network, address, opt)
// close the connection if client is nil
defer func() {
if client == nil {
_ = conn.Close()
}
}()
return NewClient(conn, opt)
}
21 changes: 11 additions & 10 deletions gee-rpc/day3-service/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ func NewClient(conn net.Conn, opt *Option) (*Client, error) {
// send options with server
if err := json.NewEncoder(conn).Encode(opt); err != nil {
log.Println("rpc client: options error: ", err)
_ = conn.Close()
return nil, err
}
return newClientCodec(f(conn), opt), nil
Expand All @@ -226,19 +225,21 @@ func newClientCodec(cc codec.Codec, opt *Option) *Client {
return client
}

func dial(network, address string, opt *Option) (*Client, error) {
conn, err := net.Dial(network, address)
// Dial connects to an RPC server at the specified network address
func Dial(network, address string, opts ...*Option) (client *Client, err error) {
opt, err := parseOptions(opts...)
if err != nil {
return nil, err
}
return NewClient(conn, opt)
}

// Dial connects to an RPC server at the specified network address
func Dial(network, address string, opts ...*Option) (*Client, error) {
opt, err := parseOptions(opts...)
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return dial(network, address, opt)
// close the connection if client is nil
defer func() {
if client == nil {
_ = conn.Close()
}
}()
return NewClient(conn, opt)
}
39 changes: 20 additions & 19 deletions gee-rpc/day4-timeout/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,18 +207,17 @@ func parseOptions(opts ...*Option) (*Option, error) {
return opt, nil
}

func NewClient(conn net.Conn, opt *Option) (*Client, error) {
func NewClient(conn net.Conn, opt *Option) (client *Client, err error) {
f := codec.NewCodecFuncMap[opt.CodecType]
if f == nil {
err := fmt.Errorf("invalid codec type %s", opt.CodecType)
err = fmt.Errorf("invalid codec type %s", opt.CodecType)
log.Println("rpc client: codec error:", err)
return nil, err
return
}
// send options with server
if err := json.NewEncoder(conn).Encode(opt); err != nil {
if err = json.NewEncoder(conn).Encode(opt); err != nil {
log.Println("rpc client: options error: ", err)
_ = conn.Close()
return nil, err
return
}
return newClientCodec(f(conn), opt), nil
}
Expand All @@ -239,16 +238,26 @@ type clientResult struct {
err error
}

type dialFunc func(network, address string, opt *Option) (client *Client, err error)
type newClientFunc func(conn net.Conn, opt *Option) (client *Client, err error)

func dialTimeout(f dialFunc, network, address string, opts ...*Option) (*Client, error) {
func dialTimeout(f newClientFunc, network, address string, opts ...*Option) (client *Client, err error) {
opt, err := parseOptions(opts...)
if err != nil {
return nil, err
}
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
// close the connection if client is nil
defer func() {
if client == nil {
_ = conn.Close()
}
}()
ch := make(chan clientResult)
go func() {
client, err := f(network, address, opt)
client, err := f(conn, opt)
ch <- clientResult{client: client, err: err}
}()
if opt.ConnectTimeout == 0 {
Expand All @@ -257,21 +266,13 @@ func dialTimeout(f dialFunc, network, address string, opts ...*Option) (*Client,
}
select {
case <-time.After(opt.ConnectTimeout):
return nil, fmt.Errorf("rpc client: dial timeout: expect within %s", opt.ConnectTimeout)
return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout)
case result := <-ch:
return result.client, result.err
}
}

func dial(network, address string, opt *Option) (*Client, error) {
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return NewClient(conn, opt)
}

// Dial connects to an RPC server at the specified network address
func Dial(network, address string, opts ...*Option) (*Client, error) {
return dialTimeout(dial, network, address, opts...)
return dialTimeout(NewClient, network, address, opts...)
}
11 changes: 7 additions & 4 deletions gee-rpc/day4-timeout/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,19 @@ func startServer(addr chan string) {

func TestClient_dialTimeout(t *testing.T) {
t.Parallel()
f := func(network, address string, opt *Option) (client *Client, err error) {
l, _ := net.Listen("tcp", ":0")

f := func(conn net.Conn, opt *Option) (client *Client, err error) {
_ = conn.Close()
time.Sleep(time.Second * 2)
return nil, nil
}
t.Run("timeout", func(t *testing.T) {
_, err := dialTimeout(f, "", "", &Option{ConnectTimeout: time.Second})
_assert(err != nil && strings.Contains(err.Error(), "dial timeout"), "expect a timeout error")
_, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: time.Second})
_assert(err != nil && strings.Contains(err.Error(), "connect timeout"), "expect a timeout error")
})
t.Run("0", func(t *testing.T) {
_, err := dialTimeout(f, "", "", &Option{ConnectTimeout: 0})
_, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: 0})
_assert(err == nil, "0 means no limit")
})
}
Expand Down
38 changes: 18 additions & 20 deletions gee-rpc/day5-http-debug/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,16 +242,26 @@ type clientResult struct {
err error
}

type dialFunc func(network, address string, opt *Option) (client *Client, err error)
type newClientFunc func(conn net.Conn, opt *Option) (client *Client, err error)

func dialTimeout(f dialFunc, network, address string, opts ...*Option) (*Client, error) {
func dialTimeout(f newClientFunc, network, address string, opts ...*Option) (client *Client, err error) {
opt, err := parseOptions(opts...)
if err != nil {
return nil, err
}
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
// close the connection if client is nil
defer func() {
if client == nil {
_ = conn.Close()
}
}()
ch := make(chan clientResult)
go func() {
client, err := f(network, address, opt)
client, err := f(conn, opt)
ch <- clientResult{client: client, err: err}
}()
if opt.ConnectTimeout == 0 {
Expand All @@ -260,30 +270,19 @@ func dialTimeout(f dialFunc, network, address string, opts ...*Option) (*Client,
}
select {
case <-time.After(opt.ConnectTimeout):
return nil, fmt.Errorf("rpc client: dial timeout: expect within %s", opt.ConnectTimeout)
return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout)
case result := <-ch:
return result.client, result.err
}
}

func dial(network, address string, opt *Option) (*Client, error) {
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return NewClient(conn, opt)
}

// Dial connects to an RPC server at the specified network address
func Dial(network, address string, opts ...*Option) (*Client, error) {
return dialTimeout(dial, network, address, opts...)
return dialTimeout(NewClient, network, address, opts...)
}

func dialHTTP(network, address string, opt *Option) (*Client, error) {
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
// NewHTTPClient new a Client instance via HTTP as transport protocol
func NewHTTPClient(conn net.Conn, opt *Option) (*Client, error) {
_, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", defaultRPCPath))

// Require successful HTTP response
Expand All @@ -295,14 +294,13 @@ func dialHTTP(network, address string, opt *Option) (*Client, error) {
if err == nil {
err = errors.New("unexpected HTTP response: " + resp.Status)
}
_ = conn.Close()
return nil, err
}

// DialHTTP connects to an HTTP RPC server at the specified network address
// listening on the default HTTP RPC path.
func DialHTTP(network, address string, opts ...*Option) (*Client, error) {
return dialTimeout(dialHTTP, network, address, opts...)
return dialTimeout(NewHTTPClient, network, address, opts...)
}

// XDial calls different functions to connect to a RPC server
Expand Down
11 changes: 7 additions & 4 deletions gee-rpc/day5-http-debug/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,19 @@ func startServer(addr chan string) {

func TestClient_dialTimeout(t *testing.T) {
t.Parallel()
f := func(network, address string, opt *Option) (client *Client, err error) {
l, _ := net.Listen("tcp", ":0")

f := func(conn net.Conn, opt *Option) (client *Client, err error) {
_ = conn.Close()
time.Sleep(time.Second * 2)
return nil, nil
}
t.Run("timeout", func(t *testing.T) {
_, err := dialTimeout(f, "", "", &Option{ConnectTimeout: time.Second})
_assert(err != nil && strings.Contains(err.Error(), "dial timeout"), "expect a timeout error")
_, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: time.Second})
_assert(err != nil && strings.Contains(err.Error(), "connect timeout"), "expect a timeout error")
})
t.Run("0", func(t *testing.T) {
_, err := dialTimeout(f, "", "", &Option{ConnectTimeout: 0})
_, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: 0})
_assert(err == nil, "0 means no limit")
})
}
Expand Down
38 changes: 18 additions & 20 deletions gee-rpc/day6-load-balance/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,16 +242,26 @@ type clientResult struct {
err error
}

type dialFunc func(network, address string, opt *Option) (client *Client, err error)
type newClientFunc func(conn net.Conn, opt *Option) (client *Client, err error)

func dialTimeout(f dialFunc, network, address string, opts ...*Option) (*Client, error) {
func dialTimeout(f newClientFunc, network, address string, opts ...*Option) (client *Client, err error) {
opt, err := parseOptions(opts...)
if err != nil {
return nil, err
}
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
// close the connection if client is nil
defer func() {
if client == nil {
_ = conn.Close()
}
}()
ch := make(chan clientResult)
go func() {
client, err := f(network, address, opt)
client, err := f(conn, opt)
ch <- clientResult{client: client, err: err}
}()
if opt.ConnectTimeout == 0 {
Expand All @@ -260,30 +270,19 @@ func dialTimeout(f dialFunc, network, address string, opts ...*Option) (*Client,
}
select {
case <-time.After(opt.ConnectTimeout):
return nil, fmt.Errorf("rpc client: dial timeout: expect within %s", opt.ConnectTimeout)
return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout)
case result := <-ch:
return result.client, result.err
}
}

func dial(network, address string, opt *Option) (*Client, error) {
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return NewClient(conn, opt)
}

// Dial connects to an RPC server at the specified network address
func Dial(network, address string, opts ...*Option) (*Client, error) {
return dialTimeout(dial, network, address, opts...)
return dialTimeout(NewClient, network, address, opts...)
}

func dialHTTP(network, address string, opt *Option) (*Client, error) {
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
// NewHTTPClient new a Client instance via HTTP as transport protocol
func NewHTTPClient(conn net.Conn, opt *Option) (*Client, error) {
_, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", defaultRPCPath))

// Require successful HTTP response
Expand All @@ -295,14 +294,13 @@ func dialHTTP(network, address string, opt *Option) (*Client, error) {
if err == nil {
err = errors.New("unexpected HTTP response: " + resp.Status)
}
_ = conn.Close()
return nil, err
}

// DialHTTP connects to an HTTP RPC server at the specified network address
// listening on the default HTTP RPC path.
func DialHTTP(network, address string, opts ...*Option) (*Client, error) {
return dialTimeout(dialHTTP, network, address, opts...)
return dialTimeout(NewHTTPClient, network, address, opts...)
}

// XDial calls different functions to connect to a RPC server
Expand Down
11 changes: 7 additions & 4 deletions gee-rpc/day6-load-balance/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,19 @@ func startServer(addr chan string) {

func TestClient_dialTimeout(t *testing.T) {
t.Parallel()
f := func(network, address string, opt *Option) (client *Client, err error) {
l, _ := net.Listen("tcp", ":0")

f := func(conn net.Conn, opt *Option) (client *Client, err error) {
_ = conn.Close()
time.Sleep(time.Second * 2)
return nil, nil
}
t.Run("timeout", func(t *testing.T) {
_, err := dialTimeout(f, "", "", &Option{ConnectTimeout: time.Second})
_assert(err != nil && strings.Contains(err.Error(), "dial timeout"), "expect a timeout error")
_, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: time.Second})
_assert(err != nil && strings.Contains(err.Error(), "connect timeout"), "expect a timeout error")
})
t.Run("0", func(t *testing.T) {
_, err := dialTimeout(f, "", "", &Option{ConnectTimeout: 0})
_, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: 0})
_assert(err == nil, "0 means no limit")
})
}
Expand Down

0 comments on commit f0af6b8

Please sign in to comment.