Skip to content

Commit

Permalink
Fix TCP-TLS/HTTPS routing precedence
Browse files Browse the repository at this point in the history
Co-authored-by: Mathieu Lonjaret <[email protected]>
  • Loading branch information
rtribotte and mpl committed May 19, 2022
1 parent ede2be1 commit ac4086d
Show file tree
Hide file tree
Showing 4 changed files with 1,031 additions and 36 deletions.
52 changes: 31 additions & 21 deletions pkg/muxer/tcp/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,32 +95,21 @@ func NewMuxer() (*Muxer, error) {
return &Muxer{parser: parser}, nil
}

// Match returns the handler of the first route matching the connection metadata.
func (m Muxer) Match(meta ConnData) tcp.Handler {
// Match returns the handler of the first route matching the connection metadata,
// and whether the match is exactly from the rule HostSNI(*).
func (m Muxer) Match(meta ConnData) (tcp.Handler, bool) {
for _, route := range m.routes {
if route.matchers.match(meta) {
return route.handler
return route.handler, route.catchAll
}
}

return nil
return nil, false
}

// AddRoute adds a new route, associated to the given handler, at the given
// priority, to the muxer.
func (m *Muxer) AddRoute(rule string, priority int, handler tcp.Handler) error {
// Special case for when the catchAll fallback is present.
// When no user-defined priority is found, the lowest computable priority minus one is used,
// in order to make the fallback the last to be evaluated.
if priority == 0 && rule == "HostSNI(`*`)" {
priority = -1
}

// Default value, which means the user has not set it, so we'll compute it.
if priority == 0 {
priority = len(rule)
}

parse, err := m.parser.Parse(rule)
if err != nil {
return fmt.Errorf("error while parsing rule %s: %w", rule, err)
Expand All @@ -131,16 +120,36 @@ func (m *Muxer) AddRoute(rule string, priority int, handler tcp.Handler) error {
return fmt.Errorf("error while parsing rule %s", rule)
}

ruleTree := buildTree()

var matchers matchersTree
err = addRule(&matchers, buildTree())
err = addRule(&matchers, ruleTree)
if err != nil {
return err
}

var catchAll bool
if ruleTree.RuleLeft == nil && ruleTree.RuleRight == nil && len(ruleTree.Value) == 1 {
catchAll = ruleTree.Value[0] == "*" && strings.EqualFold(ruleTree.Matcher, "HostSNI")
}

// Special case for when the catchAll fallback is present.
// When no user-defined priority is found, the lowest computable priority minus one is used,
// in order to make the fallback the last to be evaluated.
if priority == 0 && catchAll {
priority = -1
}

// Default value, which means the user has not set it, so we'll compute it.
if priority == 0 {
priority = len(rule)
}

newRoute := &route{
handler: handler,
priority: priority,
matchers: matchers,
catchAll: catchAll,
priority: priority,
}
m.routes = append(m.routes, newRoute)

Expand Down Expand Up @@ -207,9 +216,10 @@ type route struct {
matchers matchersTree
// handler responsible for handling the route.
handler tcp.Handler

// Used to disambiguate between two (or more) rules that would both match for a
// given request.
// catchAll indicates whether the route rule has exactly the catchAll value (HostSNI(`*`)).
catchAll bool
// priority is used to disambiguate between two (or more) rules that would
// all match for a given request.
// Computed from the matching rule length, if not user-set.
priority int
}
Expand Down
52 changes: 50 additions & 2 deletions pkg/muxer/tcp/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ func Test_addTCPRoute(t *testing.T) {
connData, err := NewConnData(test.serverName, conn)
require.NoError(t, err)

matchingHandler := router.Match(connData)
matchingHandler, _ := router.Match(connData)
if test.matchErr {
require.Nil(t, matchingHandler)
return
Expand Down Expand Up @@ -568,6 +568,54 @@ func TestParseHostSNI(t *testing.T) {
}
}

func Test_HostSNICatchAll(t *testing.T) {
testCases := []struct {
desc string
rule string
isCatchAll bool
}{
{
desc: "HostSNI(`foobar`) is not catchAll",
rule: "HostSNI(`foobar`)",
},
{
desc: "HostSNI(`*`) is catchAll",
rule: "HostSNI(`*`)",
isCatchAll: true,
},
{
desc: "HOSTSNI(`*`) is catchAll",
rule: "HOSTSNI(`*`)",
isCatchAll: true,
},
{
desc: `HostSNI("*") is catchAll`,
rule: `HostSNI("*")`,
isCatchAll: true,
},
}

for _, test := range testCases {
test := test

t.Run(test.desc, func(t *testing.T) {
t.Parallel()

muxer, err := NewMuxer()
require.NoError(t, err)

err = muxer.AddRoute(test.rule, 0, tcp.HandlerFunc(func(conn tcp.WriteCloser) {}))
require.NoError(t, err)

handler, catchAll := muxer.Match(ConnData{
serverName: "foobar",
})
require.NotNil(t, handler)
assert.Equal(t, test.isCatchAll, catchAll)
})
}
}

func Test_HostSNI(t *testing.T) {
testCases := []struct {
desc string
Expand Down Expand Up @@ -934,7 +982,7 @@ func Test_Priority(t *testing.T) {
require.NoError(t, err)
}

handler := muxer.Match(ConnData{
handler, _ := muxer.Match(ConnData{
serverName: test.serverName,
remoteIP: test.remoteIP,
})
Expand Down
44 changes: 31 additions & 13 deletions pkg/server/router/tcp/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) {
return
}

handler := r.muxerTCP.Match(connData)
handler, _ := r.muxerTCP.Match(connData)
// If there is a handler matching the connection metadata,
// we let it handle the connection.
if handler != nil {
Expand Down Expand Up @@ -133,7 +133,7 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) {
}

if !tls {
handler := r.muxerTCP.Match(connData)
handler, _ := r.muxerTCP.Match(connData)
switch {
case handler != nil:
handler.ServeTCP(r.GetConn(conn, peeked))
Expand All @@ -145,20 +145,38 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) {
return
}

handler := r.muxerTCPTLS.Match(connData)
if handler != nil {
handler.ServeTCP(r.GetConn(conn, peeked))
// For real, the handler eventually used for HTTPS is (almost) always the same:
// it is the httpsForwarder that is used for all HTTPS connections that match
// (which is also incidentally the same used in the last block below for 404s).
// The added value from doing Match is to find and use the specific TLS config
// (wrapped inside the returned handler) requested for the given HostSNI.
handlerHTTPS, catchAllHTTPS := r.muxerHTTPS.Match(connData)
if handlerHTTPS != nil && !catchAllHTTPS {
// In order not to depart from the behavior in 2.6, we only allow an HTTPS router
// to take precedence over a TCP-TLS router if it is _not_ an HostSNI(*) router (so
// basically any router that has a specific HostSNI based rule).
handlerHTTPS.ServeTCP(r.GetConn(conn, peeked))
return
}

// for real, the handler returned here is (almost) always the same:
// it is the httpsForwarder that is used for all HTTPS connections that match
// (which is also incidentally the same used in the last block below for 404s).
// The added value from doing Match, is to find and use the specific TLS config
// requested for the given HostSNI.
handler = r.muxerHTTPS.Match(connData)
if handler != nil {
handler.ServeTCP(r.GetConn(conn, peeked))
// Contains also TCP TLS passthrough routes.
handlerTCPTLS, catchAllTCPTLS := r.muxerTCPTLS.Match(connData)
if handlerTCPTLS != nil && !catchAllTCPTLS {
handlerTCPTLS.ServeTCP(r.GetConn(conn, peeked))
return
}

// Fallback on HTTPS catchAll.
// We end up here for e.g. an HTTPS router that only has a PathPrefix rule,
// which under the scenes is counted as an HostSNI(*) rule.
if handlerHTTPS != nil {
handlerHTTPS.ServeTCP(r.GetConn(conn, peeked))
return
}

// Fallback on TCP TLS catchAll.
if handlerTCPTLS != nil {
handlerTCPTLS.ServeTCP(r.GetConn(conn, peeked))
return
}

Expand Down
Loading

0 comments on commit ac4086d

Please sign in to comment.