Skip to content

Commit

Permalink
fix the server's 0-RTT rejection logic when using GetConfigForClient (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Jun 3, 2024
1 parent dea2eaf commit 459a6f3
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 194 deletions.
4 changes: 4 additions & 0 deletions http3/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config {
if config == nil {
return nil, nil
}
// Workaround for https://github.com/golang/go/issues/60506.
// This initializes the session tickets _before_ cloning the config.
_, _ = config.DecryptTicket(nil, tls.ConnectionState{})

config = config.Clone()
config.NextProtos = []string{proto}
return config, nil
Expand Down
154 changes: 80 additions & 74 deletions integrationtests/self/zero_rtt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -813,87 +813,93 @@ var _ = Describe("0-RTT", func() {
Entry("doesn't reject 0-RTT when the server's transport connection flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialConnectionReceiveWindow = limit }),
)

for _, l := range []int{0, 15} {
connIDLen := l

It(fmt.Sprintf("correctly deals with 0-RTT rejections, for %d byte connection IDs", connIDLen), func() {
tlsConf := getTLSConfig()
clientConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
// now dial new connection with different transport parameters
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
MaxIncomingUniStreams: 1,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
defer proxy.Close()
test0RTTRejection := func(tlsConf *tls.Config) {
clientConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
// now dial new connection with different transport parameters
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
MaxIncomingUniStreams: 1,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
defer proxy.Close()

conn, err := quic.DialAddrEarly(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
clientConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
// The client remembers that it was allowed to open 2 uni-directional streams.
firstStr, err := conn.OpenUniStream()
conn, err := quic.DialAddrEarly(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
clientConf,
getQuicConfig(&quic.Config{}),
)
Expect(err).ToNot(HaveOccurred())
// The client remembers that it was allowed to open 2 uni-directional streams.
firstStr, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
written := make(chan struct{}, 2)
go func() {
defer GinkgoRecover()
defer func() { written <- struct{}{} }()
_, err := firstStr.Write([]byte("first flight"))
Expect(err).ToNot(HaveOccurred())
written := make(chan struct{}, 2)
go func() {
defer GinkgoRecover()
defer func() { written <- struct{}{} }()
_, err := firstStr.Write([]byte("first flight"))
Expect(err).ToNot(HaveOccurred())
}()
secondStr, err := conn.OpenUniStream()
}()
secondStr, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
defer func() { written <- struct{}{} }()
_, err := secondStr.Write([]byte("first flight"))
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
defer func() { written <- struct{}{} }()
_, err := secondStr.Write([]byte("first flight"))
Expect(err).ToNot(HaveOccurred())
}()
}()

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_, err = conn.AcceptStream(ctx)
Expect(err).To(MatchError(quic.Err0RTTRejected))
Eventually(written).Should(Receive())
Eventually(written).Should(Receive())
_, err = firstStr.Write([]byte("foobar"))
Expect(err).To(MatchError(quic.Err0RTTRejected))
_, err = conn.OpenUniStream()
Expect(err).To(MatchError(quic.Err0RTTRejected))

_, err = conn.AcceptStream(ctx)
Expect(err).To(Equal(quic.Err0RTTRejected))

newConn := conn.NextConnection()
str, err := newConn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
_, err = newConn.OpenUniStream()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("too many open streams"))
_, err = str.Write([]byte("second flight"))
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
Expect(conn.CloseWithError(0, "")).To(Succeed())
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_, err = conn.AcceptStream(ctx)
Expect(err).To(MatchError(quic.Err0RTTRejected))
Eventually(written).Should(Receive())
Eventually(written).Should(Receive())
_, err = firstStr.Write([]byte("foobar"))
Expect(err).To(MatchError(quic.Err0RTTRejected))
_, err = conn.OpenUniStream()
Expect(err).To(MatchError(quic.Err0RTTRejected))

// The client should send 0-RTT packets, but the server doesn't process them.
num0RTT := num0RTTPackets.Load()
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
})
_, err = conn.AcceptStream(ctx)
Expect(err).To(Equal(quic.Err0RTTRejected))

newConn := conn.NextConnection()
str, err := newConn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
_, err = newConn.OpenUniStream()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("too many open streams"))
_, err = str.Write([]byte("second flight"))
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
Expect(conn.CloseWithError(0, "")).To(Succeed())

// The client should send 0-RTT packets, but the server doesn't process them.
num0RTT := num0RTTPackets.Load()
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
}

It("correctly deals with 0-RTT rejections", func() {
test0RTTRejection(getTLSConfig())
})

It("correctly deals with 0-RTT rejections, when the server uses GetConfigForClient", func() {
tlsConf := getTLSConfig()
test0RTTRejection(&tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { return tlsConf, nil },
})
})

It("queues 0-RTT packets, if the Initial is delayed", func() {
tlsConf := getTLSConfig()
clientConf := getTLSClientConfig()
Expand Down
42 changes: 4 additions & 38 deletions internal/handshake/crypto_setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,44 +123,12 @@ func NewCryptoSetupServer(
)
cs.allow0RTT = allow0RTT

quicConf := &tls.QUICConfig{TLSConfig: tlsConf}
qtls.SetupConfigForServer(quicConf, cs.getDataForSessionTicket, cs.handleSessionTicket)
addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr)

cs.tlsConf = quicConf.TLSConfig
cs.conn = tls.QUICServer(quicConf)

tlsConf = qtls.SetupConfigForServer(tlsConf, localAddr, remoteAddr, cs.getDataForSessionTicket, cs.handleSessionTicket)
cs.tlsConf = tlsConf
cs.conn = tls.QUICServer(&tls.QUICConfig{TLSConfig: tlsConf})
return cs
}

// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
// that allows the caller to get the local and the remote address.
func addConnToClientHelloInfo(conf *tls.Config, localAddr, remoteAddr net.Addr) {
if conf.GetConfigForClient != nil {
gcfc := conf.GetConfigForClient
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
c, err := gcfc(info)
if c != nil {
c = c.Clone()
// This won't be necessary anymore once https://github.com/golang/go/issues/63722 is accepted.
c.MinVersion = tls.VersionTLS13
// We're returning a tls.Config here, so we need to apply this recursively.
addConnToClientHelloInfo(c, localAddr, remoteAddr)
}
return c, err
}
}
if conf.GetCertificate != nil {
gc := conf.GetCertificate
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
return gc(info)
}
}
}

func newCryptoSetup(
connID protocol.ConnectionID,
tp *wire.TransportParameters,
Expand Down Expand Up @@ -376,9 +344,7 @@ func (h *cryptoSetup) getDataForSessionTicket() []byte {
// Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection.
// It is only valid for the server.
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{
EarlyData: h.allow0RTT,
}); err != nil {
if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{EarlyData: h.allow0RTT}); err != nil {
// Session tickets might be disabled by tls.Config.SessionTicketsDisabled.
// We can't check h.tlsConfig here, since the actual config might have been obtained from
// the GetConfigForClient callback.
Expand Down
74 changes: 0 additions & 74 deletions internal/handshake/crypto_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"crypto/x509/pkix"
"math/big"
"net"
"reflect"
"time"

mocktls "github.com/quic-go/quic-go/internal/mocks/tls"
Expand Down Expand Up @@ -106,79 +105,6 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
})

Context("filling in a net.Conn in tls.ClientHelloInfo", func() {
var (
local = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}
remote = &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
)

It("wraps GetCertificate", func() {
var localAddr, remoteAddr net.Addr
tlsConf := &tls.Config{
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
localAddr = info.Conn.LocalAddr()
remoteAddr = info.Conn.RemoteAddr()
cert := generateCert()
return &cert, nil
},
}
addConnToClientHelloInfo(tlsConf, local, remote)
_, err := tlsConf.GetCertificate(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(localAddr).To(Equal(local))
Expect(remoteAddr).To(Equal(remote))
})

It("wraps GetConfigForClient", func() {
var localAddr, remoteAddr net.Addr
tlsConf := &tls.Config{
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
localAddr = info.Conn.LocalAddr()
remoteAddr = info.Conn.RemoteAddr()
return &tls.Config{}, nil
},
}
addConnToClientHelloInfo(tlsConf, local, remote)
conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(localAddr).To(Equal(local))
Expect(remoteAddr).To(Equal(remote))
Expect(conf).ToNot(BeNil())
Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
})

It("wraps GetConfigForClient, recursively", func() {
var localAddr, remoteAddr net.Addr
tlsConf := &tls.Config{}
var innerConf *tls.Config
getCert := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam
localAddr = info.Conn.LocalAddr()
remoteAddr = info.Conn.RemoteAddr()
cert := generateCert()
return &cert, nil
}
tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
innerConf = tlsConf.Clone()
// set the MaxVersion, so we can check that quic-go doesn't overwrite the user's config
innerConf.MaxVersion = tls.VersionTLS12
innerConf.GetCertificate = getCert
return innerConf, nil
}
addConnToClientHelloInfo(tlsConf, local, remote)
conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(conf).ToNot(BeNil())
Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
_, err = conf.GetCertificate(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(localAddr).To(Equal(local))
Expect(remoteAddr).To(Equal(remote))
// make sure that the tls.Config returned by GetConfigForClient isn't modified
Expect(reflect.ValueOf(innerConf.GetCertificate).Pointer() == reflect.ValueOf(getCert).Pointer()).To(BeTrue())
Expect(innerConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12))
})
})

Context("doing the handshake", func() {
newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats {
rttStats := &utils.RTTStats{}
Expand Down
2 changes: 1 addition & 1 deletion internal/handshake/conn.go → internal/qtls/conn.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package handshake
package qtls

import (
"net"
Expand Down
34 changes: 30 additions & 4 deletions internal/qtls/qtls.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,23 @@ import (
"bytes"
"crypto/tls"
"fmt"
"net"

"github.com/quic-go/quic-go/internal/protocol"
)

func SetupConfigForServer(qconf *tls.QUICConfig, getData func() []byte, handleSessionTicket func([]byte, bool) bool) {
conf := qconf.TLSConfig

func SetupConfigForServer(
conf *tls.Config,
localAddr, remoteAddr net.Addr,
getData func() []byte,
handleSessionTicket func([]byte, bool) bool,
) *tls.Config {
// Workaround for https://github.com/golang/go/issues/60506.
// This initializes the session tickets _before_ cloning the config.
_, _ = conf.DecryptTicket(nil, tls.ConnectionState{})

conf = conf.Clone()
conf.MinVersion = tls.VersionTLS13
qconf.TLSConfig = conf

// add callbacks to save transport parameters into the session ticket
origWrapSession := conf.WrapSession
Expand Down Expand Up @@ -58,6 +61,29 @@ func SetupConfigForServer(qconf *tls.QUICConfig, getData func() []byte, handleSe

return state, nil
}
// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
// that allows the caller to get the local and the remote address.
if conf.GetConfigForClient != nil {
gcfc := conf.GetConfigForClient
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
c, err := gcfc(info)
if c != nil {
// We're returning a tls.Config here, so we need to apply this recursively.
c = SetupConfigForServer(c, localAddr, remoteAddr, getData, handleSessionTicket)
}
return c, err
}
}
if conf.GetCertificate != nil {
gc := conf.GetCertificate
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
return gc(info)
}
}
return conf
}

func SetupConfigForClient(
Expand Down
Loading

0 comments on commit 459a6f3

Please sign in to comment.