diff --git a/internal/grpcutil/target.go b/internal/grpcutil/target.go index 3e1b22f5a8c0..baa3479f12ca 100644 --- a/internal/grpcutil/target.go +++ b/internal/grpcutil/target.go @@ -48,9 +48,10 @@ func ParseTarget(target string, skipUnixColonParsing bool) (ret resolver.Target) ret.Scheme, ret.Endpoint, ok = split2(target, "://") if !ok { if strings.HasPrefix(target, "unix:") && !skipUnixColonParsing { - // Handle the "unix:[path]" case, because splitting on :// only - // handles the "unix://[/absolute/path]" case. Only handle if the - // dialer is nil, to avoid a behavior change with custom dialers. + // Handle the "unix:[local/path]" and "unix:[/absolute/path]" cases, + // because splitting on :// only handles the + // "unix://[/absolute/path]" case. Only handle if the dialer is nil, + // to avoid a behavior change with custom dialers. return resolver.Target{Scheme: "unix", Endpoint: target[len("unix:"):]} } return resolver.Target{Endpoint: target} @@ -61,7 +62,7 @@ func ParseTarget(target string, skipUnixColonParsing bool) (ret resolver.Target) } if ret.Scheme == "unix" { // Add the "/" back in the unix case, so the unix resolver receives the - // actual endpoint. + // actual endpoint in the "unix://[/absolute/path]" case. ret.Endpoint = "/" + ret.Endpoint } return ret diff --git a/internal/grpcutil/target_test.go b/internal/grpcutil/target_test.go index 562390bfe381..6083e84d75ae 100644 --- a/internal/grpcutil/target_test.go +++ b/internal/grpcutil/target_test.go @@ -70,6 +70,7 @@ func TestParseTargetString(t *testing.T) { // If we can only parse part of the target. {targetStr: "://", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "://"}}, {targetStr: "unix://domain", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix://domain"}}, + {targetStr: "unix://a/b/c", want: resolver.Target{Scheme: "unix", Authority: "a", Endpoint: "/b/c"}}, {targetStr: "a:b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a:b"}}, {targetStr: "a/b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a/b"}}, {targetStr: "a:/b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a:/b"}}, @@ -77,10 +78,13 @@ func TestParseTargetString(t *testing.T) { {targetStr: "a://b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a://b"}}, // Unix cases without custom dialer. - // unix:[local_path] and unix:[/absolute] have different behaviors with - // a custom dialer, to prevent behavior changes with custom dialers. - {targetStr: "unix:domain", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "domain"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:domain"}}, - {targetStr: "unix:/domain", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/domain"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:/domain"}}, + // unix:[local_path], unix:[/absolute], and unix://[/absolute] have different + // behaviors with a custom dialer, to prevent behavior changes with custom dialers. + {targetStr: "unix:a/b/c", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "a/b/c"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:a/b/c"}}, + {targetStr: "unix:/a/b/c", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/a/b/c"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:/a/b/c"}}, + {targetStr: "unix:///a/b/c", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/a/b/c"}}, + + {targetStr: "passthrough:///unix:///a/b/c", want: resolver.Target{Scheme: "passthrough", Authority: "", Endpoint: "unix:///a/b/c"}}, } { got := ParseTarget(test.targetStr, false) if got != test.want { diff --git a/internal/resolver/unix/unix.go b/internal/resolver/unix/unix.go index d046e50613d0..8e78d47197c8 100644 --- a/internal/resolver/unix/unix.go +++ b/internal/resolver/unix/unix.go @@ -20,6 +20,8 @@ package unix import ( + "fmt" + "google.golang.org/grpc/internal/transport/networktype" "google.golang.org/grpc/resolver" ) @@ -29,6 +31,9 @@ const scheme = "unix" type builder struct{} func (*builder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) { + if target.Authority != "" { + return nil, fmt.Errorf("invalid (non-empty) authority: %v", target.Authority) + } cc.UpdateState(resolver.State{Addresses: []resolver.Address{networktype.Set(resolver.Address{Addr: target.Endpoint}, "unix")}}) return &nopResolver{}, nil } diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 4778ed16252b..4c9f9740ec71 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -140,17 +140,26 @@ type http2Client struct { } func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr resolver.Address, useProxy bool, grpcUA string) (net.Conn, error) { + address := addr.Addr + networkType, ok := networktype.Get(addr) if fn != nil { - return fn(ctx, addr.Addr) + if networkType == "unix" { + // For backward compatibility, if the user dialed "unix:///path", + // the passthrough resolver would be used and the user's custom + // dialer would see "unix:///path". Since the unix resolver is used + // and the address is now "/path", prepend "unix://" so the user's + // custom dialer sees the same address. + return fn(ctx, "unix://"+address) + } + return fn(ctx, address) } - networkType := "tcp" - if n, ok := networktype.Get(addr); ok { - networkType = n + if !ok { + networkType, address = parseDialTarget(address) } if networkType == "tcp" && useProxy { - return proxyDial(ctx, addr.Addr, grpcUA) + return proxyDial(ctx, address, grpcUA) } - return (&net.Dialer{}).DialContext(ctx, networkType, addr.Addr) + return (&net.Dialer{}).DialContext(ctx, networkType, address) } func isTemporary(err error) bool { diff --git a/internal/transport/http_util.go b/internal/transport/http_util.go index 4d15afbf73f1..7e41d1183f93 100644 --- a/internal/transport/http_util.go +++ b/internal/transport/http_util.go @@ -27,6 +27,7 @@ import ( "math" "net" "net/http" + "net/url" "strconv" "strings" "time" @@ -598,3 +599,31 @@ func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, maxHeaderList f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil) return f } + +// parseDialTarget returns the network and address to pass to dialer. +func parseDialTarget(target string) (string, string) { + net := "tcp" + m1 := strings.Index(target, ":") + m2 := strings.Index(target, ":/") + // handle unix:addr which will fail with url.Parse + if m1 >= 0 && m2 < 0 { + if n := target[0:m1]; n == "unix" { + return n, target[m1+1:] + } + } + if m2 >= 0 { + t, err := url.Parse(target) + if err != nil { + return net, target + } + scheme := t.Scheme + addr := t.Path + if scheme == "unix" { + if addr == "" { + addr = t.Host + } + return scheme, addr + } + } + return net, target +} diff --git a/internal/transport/http_util_test.go b/internal/transport/http_util_test.go index 85a083f6c8a8..2205050acea0 100644 --- a/internal/transport/http_util_test.go +++ b/internal/transport/http_util_test.go @@ -250,3 +250,32 @@ func (s) TestDecodeHeaderH2ErrCode(t *testing.T) { }) } } + +func (s) TestParseDialTarget(t *testing.T) { + for _, test := range []struct { + target, wantNet, wantAddr string + }{ + {"unix:a", "unix", "a"}, + {"unix:a/b/c", "unix", "a/b/c"}, + {"unix:/a", "unix", "/a"}, + {"unix:/a/b/c", "unix", "/a/b/c"}, + {"unix://a", "unix", "a"}, + {"unix://a/b/c", "unix", "/b/c"}, + {"unix:///a", "unix", "/a"}, + {"unix:///a/b/c", "unix", "/a/b/c"}, + {"unix:etcd:0", "unix", "etcd:0"}, + {"unix:///tmp/unix-3", "unix", "/tmp/unix-3"}, + {"unix://domain", "unix", "domain"}, + {"unix://etcd:0", "unix", "etcd:0"}, + {"unix:///etcd:0", "unix", "/etcd:0"}, + {"passthrough://unix://domain", "tcp", "passthrough://unix://domain"}, + {"https://google.com:443", "tcp", "https://google.com:443"}, + {"dns:///google.com", "tcp", "dns:///google.com"}, + {"/unix/socket/address", "tcp", "/unix/socket/address"}, + } { + gotNet, gotAddr := parseDialTarget(test.target) + if gotNet != test.wantNet || gotAddr != test.wantAddr { + t.Errorf("parseDialTarget(%q) = %s, %s want %s, %s", test.target, gotNet, gotAddr, test.wantNet, test.wantAddr) + } + } +} diff --git a/test/authority_test.go b/test/authority_test.go index e8307fee225d..e6599b2fde04 100644 --- a/test/authority_test.go +++ b/test/authority_test.go @@ -80,35 +80,46 @@ func runUnixTest(t *testing.T, address, target, expectedAuthority string, dialer } } +type authorityTest struct { + name string + address string + target string + authority string + dialTargetWant string +} + +var authorityTests = []authorityTest{ + { + name: "UnixRelative", + address: "sock.sock", + target: "unix:sock.sock", + authority: "localhost", + }, + { + name: "UnixAbsolute", + address: "/tmp/sock.sock", + target: "unix:/tmp/sock.sock", + authority: "localhost", + }, + { + name: "UnixAbsoluteAlternate", + address: "/tmp/sock.sock", + target: "unix:///tmp/sock.sock", + authority: "localhost", + }, + { + name: "UnixPassthrough", + address: "/tmp/sock.sock", + target: "passthrough:///unix:///tmp/sock.sock", + authority: "unix:///tmp/sock.sock", + dialTargetWant: "unix:///tmp/sock.sock", + }, +} + // TestUnix does end to end tests with the various supported unix target -// formats, ensuring that the authority is set to localhost in every case. +// formats, ensuring that the authority is set as expected. func (s) TestUnix(t *testing.T) { - tests := []struct { - name string - address string - target string - authority string - }{ - { - name: "UnixRelative", - address: "sock.sock", - target: "unix:sock.sock", - authority: "localhost", - }, - { - name: "UnixAbsolute", - address: "/tmp/sock.sock", - target: "unix:/tmp/sock.sock", - authority: "localhost", - }, - { - name: "UnixAbsoluteAlternate", - address: "/tmp/sock.sock", - target: "unix:///tmp/sock.sock", - authority: "localhost", - }, - } - for _, test := range tests { + for _, test := range authorityTests { t.Run(test.name, func(t *testing.T) { runUnixTest(t, test.address, test.target, test.authority, nil) }) @@ -119,30 +130,14 @@ func (s) TestUnix(t *testing.T) { // formats, ensuring that the target sent to the dialer does NOT have the // "unix:" prefix stripped. func (s) TestUnixCustomDialer(t *testing.T) { - tests := []struct { - name string - address string - target string - authority string - }{ - { - name: "UnixRelative", - address: "sock.sock", - target: "unix:sock.sock", - authority: "localhost", - }, - { - name: "UnixAbsolute", - address: "/tmp/sock.sock", - target: "unix:/tmp/sock.sock", - authority: "localhost", - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { + for _, test := range authorityTests { + t.Run(test.name+"WithDialer", func(t *testing.T) { + if test.dialTargetWant == "" { + test.dialTargetWant = test.target + } dialer := func(ctx context.Context, address string) (net.Conn, error) { - if address != test.target { - return nil, fmt.Errorf("expected target %v in custom dialer, instead got %v", test.target, address) + if address != test.dialTargetWant { + return nil, fmt.Errorf("expected target %v in custom dialer, instead got %v", test.dialTargetWant, address) } address = address[len("unix:"):] return (&net.Dialer{}).DialContext(ctx, "unix", address) @@ -152,6 +147,8 @@ func (s) TestUnixCustomDialer(t *testing.T) { } } +// TestColonPortAuthority does an end to end test with the target for grpc.Dial +// being ":[port]". Ensures authority is "localhost:[port]". func (s) TestColonPortAuthority(t *testing.T) { expectedAuthority := "" var authorityMu sync.Mutex