Skip to content

Commit

Permalink
WebSocket support close #99 (#103)
Browse files Browse the repository at this point in the history
* [ws] add WebSocketServer

* [ws] close megaese#99

* [wc] fix make test failed

* [ws] correct log format

* [ws] refactor TLS validate and update comment

Co-authored-by: benjaminwu <[email protected]>
  • Loading branch information
benja-wu and benjaminwu committed Jul 20, 2021
1 parent 573eb49 commit a36ff16
Show file tree
Hide file tree
Showing 6 changed files with 483 additions and 0 deletions.
5 changes: 5 additions & 0 deletions example/config/websocket-server-example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
kind: WebSocketServer
name: websocket-server-example
port: 10081
https: false
backend: ws:https://localhost:9880
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ require (
github.com/ghodss/yaml v1.0.0
github.com/go-chi/chi/v5 v5.0.3
github.com/go-zookeeper/zk v1.0.2
github.com/gorilla/websocket v1.4.2
github.com/hashicorp/consul/api v1.8.1
github.com/hashicorp/golang-lru v0.5.4
github.com/json-iterator/go v1.1.11
Expand Down
299 changes: 299 additions & 0 deletions pkg/object/websocketserver/proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
/*
* Copyright (c) 2017, MegaEase
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http:https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package websocketserver

import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"

"github.com/gorilla/websocket"

"github.com/megaease/easegress/pkg/logger"
"github.com/megaease/easegress/pkg/supervisor"
)

var (
// defaultUpgrader specifies the parameters for upgrading an HTTP
// connection to a WebSocket connection.
defaultUpgrader = &websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}

// defaultDialer is a dialer with all fields set to the default zero values.
defaultDialer = websocket.DefaultDialer

// defaultInterval is the default interval for polling websocket client and server which is
// 200ms right now.
defaultInterval = 200 * time.Millisecond
)

// Proxy is a handler that takes an incoming WebSocket
// connection and proxies it to the backend server.
type Proxy struct {
// server is the HTTPServer
server http.Server
superSpec *supervisor.Spec

// backendURL URL is the URL of target websocket server.
backendURL *url.URL

// upgrader specifies the parameters for upgrading a incoming HTTP
// connection to a WebSocket connection.
upgrader *websocket.Upgrader

// dialer contains options for connecting to the backend WebSocket server.
dialer *websocket.Dialer

// done is the channel for shutdowning this proxy.
done chan struct{}
}

// NewProxy returns a new Websocket proxy.
func newProxy(superSpec *supervisor.Spec) *Proxy {
proxy := &Proxy{
superSpec: superSpec,
done: make(chan struct{}),
}
go proxy.run()
return proxy
}

// buildRequestURL builds an URL with backend in spec and original HTTP request.
func (p *Proxy) buildRequestURL(r *http.Request) *url.URL {
u := *p.backendURL
u.Fragment = r.URL.Fragment
u.Path = r.URL.Path
u.RawQuery = r.URL.RawQuery
return &u
}

// passMsg passes websocket message from src to dst.
func (p *Proxy) passMsg(src, dst *websocket.Conn, errc chan error, stop chan struct{}) {
handle := func() bool {
msgType, msg, err := src.ReadMessage()
if err != nil {
m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err))
if e, ok := err.(*websocket.CloseError); ok {
if e.Code != websocket.CloseNoStatusReceived {
m = websocket.FormatCloseMessage(e.Code, e.Text)
}
}
dst.WriteMessage(websocket.CloseMessage, m)
errc <- err
return false
}
err = dst.WriteMessage(msgType, msg)
if err != nil {
errc <- err
return false
}
return true
}

for {
select {
// this request handling is stopped due to some error or websocketserver shutdown.
case <-stop:
return
case <-time.After(defaultInterval):
if !handle() {
return
}
}
}
}

// run runs the websocket proxy.
func (p *Proxy) run() {
spec := p.superSpec.ObjectSpec().(*Spec)
backendURL, err := url.Parse(spec.Backend)
if err != nil {
logger.Errorf("BUG: %s get invalid websocketserver backend URL: %s",
p.superSpec.Name(), spec.Backend)
return
}

p.backendURL = backendURL
dialer := defaultDialer
if strings.HasPrefix(spec.Backend, "wss") {
tlsConfig, err := spec.wssTLSConfig()
if err != nil {
logger.Errorf("%s gen websocketserver backend tls failed: %v, spec :%#v",
p.superSpec.Name(), spec)
return
}
dialer.TLSClientConfig = tlsConfig
}
p.dialer = dialer
p.upgrader = defaultUpgrader

http.HandleFunc("/", p.handle)
addr := fmt.Sprintf(":%d", spec.Port)
svr := &http.Server{
Addr: addr,
Handler: nil,
}

if spec.HTTPS {
tlsConfig, err := spec.tlsConfig()
if err != nil {
logger.Errorf("%s gen websocketserver's httpserver tlsConfig: %#v, failed: %v",
p.superSpec.Name(), spec, err)
}
svr.TLSConfig = tlsConfig
}

if err := svr.ListenAndServe(); err != nil {
logger.Errorf("%s websocketserver ListenAndServe failed: %v", p.superSpec.Name(), err)
}
}

// copyHeader copies headers from the incoming request to the dialer and forward them to
// the destination.
func (p *Proxy) copyHeader(req *http.Request) http.Header {

requestHeader := http.Header{}
if origin := req.Header.Get("Origin"); origin != "" {
requestHeader.Add("Origin", origin)
}
for _, prot := range req.Header[http.CanonicalHeaderKey("Sec-WebSocket-Protocol")] {
requestHeader.Add("Sec-WebSocket-Protocol", prot)
}
for _, cookie := range req.Header[http.CanonicalHeaderKey("Cookie")] {
requestHeader.Add("Cookie", cookie)
}
if req.Host != "" {
requestHeader.Set("Host", req.Host)
}

if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if prior, ok := req.Header["X-Forwarded-For"]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
requestHeader.Set("X-Forwarded-For", clientIP)
}

requestHeader.Set("X-Forwarded-Proto", "http")
if req.TLS != nil {
requestHeader.Set("X-Forwarded-Proto", "https")
}

return requestHeader
}

// upgradeRspHeader passes only selected headers as return.
func (p *Proxy) upgradeRspHeader(resp *http.Response) http.Header {
upgradeHeader := http.Header{}
if hdr := resp.Header.Get("Sec-Websocket-Protocol"); hdr != "" {
upgradeHeader.Set("Sec-Websocket-Protocol", hdr)
}
if hdr := resp.Header.Get("Set-Cookie"); hdr != "" {
upgradeHeader.Set("Set-Cookie", hdr)
}
return upgradeHeader
}

// handle implements the http.Handler that proxies WebSocket connections.
func (p *Proxy) handle(rw http.ResponseWriter, req *http.Request) {
connBackend, resp, err := p.dialer.Dial(p.buildRequestURL(req).String(), p.copyHeader(req))
if err != nil {
logger.Errorf("%s dials %s failed: %v", p.superSpec.Name(), p.backendURL.String(), err)
if resp != nil {
// Handle WebSocket handshake failed scenario.
// Should send back a non-nil *http.Response for callers to handle
// `redirects`, `authentication` operations and so on.
if err := copyResponse(rw, resp); err != nil {
logger.Errorf("%s writes response failed at remote backend: %s handshake: %v",
p.superSpec.Name(), p.backendURL.String(), err)
}
} else {
http.Error(rw, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable)
}
return
}
defer connBackend.Close()

// Upgrade the incoming request to a WebSocket connection(Protocol Switching).
// Also pass the header from the Dial handshake.
connClient, err := p.upgrader.Upgrade(rw, req, p.upgradeRspHeader(resp))
if err != nil {
logger.Errorf("%s upgrades req: %#v failed: %s", p.superSpec.Name(), err)
return
}
defer connClient.Close()

errClient := make(chan error, 1)
errBackend := make(chan error, 1)
stop := make(chan struct{})

defer close(stop)

// pass msg from backend to client via WebSocket protocol.
go p.passMsg(connBackend, connClient, errBackend, stop)
// pass msg from client to backend via WebSocket protocol.
go p.passMsg(connClient, connBackend, errClient, stop)

var errMsg string
select {
case err = <-errBackend:
errMsg = "%s passes msg from backend: %s to client failed: %v"
case err = <-errClient:
errMsg = "%s passes msg client to backend: %s failed: %v"
case <-p.done:
logger.Debugf("shutdown websocketserver in request handling")
return
}

if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure {
logger.Errorf(errMsg, p.superSpec.Name(), p.backendURL.String(), err)
}
// other error type is expected, not need to log
}

// Close close websocket proxy.
func (p *Proxy) Close() {
close(p.done)

ctx, cancelFunc := context.WithTimeout(context.Background(), 30*time.Second)
defer cancelFunc()
err := p.server.Shutdown(ctx)
if err != nil {
logger.Warnf("%s shutdowns http server failed: %v",
p.superSpec.Name(), err)
}
}

func copyResponse(rw http.ResponseWriter, resp *http.Response) error {
for k, vv := range resp.Header {
for _, v := range vv {
rw.Header().Add(k, v)
}
}
rw.WriteHeader(resp.StatusCode)
defer resp.Body.Close()

_, err := io.Copy(rw, resp.Body)
return err
}
Loading

0 comments on commit a36ff16

Please sign in to comment.