Skip to content

Commit

Permalink
todo realUpstreamProxy
Browse files Browse the repository at this point in the history
  • Loading branch information
lqqyt2423 committed Jul 18, 2023
1 parent ad19f39 commit 3975856
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 13 deletions.
23 changes: 16 additions & 7 deletions proxy/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ import (
// client connection
type ClientConn struct {
Id uuid.UUID
Conn net.Conn
Conn net.Conn // rawClientConnContextKey is this
Tls bool
}

var rawClientConnContextKey = new(struct{})

func newClientConn(c net.Conn) *ClientConn {
return &ClientConn{
Id: uuid.NewV4(),
Expand Down Expand Up @@ -115,7 +117,7 @@ func (connCtx *ConnContext) initHttpServerConn() {
serverConn := newServerConn()
serverConn.client = &http.Client{
Transport: &http.Transport{
Proxy: clientProxy(connCtx.proxy.Opts.Upstream),
Proxy: connCtx.proxy.realUpstreamProxy(),
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
c, err := (&net.Dialer{}).DialContext(ctx, network, addr)
if err != nil {
Expand Down Expand Up @@ -150,13 +152,13 @@ func (connCtx *ConnContext) initHttpServerConn() {
connCtx.ServerConn = serverConn
}

func (connCtx *ConnContext) initServerTcpConn(req *http.Request) error {
func (connCtx *ConnContext) initServerTcpConn(req *http.Request, rawClientConn net.Conn) error {
log.Debugln("in initServerTcpConn")
ServerConn := newServerConn()
connCtx.ServerConn = ServerConn
ServerConn.Address = connCtx.pipeConn.host

plainConn, err := getConnFrom(req.Host, connCtx.proxy.Opts.Upstream)
plainConn, err := getConnFrom(req.Host, connCtx.proxy, rawClientConn)
if err != nil {
return err
}
Expand Down Expand Up @@ -369,9 +371,16 @@ func getProxyConn(proxyUrl *url.URL, address string) (net.Conn, error) {
return conn, nil
}

func getConnFrom(address string, upstream string) (net.Conn, error) {
clientReq := &http.Request{URL: &url.URL{Scheme: "https", Host: address}}
proxyUrl, err := clientProxy(upstream)(clientReq)
func getConnFrom(address string, proxy *Proxy, rawClientConn net.Conn) (net.Conn, error) {
clientReqCtx := context.WithValue(context.Background(), rawClientConnContextKey, rawClientConn)
clientReq, err := http.NewRequestWithContext(clientReqCtx, "CONNECT", "https://"+address, nil)
if err != nil {
return nil, err
}
clientReq.URL.Scheme = "https"
clientReq.URL.Host = address

proxyUrl, err := proxy.realUpstreamProxy()(clientReq)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions proxy/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ func (m *middle) close() error {
return nil
}

func (m *middle) dial(req *http.Request) (net.Conn, error) {
func (m *middle) dial(req *http.Request, rawClientConn net.Conn) (net.Conn, error) {
pipeClientConn, pipeServerConn := newPipes(req)
err := pipeServerConn.connContext.initServerTcpConn(req)
err := pipeServerConn.connContext.initServerTcpConn(req, rawClientConn)
if err != nil {
pipeClientConn.Close()
pipeServerConn.Close()
Expand Down
26 changes: 22 additions & 4 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io"
"net"
"net/http"
"net/url"

log "github.com/sirupsen/logrus"
)
Expand All @@ -30,6 +31,7 @@ type Proxy struct {
server *http.Server
interceptor *middle
shouldIntercept func(address string) bool
upstreamProxy func(*http.Request, net.Conn) (*url.URL, error)
}

func NewProxy(opts *Options) (*Proxy, error) {
Expand All @@ -45,7 +47,7 @@ func NewProxy(opts *Options) (*Proxy, error) {

proxy.client = &http.Client{
Transport: &http.Transport{
Proxy: clientProxy(opts.Upstream),
Proxy: proxy.realUpstreamProxy(),
ForceAttemptHTTP2: false, // disable http2
DisableCompression: true, // To get the original response from the server, set Transport.DisableCompression to true.
TLSClientConfig: &tls.Config{
Expand Down Expand Up @@ -229,7 +231,9 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) {
for _, addon := range proxy.Addons {
reqBody = addon.StreamRequestModifier(f, reqBody)
}
proxyReq, err := http.NewRequest(f.Request.Method, f.Request.URL.String(), reqBody)

proxyReqCtx := context.WithValue(context.Background(), rawClientConnContextKey, f.ConnContext.ClientConn.Conn)
proxyReq, err := http.NewRequestWithContext(proxyReqCtx, f.Request.Method, f.Request.URL.String(), reqBody)
if err != nil {
log.Error(err)
res.WriteHeader(502)
Expand Down Expand Up @@ -327,10 +331,10 @@ func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) {
var err error
if shouldIntercept {
log.Debugf("begin intercept %v", req.Host)
conn, err = proxy.interceptor.dial(req)
conn, err = proxy.interceptor.dial(req, f.ConnContext.ClientConn.Conn)
} else {
log.Debugf("begin transpond %v", req.Host)
conn, err = getConnFrom(req.Host, proxy.Opts.Upstream)
conn, err = getConnFrom(req.Host, proxy, f.ConnContext.ClientConn.Conn)
}
if err != nil {
log.Error(err)
Expand Down Expand Up @@ -382,3 +386,17 @@ func (proxy *Proxy) GetCertificate() x509.Certificate {
func (proxy *Proxy) SetShouldInterceptRule(rule func(address string) bool) {
proxy.shouldIntercept = rule
}

func (proxy *Proxy) SetUpstreamProxy(fn func(*http.Request, net.Conn) (*url.URL, error)) {
proxy.upstreamProxy = fn
}

func (proxy *Proxy) realUpstreamProxy() func(*http.Request) (*url.URL, error) {
return func(req *http.Request) (*url.URL, error) {
if proxy.upstreamProxy != nil {
rawClientConn := req.Context().Value(rawClientConnContextKey).(net.Conn)
return proxy.upstreamProxy(req, rawClientConn)
}
return clientProxy(proxy.Opts.Upstream)(req)
}
}

0 comments on commit 3975856

Please sign in to comment.