diff --git a/gee-rpc/day2-client/client.go b/gee-rpc/day2-client/client.go index 1c8a3cb..4af2ebf 100644 --- a/gee-rpc/day2-client/client.go +++ b/gee-rpc/day2-client/client.go @@ -226,19 +226,21 @@ func newClientCodec(cc codec.Codec, opt *Option) *Client { return client } -func dial(network, address string, opt *Option) (*Client, error) { - conn, err := net.Dial(network, address) +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Option) (client *Client, err error) { + opt, err := parseOptions(opts...) if err != nil { return nil, err } - return NewClient(conn, opt) -} - -// Dial connects to an RPC server at the specified network address -func Dial(network, address string, opts ...*Option) (*Client, error) { - opt, err := parseOptions(opts...) + conn, err := net.Dial(network, address) if err != nil { return nil, err } - return dial(network, address, opt) + // close the connection if client is nil + defer func() { + if client == nil { + _ = conn.Close() + } + }() + return NewClient(conn, opt) } diff --git a/gee-rpc/day3-service/client.go b/gee-rpc/day3-service/client.go index 1c8a3cb..77c9436 100644 --- a/gee-rpc/day3-service/client.go +++ b/gee-rpc/day3-service/client.go @@ -209,7 +209,6 @@ func NewClient(conn net.Conn, opt *Option) (*Client, error) { // send options with server if err := json.NewEncoder(conn).Encode(opt); err != nil { log.Println("rpc client: options error: ", err) - _ = conn.Close() return nil, err } return newClientCodec(f(conn), opt), nil @@ -226,19 +225,21 @@ func newClientCodec(cc codec.Codec, opt *Option) *Client { return client } -func dial(network, address string, opt *Option) (*Client, error) { - conn, err := net.Dial(network, address) +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Option) (client *Client, err error) { + opt, err := parseOptions(opts...) if err != nil { return nil, err } - return NewClient(conn, opt) -} - -// Dial connects to an RPC server at the specified network address -func Dial(network, address string, opts ...*Option) (*Client, error) { - opt, err := parseOptions(opts...) + conn, err := net.Dial(network, address) if err != nil { return nil, err } - return dial(network, address, opt) + // close the connection if client is nil + defer func() { + if client == nil { + _ = conn.Close() + } + }() + return NewClient(conn, opt) } diff --git a/gee-rpc/day4-timeout/client.go b/gee-rpc/day4-timeout/client.go index 45fc9e6..29af7aa 100644 --- a/gee-rpc/day4-timeout/client.go +++ b/gee-rpc/day4-timeout/client.go @@ -207,18 +207,17 @@ func parseOptions(opts ...*Option) (*Option, error) { return opt, nil } -func NewClient(conn net.Conn, opt *Option) (*Client, error) { +func NewClient(conn net.Conn, opt *Option) (client *Client, err error) { f := codec.NewCodecFuncMap[opt.CodecType] if f == nil { - err := fmt.Errorf("invalid codec type %s", opt.CodecType) + err = fmt.Errorf("invalid codec type %s", opt.CodecType) log.Println("rpc client: codec error:", err) - return nil, err + return } // send options with server - if err := json.NewEncoder(conn).Encode(opt); err != nil { + if err = json.NewEncoder(conn).Encode(opt); err != nil { log.Println("rpc client: options error: ", err) - _ = conn.Close() - return nil, err + return } return newClientCodec(f(conn), opt), nil } @@ -239,16 +238,26 @@ type clientResult struct { err error } -type dialFunc func(network, address string, opt *Option) (client *Client, err error) +type newClientFunc func(conn net.Conn, opt *Option) (client *Client, err error) -func dialTimeout(f dialFunc, network, address string, opts ...*Option) (*Client, error) { +func dialTimeout(f newClientFunc, network, address string, opts ...*Option) (client *Client, err error) { opt, err := parseOptions(opts...) if err != nil { return nil, err } + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + // close the connection if client is nil + defer func() { + if client == nil { + _ = conn.Close() + } + }() ch := make(chan clientResult) go func() { - client, err := f(network, address, opt) + client, err := f(conn, opt) ch <- clientResult{client: client, err: err} }() if opt.ConnectTimeout == 0 { @@ -257,21 +266,13 @@ func dialTimeout(f dialFunc, network, address string, opts ...*Option) (*Client, } select { case <-time.After(opt.ConnectTimeout): - return nil, fmt.Errorf("rpc client: dial timeout: expect within %s", opt.ConnectTimeout) + return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout) case result := <-ch: return result.client, result.err } } -func dial(network, address string, opt *Option) (*Client, error) { - conn, err := net.Dial(network, address) - if err != nil { - return nil, err - } - return NewClient(conn, opt) -} - // Dial connects to an RPC server at the specified network address func Dial(network, address string, opts ...*Option) (*Client, error) { - return dialTimeout(dial, network, address, opts...) + return dialTimeout(NewClient, network, address, opts...) } diff --git a/gee-rpc/day4-timeout/client_test.go b/gee-rpc/day4-timeout/client_test.go index c5c1d1f..4488455 100644 --- a/gee-rpc/day4-timeout/client_test.go +++ b/gee-rpc/day4-timeout/client_test.go @@ -26,16 +26,19 @@ func startServer(addr chan string) { func TestClient_dialTimeout(t *testing.T) { t.Parallel() - f := func(network, address string, opt *Option) (client *Client, err error) { + l, _ := net.Listen("tcp", ":0") + + f := func(conn net.Conn, opt *Option) (client *Client, err error) { + _ = conn.Close() time.Sleep(time.Second * 2) return nil, nil } t.Run("timeout", func(t *testing.T) { - _, err := dialTimeout(f, "", "", &Option{ConnectTimeout: time.Second}) - _assert(err != nil && strings.Contains(err.Error(), "dial timeout"), "expect a timeout error") + _, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: time.Second}) + _assert(err != nil && strings.Contains(err.Error(), "connect timeout"), "expect a timeout error") }) t.Run("0", func(t *testing.T) { - _, err := dialTimeout(f, "", "", &Option{ConnectTimeout: 0}) + _, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: 0}) _assert(err == nil, "0 means no limit") }) } diff --git a/gee-rpc/day5-http-debug/client.go b/gee-rpc/day5-http-debug/client.go index bdc2c0f..b8b57f7 100644 --- a/gee-rpc/day5-http-debug/client.go +++ b/gee-rpc/day5-http-debug/client.go @@ -242,16 +242,26 @@ type clientResult struct { err error } -type dialFunc func(network, address string, opt *Option) (client *Client, err error) +type newClientFunc func(conn net.Conn, opt *Option) (client *Client, err error) -func dialTimeout(f dialFunc, network, address string, opts ...*Option) (*Client, error) { +func dialTimeout(f newClientFunc, network, address string, opts ...*Option) (client *Client, err error) { opt, err := parseOptions(opts...) if err != nil { return nil, err } + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + // close the connection if client is nil + defer func() { + if client == nil { + _ = conn.Close() + } + }() ch := make(chan clientResult) go func() { - client, err := f(network, address, opt) + client, err := f(conn, opt) ch <- clientResult{client: client, err: err} }() if opt.ConnectTimeout == 0 { @@ -260,30 +270,19 @@ func dialTimeout(f dialFunc, network, address string, opts ...*Option) (*Client, } select { case <-time.After(opt.ConnectTimeout): - return nil, fmt.Errorf("rpc client: dial timeout: expect within %s", opt.ConnectTimeout) + return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout) case result := <-ch: return result.client, result.err } } -func dial(network, address string, opt *Option) (*Client, error) { - conn, err := net.Dial(network, address) - if err != nil { - return nil, err - } - return NewClient(conn, opt) -} - // Dial connects to an RPC server at the specified network address func Dial(network, address string, opts ...*Option) (*Client, error) { - return dialTimeout(dial, network, address, opts...) + return dialTimeout(NewClient, network, address, opts...) } -func dialHTTP(network, address string, opt *Option) (*Client, error) { - conn, err := net.Dial(network, address) - if err != nil { - return nil, err - } +// NewHTTPClient new a Client instance via HTTP as transport protocol +func NewHTTPClient(conn net.Conn, opt *Option) (*Client, error) { _, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", defaultRPCPath)) // Require successful HTTP response @@ -295,14 +294,13 @@ func dialHTTP(network, address string, opt *Option) (*Client, error) { if err == nil { err = errors.New("unexpected HTTP response: " + resp.Status) } - _ = conn.Close() return nil, err } // DialHTTP connects to an HTTP RPC server at the specified network address // listening on the default HTTP RPC path. func DialHTTP(network, address string, opts ...*Option) (*Client, error) { - return dialTimeout(dialHTTP, network, address, opts...) + return dialTimeout(NewHTTPClient, network, address, opts...) } // XDial calls different functions to connect to a RPC server diff --git a/gee-rpc/day5-http-debug/client_test.go b/gee-rpc/day5-http-debug/client_test.go index bd46675..3b13cb0 100644 --- a/gee-rpc/day5-http-debug/client_test.go +++ b/gee-rpc/day5-http-debug/client_test.go @@ -28,16 +28,19 @@ func startServer(addr chan string) { func TestClient_dialTimeout(t *testing.T) { t.Parallel() - f := func(network, address string, opt *Option) (client *Client, err error) { + l, _ := net.Listen("tcp", ":0") + + f := func(conn net.Conn, opt *Option) (client *Client, err error) { + _ = conn.Close() time.Sleep(time.Second * 2) return nil, nil } t.Run("timeout", func(t *testing.T) { - _, err := dialTimeout(f, "", "", &Option{ConnectTimeout: time.Second}) - _assert(err != nil && strings.Contains(err.Error(), "dial timeout"), "expect a timeout error") + _, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: time.Second}) + _assert(err != nil && strings.Contains(err.Error(), "connect timeout"), "expect a timeout error") }) t.Run("0", func(t *testing.T) { - _, err := dialTimeout(f, "", "", &Option{ConnectTimeout: 0}) + _, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: 0}) _assert(err == nil, "0 means no limit") }) } diff --git a/gee-rpc/day6-load-balance/client.go b/gee-rpc/day6-load-balance/client.go index bdc2c0f..b8b57f7 100644 --- a/gee-rpc/day6-load-balance/client.go +++ b/gee-rpc/day6-load-balance/client.go @@ -242,16 +242,26 @@ type clientResult struct { err error } -type dialFunc func(network, address string, opt *Option) (client *Client, err error) +type newClientFunc func(conn net.Conn, opt *Option) (client *Client, err error) -func dialTimeout(f dialFunc, network, address string, opts ...*Option) (*Client, error) { +func dialTimeout(f newClientFunc, network, address string, opts ...*Option) (client *Client, err error) { opt, err := parseOptions(opts...) if err != nil { return nil, err } + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + // close the connection if client is nil + defer func() { + if client == nil { + _ = conn.Close() + } + }() ch := make(chan clientResult) go func() { - client, err := f(network, address, opt) + client, err := f(conn, opt) ch <- clientResult{client: client, err: err} }() if opt.ConnectTimeout == 0 { @@ -260,30 +270,19 @@ func dialTimeout(f dialFunc, network, address string, opts ...*Option) (*Client, } select { case <-time.After(opt.ConnectTimeout): - return nil, fmt.Errorf("rpc client: dial timeout: expect within %s", opt.ConnectTimeout) + return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout) case result := <-ch: return result.client, result.err } } -func dial(network, address string, opt *Option) (*Client, error) { - conn, err := net.Dial(network, address) - if err != nil { - return nil, err - } - return NewClient(conn, opt) -} - // Dial connects to an RPC server at the specified network address func Dial(network, address string, opts ...*Option) (*Client, error) { - return dialTimeout(dial, network, address, opts...) + return dialTimeout(NewClient, network, address, opts...) } -func dialHTTP(network, address string, opt *Option) (*Client, error) { - conn, err := net.Dial(network, address) - if err != nil { - return nil, err - } +// NewHTTPClient new a Client instance via HTTP as transport protocol +func NewHTTPClient(conn net.Conn, opt *Option) (*Client, error) { _, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", defaultRPCPath)) // Require successful HTTP response @@ -295,14 +294,13 @@ func dialHTTP(network, address string, opt *Option) (*Client, error) { if err == nil { err = errors.New("unexpected HTTP response: " + resp.Status) } - _ = conn.Close() return nil, err } // DialHTTP connects to an HTTP RPC server at the specified network address // listening on the default HTTP RPC path. func DialHTTP(network, address string, opts ...*Option) (*Client, error) { - return dialTimeout(dialHTTP, network, address, opts...) + return dialTimeout(NewHTTPClient, network, address, opts...) } // XDial calls different functions to connect to a RPC server diff --git a/gee-rpc/day6-load-balance/client_test.go b/gee-rpc/day6-load-balance/client_test.go index bd46675..3b13cb0 100644 --- a/gee-rpc/day6-load-balance/client_test.go +++ b/gee-rpc/day6-load-balance/client_test.go @@ -28,16 +28,19 @@ func startServer(addr chan string) { func TestClient_dialTimeout(t *testing.T) { t.Parallel() - f := func(network, address string, opt *Option) (client *Client, err error) { + l, _ := net.Listen("tcp", ":0") + + f := func(conn net.Conn, opt *Option) (client *Client, err error) { + _ = conn.Close() time.Sleep(time.Second * 2) return nil, nil } t.Run("timeout", func(t *testing.T) { - _, err := dialTimeout(f, "", "", &Option{ConnectTimeout: time.Second}) - _assert(err != nil && strings.Contains(err.Error(), "dial timeout"), "expect a timeout error") + _, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: time.Second}) + _assert(err != nil && strings.Contains(err.Error(), "connect timeout"), "expect a timeout error") }) t.Run("0", func(t *testing.T) { - _, err := dialTimeout(f, "", "", &Option{ConnectTimeout: 0}) + _, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: 0}) _assert(err == nil, "0 means no limit") }) } diff --git a/gee-rpc/day7-registry/client.go b/gee-rpc/day7-registry/client.go index bdc2c0f..b8b57f7 100644 --- a/gee-rpc/day7-registry/client.go +++ b/gee-rpc/day7-registry/client.go @@ -242,16 +242,26 @@ type clientResult struct { err error } -type dialFunc func(network, address string, opt *Option) (client *Client, err error) +type newClientFunc func(conn net.Conn, opt *Option) (client *Client, err error) -func dialTimeout(f dialFunc, network, address string, opts ...*Option) (*Client, error) { +func dialTimeout(f newClientFunc, network, address string, opts ...*Option) (client *Client, err error) { opt, err := parseOptions(opts...) if err != nil { return nil, err } + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + // close the connection if client is nil + defer func() { + if client == nil { + _ = conn.Close() + } + }() ch := make(chan clientResult) go func() { - client, err := f(network, address, opt) + client, err := f(conn, opt) ch <- clientResult{client: client, err: err} }() if opt.ConnectTimeout == 0 { @@ -260,30 +270,19 @@ func dialTimeout(f dialFunc, network, address string, opts ...*Option) (*Client, } select { case <-time.After(opt.ConnectTimeout): - return nil, fmt.Errorf("rpc client: dial timeout: expect within %s", opt.ConnectTimeout) + return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout) case result := <-ch: return result.client, result.err } } -func dial(network, address string, opt *Option) (*Client, error) { - conn, err := net.Dial(network, address) - if err != nil { - return nil, err - } - return NewClient(conn, opt) -} - // Dial connects to an RPC server at the specified network address func Dial(network, address string, opts ...*Option) (*Client, error) { - return dialTimeout(dial, network, address, opts...) + return dialTimeout(NewClient, network, address, opts...) } -func dialHTTP(network, address string, opt *Option) (*Client, error) { - conn, err := net.Dial(network, address) - if err != nil { - return nil, err - } +// NewHTTPClient new a Client instance via HTTP as transport protocol +func NewHTTPClient(conn net.Conn, opt *Option) (*Client, error) { _, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", defaultRPCPath)) // Require successful HTTP response @@ -295,14 +294,13 @@ func dialHTTP(network, address string, opt *Option) (*Client, error) { if err == nil { err = errors.New("unexpected HTTP response: " + resp.Status) } - _ = conn.Close() return nil, err } // DialHTTP connects to an HTTP RPC server at the specified network address // listening on the default HTTP RPC path. func DialHTTP(network, address string, opts ...*Option) (*Client, error) { - return dialTimeout(dialHTTP, network, address, opts...) + return dialTimeout(NewHTTPClient, network, address, opts...) } // XDial calls different functions to connect to a RPC server diff --git a/gee-rpc/day7-registry/client_test.go b/gee-rpc/day7-registry/client_test.go index bd46675..3b13cb0 100644 --- a/gee-rpc/day7-registry/client_test.go +++ b/gee-rpc/day7-registry/client_test.go @@ -28,16 +28,19 @@ func startServer(addr chan string) { func TestClient_dialTimeout(t *testing.T) { t.Parallel() - f := func(network, address string, opt *Option) (client *Client, err error) { + l, _ := net.Listen("tcp", ":0") + + f := func(conn net.Conn, opt *Option) (client *Client, err error) { + _ = conn.Close() time.Sleep(time.Second * 2) return nil, nil } t.Run("timeout", func(t *testing.T) { - _, err := dialTimeout(f, "", "", &Option{ConnectTimeout: time.Second}) - _assert(err != nil && strings.Contains(err.Error(), "dial timeout"), "expect a timeout error") + _, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: time.Second}) + _assert(err != nil && strings.Contains(err.Error(), "connect timeout"), "expect a timeout error") }) t.Run("0", func(t *testing.T) { - _, err := dialTimeout(f, "", "", &Option{ConnectTimeout: 0}) + _, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: 0}) _assert(err == nil, "0 means no limit") }) } diff --git a/gee-rpc/doc/geerpc-day2.md b/gee-rpc/doc/geerpc-day2.md index 2d356e3..bf4db48 100644 --- a/gee-rpc/doc/geerpc-day2.md +++ b/gee-rpc/doc/geerpc-day2.md @@ -243,21 +243,23 @@ func parseOptions(opts ...*Option) (*Option, error) { return opt, nil } -func dial(network, address string, opt *Option) (*Client, error) { - conn, err := net.Dial(network, address) +// Dial connects to an RPC server at the specified network address +func Dial(network, address string, opts ...*Option) (client *Client, err error) { + opt, err := parseOptions(opts...) if err != nil { return nil, err } - return NewClient(conn, opt) -} - -// Dial connects to an RPC server at the specified network address -func Dial(network, address string, opts ...*Option) (*Client, error) { - opt, err := parseOptions(opts...) + conn, err := net.Dial(network, address) if err != nil { return nil, err } - return dial(network, address, opt) + // close the connection if client is nil + defer func() { + if client == nil { + _ = conn.Close() + } + }() + return NewClient(conn, opt) } ```