Skip to content

Commit

Permalink
Use the websocket protocol header, verify selected protocol
Browse files Browse the repository at this point in the history
Kubernetes-commit: b394aac4ce36457bd37459a58b4c3536d2f43d86
  • Loading branch information
liggitt authored and k8s-publishing-bot committed Feb 27, 2024
1 parent 2f00261 commit e8b5ff9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
27 changes: 23 additions & 4 deletions transport/websocket/roundtripper.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ package websocket

import (
"crypto/tls"
"errors"
"fmt"
"net/http"
"net/url"

gwebsocket "github.com/gorilla/websocket"

"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
utilnet "k8s.io/apimachinery/pkg/util/net"
restclient "k8s.io/client-go/rest"
"k8s.io/client-go/transport"
Expand Down Expand Up @@ -88,8 +90,8 @@ func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response
}()

// set the protocol version directly on the dialer from the header
protocolVersions := request.Header[httpstream.HeaderProtocolVersion]
delete(request.Header, httpstream.HeaderProtocolVersion)
protocolVersions := request.Header[wsstream.WebSocketProtocolHeader]
delete(request.Header, wsstream.WebSocketProtocolHeader)

dialer := gwebsocket.Dialer{
Proxy: rt.Proxier,
Expand All @@ -108,7 +110,23 @@ func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response
}
wsConn, resp, err := dialer.DialContext(request.Context(), request.URL.String(), request.Header)
if err != nil {
return nil, &httpstream.UpgradeFailureError{Cause: err}
if errors.Is(err, gwebsocket.ErrBadHandshake) {
return nil, &httpstream.UpgradeFailureError{Cause: err}
}
return nil, err
}

// Ensure we got back a protocol we understand
foundProtocol := false
for _, protocolVersion := range protocolVersions {
if protocolVersion == wsConn.Subprotocol() {
foundProtocol = true
break
}
}
if !foundProtocol {
wsConn.Close() // nolint:errcheck
return nil, &httpstream.UpgradeFailureError{Cause: fmt.Errorf("invalid protocol, expected one of %q, got %q", protocolVersions, wsConn.Subprotocol())}
}

rt.Conn = wsConn
Expand Down Expand Up @@ -149,7 +167,8 @@ func RoundTripperFor(config *restclient.Config) (http.RoundTripper, ConnectionHo
// a WebSocket connection. Upon success, it returns the negotiated connection.
// The round tripper rt must use the WebSocket round tripper wsRt - see RoundTripperFor.
func Negotiate(rt http.RoundTripper, connectionInfo ConnectionHolder, req *http.Request, protocols ...string) (*gwebsocket.Conn, error) {
req.Header[httpstream.HeaderProtocolVersion] = protocols
// Plumb protocols to RoundTripper#RoundTrip
req.Header[wsstream.WebSocketProtocolHeader] = protocols
resp, err := rt.RoundTrip(req)
if err != nil {
return nil, err
Expand Down
5 changes: 3 additions & 2 deletions transport/websocket/roundtripper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) {
rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host})
require.NoError(t, err)
requestedProtocol := remotecommand.StreamProtocolV5Name
req.Header[httpstream.HeaderProtocolVersion] = []string{requestedProtocol}
req.Header[wsstream.WebSocketProtocolHeader] = []string{requestedProtocol}
_, err = rt.RoundTrip(req)
require.NoError(t, err)
// WebSocket Connection is stored in websocket RoundTripper.
Expand Down Expand Up @@ -83,11 +83,12 @@ func TestWebSocketRoundTripper_RoundTripperFails(t *testing.T) {
require.NoError(t, err)
// Requested subprotocol version 1 is not supported by test websocket server.
requestedProtocol := remotecommand.StreamProtocolV1Name
req.Header[httpstream.HeaderProtocolVersion] = []string{requestedProtocol}
req.Header[wsstream.WebSocketProtocolHeader] = []string{requestedProtocol}
_, err = rt.RoundTrip(req)
// Ensure a "bad handshake" error is returned, since requested protocol is not supported.
require.Error(t, err)
assert.True(t, strings.Contains(err.Error(), "bad handshake"))
assert.True(t, httpstream.IsUpgradeFailure(err))
}

func TestWebSocketRoundTripper_NegotiateCreatesConnection(t *testing.T) {
Expand Down

0 comments on commit e8b5ff9

Please sign in to comment.