Skip to content

Commit

Permalink
build success resp for websocket proxy (#1189)
Browse files Browse the repository at this point in the history
* build success resp for websocket proxy

* fix test
  • Loading branch information
suchen-sci committed Jan 6, 2024
1 parent 784d30b commit 415d633
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
27 changes: 22 additions & 5 deletions pkg/filters/proxies/httpproxy/wspool.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,22 @@ func (sp *WebSocketServerPool) buildFailureResponse(ctx *context.Context, status
ctx.SetOutputResponse(resp)
}

// buildSuccessResponse builds a success response for WebSocket.
// The response is from WebSocket dial process.
// Here we leave the response body alone, and only set the status code and headers.
// The response body will be handled by the WebSocket protocol.
func (sp *WebSocketServerPool) buildSuccessResponse(ctx *context.Context, resp *http.Response) {
newResp, _ := ctx.GetOutputResponse().(*httpprot.Response)
if newResp == nil {
newResp, _ = httpprot.NewResponse(nil)
}
if resp != nil {
newResp.SetStatusCode(resp.StatusCode)
newResp.Std().Header = resp.Header.Clone()
}
ctx.SetOutputResponse(newResp)
}

func buildServerURL(svr *Server, req *httpprot.Request) (string, error) {
u := *req.URL()
u1, err := url.ParseRequestURI(svr.URL)
Expand All @@ -112,10 +128,10 @@ func buildServerURL(svr *Server, req *httpprot.Request) (string, error) {
return u.String(), nil
}

func (sp *WebSocketServerPool) dialServer(svr *Server, req *httpprot.Request) (*websocket.Conn, error) {
func (sp *WebSocketServerPool) dialServer(svr *Server, req *httpprot.Request) (*websocket.Conn, *http.Response, error) {
u, err := buildServerURL(svr, req)
if err != nil {
return nil, err
return nil, nil, err
}

opts := &websocket.DialOptions{
Expand Down Expand Up @@ -161,11 +177,11 @@ func (sp *WebSocketServerPool) dialServer(svr *Server, req *httpprot.Request) (*
opts.HTTPHeader.Set(xForwardedProto, "https")
}

conn, _, err := websocket.Dial(stdctx.Background(), u, opts)
conn, resp, err := websocket.Dial(stdctx.Background(), u, opts)
if err == nil && (sp.spec.ServerMaxMsgSize > 0 || sp.spec.ServerMaxMsgSize == -1) {
conn.SetReadLimit(sp.spec.ServerMaxMsgSize)
}
return conn, err
return conn, resp, err
}

func (sp *WebSocketServerPool) handle(ctx *context.Context) (result string) {
Expand Down Expand Up @@ -214,7 +230,7 @@ func (sp *WebSocketServerPool) handle(ctx *context.Context) (result string) {
clntConn.SetReadLimit(sp.spec.ClientMaxMsgSize)
}

svrConn, err := sp.dialServer(svr, req)
svrConn, resp, err := sp.dialServer(svr, req)
if err != nil {
logger.Errorf("%s: dial to %s failed: %v", sp.Name, svr.URL, err)
clntConn.Close(websocket.StatusGoingAway, "")
Expand Down Expand Up @@ -296,6 +312,7 @@ func (sp *WebSocketServerPool) handle(ctx *context.Context) (result string) {
wg.Wait()
close(stop)

sp.buildSuccessResponse(ctx, resp)
metric.StatusCode = http.StatusSwitchingProtocols
ctx.SetData("HTTP_METRIC", metric)
return
Expand Down
16 changes: 8 additions & 8 deletions pkg/filters/proxies/httpproxy/wspool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,36 +39,36 @@ func TestDialServer(t *testing.T) {
req, _ := httpprot.NewRequest(stdr)

svr.URL = "####"
_, err := sp.dialServer(svr, req)
_, _, err := sp.dialServer(svr, req)
assert.Error(err)

svr.URL = "http:https://127.0.0.1:9999"
_, err = sp.dialServer(svr, req)
_, _, err = sp.dialServer(svr, req)
assert.Error(err)

svr.URL = "https://127.0.0.1:9999"
_, err = sp.dialServer(svr, req)
_, _, err = sp.dialServer(svr, req)
assert.Error(err)

svr.URL = "tcp:https://127.0.0.1:9999"
_, err = sp.dialServer(svr, req)
_, _, err = sp.dialServer(svr, req)
assert.Error(err)

svr.URL = "ws:https://127.0.0.1:9999"
_, err = sp.dialServer(svr, req)
_, _, err = sp.dialServer(svr, req)
assert.Error(err)

stdr.Header.Add("Origin", "$#@#@$#$#")
_, err = sp.dialServer(svr, req)
_, _, err = sp.dialServer(svr, req)
assert.Error(err)

stdr.Header.Set("Origin", "http:https://127.0.0.1/hello")
stdr.RemoteAddr = "127.0.0.1:8080"
stdr.TLS = &tls.ConnectionState{}
_, err = sp.dialServer(svr, req)
_, _, err = sp.dialServer(svr, req)
assert.Error(err)

stdr.Header.Set("X-Forwarded-For", "192.168.1.1")
_, err = sp.dialServer(svr, req)
_, _, err = sp.dialServer(svr, req)
assert.Error(err)
}

0 comments on commit 415d633

Please sign in to comment.