diff --git a/pkg/filters/proxies/httpproxy/wspool.go b/pkg/filters/proxies/httpproxy/wspool.go index 2e65ee5113..ca216495a6 100644 --- a/pkg/filters/proxies/httpproxy/wspool.go +++ b/pkg/filters/proxies/httpproxy/wspool.go @@ -89,6 +89,22 @@ func (sp *WebSocketServerPool) buildFailureResponse(ctx *context.Context, status ctx.SetOutputResponse(resp) } +// buildSuccessResponse builds a success response for WebSocket. +// The response is from WebSocket dial process. +// Here we leave the response body alone, and only set the status code and headers. +// The response body will be handled by the WebSocket protocol. +func (sp *WebSocketServerPool) buildSuccessResponse(ctx *context.Context, resp *http.Response) { + newResp, _ := ctx.GetOutputResponse().(*httpprot.Response) + if newResp == nil { + newResp, _ = httpprot.NewResponse(nil) + } + if resp != nil { + newResp.SetStatusCode(resp.StatusCode) + newResp.Std().Header = resp.Header.Clone() + } + ctx.SetOutputResponse(newResp) +} + func buildServerURL(svr *Server, req *httpprot.Request) (string, error) { u := *req.URL() u1, err := url.ParseRequestURI(svr.URL) @@ -112,10 +128,10 @@ func buildServerURL(svr *Server, req *httpprot.Request) (string, error) { return u.String(), nil } -func (sp *WebSocketServerPool) dialServer(svr *Server, req *httpprot.Request) (*websocket.Conn, error) { +func (sp *WebSocketServerPool) dialServer(svr *Server, req *httpprot.Request) (*websocket.Conn, *http.Response, error) { u, err := buildServerURL(svr, req) if err != nil { - return nil, err + return nil, nil, err } opts := &websocket.DialOptions{ @@ -161,11 +177,11 @@ func (sp *WebSocketServerPool) dialServer(svr *Server, req *httpprot.Request) (* opts.HTTPHeader.Set(xForwardedProto, "https") } - conn, _, err := websocket.Dial(stdctx.Background(), u, opts) + conn, resp, err := websocket.Dial(stdctx.Background(), u, opts) if err == nil && (sp.spec.ServerMaxMsgSize > 0 || sp.spec.ServerMaxMsgSize == -1) { conn.SetReadLimit(sp.spec.ServerMaxMsgSize) } - return conn, err + return conn, resp, err } func (sp *WebSocketServerPool) handle(ctx *context.Context) (result string) { @@ -214,7 +230,7 @@ func (sp *WebSocketServerPool) handle(ctx *context.Context) (result string) { clntConn.SetReadLimit(sp.spec.ClientMaxMsgSize) } - svrConn, err := sp.dialServer(svr, req) + svrConn, resp, err := sp.dialServer(svr, req) if err != nil { logger.Errorf("%s: dial to %s failed: %v", sp.Name, svr.URL, err) clntConn.Close(websocket.StatusGoingAway, "") @@ -296,6 +312,7 @@ func (sp *WebSocketServerPool) handle(ctx *context.Context) (result string) { wg.Wait() close(stop) + sp.buildSuccessResponse(ctx, resp) metric.StatusCode = http.StatusSwitchingProtocols ctx.SetData("HTTP_METRIC", metric) return diff --git a/pkg/filters/proxies/httpproxy/wspool_test.go b/pkg/filters/proxies/httpproxy/wspool_test.go index 5053e8ce3f..05e2dc79ac 100644 --- a/pkg/filters/proxies/httpproxy/wspool_test.go +++ b/pkg/filters/proxies/httpproxy/wspool_test.go @@ -39,36 +39,36 @@ func TestDialServer(t *testing.T) { req, _ := httpprot.NewRequest(stdr) svr.URL = "####" - _, err := sp.dialServer(svr, req) + _, _, err := sp.dialServer(svr, req) assert.Error(err) svr.URL = "http://127.0.0.1:9999" - _, err = sp.dialServer(svr, req) + _, _, err = sp.dialServer(svr, req) assert.Error(err) svr.URL = "https://127.0.0.1:9999" - _, err = sp.dialServer(svr, req) + _, _, err = sp.dialServer(svr, req) assert.Error(err) svr.URL = "tcp://127.0.0.1:9999" - _, err = sp.dialServer(svr, req) + _, _, err = sp.dialServer(svr, req) assert.Error(err) svr.URL = "ws://127.0.0.1:9999" - _, err = sp.dialServer(svr, req) + _, _, err = sp.dialServer(svr, req) assert.Error(err) stdr.Header.Add("Origin", "$#@#@$#$#") - _, err = sp.dialServer(svr, req) + _, _, err = sp.dialServer(svr, req) assert.Error(err) stdr.Header.Set("Origin", "http://127.0.0.1/hello") stdr.RemoteAddr = "127.0.0.1:8080" stdr.TLS = &tls.ConnectionState{} - _, err = sp.dialServer(svr, req) + _, _, err = sp.dialServer(svr, req) assert.Error(err) stdr.Header.Set("X-Forwarded-For", "192.168.1.1") - _, err = sp.dialServer(svr, req) + _, _, err = sp.dialServer(svr, req) assert.Error(err) }