Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the server's 0-RTT rejection logic when using GetConfigForClient #4550

Merged
merged 1 commit into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions http3/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config {
if config == nil {
return nil, nil
}
// Workaround for https://github.com/golang/go/issues/60506.
// This initializes the session tickets _before_ cloning the config.
_, _ = config.DecryptTicket(nil, tls.ConnectionState{})

config = config.Clone()
config.NextProtos = []string{proto}
return config, nil
Expand Down
154 changes: 80 additions & 74 deletions integrationtests/self/zero_rtt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -813,87 +813,93 @@ var _ = Describe("0-RTT", func() {
Entry("doesn't reject 0-RTT when the server's transport connection flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialConnectionReceiveWindow = limit }),
)

for _, l := range []int{0, 15} {
connIDLen := l
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This didn't work.


It(fmt.Sprintf("correctly deals with 0-RTT rejections, for %d byte connection IDs", connIDLen), func() {
tlsConf := getTLSConfig()
clientConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
// now dial new connection with different transport parameters
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
MaxIncomingUniStreams: 1,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
defer proxy.Close()
test0RTTRejection := func(tlsConf *tls.Config) {
clientConf := getTLSClientConfig()
dialAndReceiveSessionTicket(tlsConf, nil, clientConf)
// now dial new connection with different transport parameters
counter, tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
MaxIncomingUniStreams: 1,
Tracer: newTracer(tracer),
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
defer proxy.Close()

conn, err := quic.DialAddrEarly(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
clientConf,
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
// The client remembers that it was allowed to open 2 uni-directional streams.
firstStr, err := conn.OpenUniStream()
conn, err := quic.DialAddrEarly(
context.Background(),
fmt.Sprintf("localhost:%d", proxy.LocalPort()),
clientConf,
getQuicConfig(&quic.Config{}),
)
Expect(err).ToNot(HaveOccurred())
// The client remembers that it was allowed to open 2 uni-directional streams.
firstStr, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
written := make(chan struct{}, 2)
go func() {
defer GinkgoRecover()
defer func() { written <- struct{}{} }()
_, err := firstStr.Write([]byte("first flight"))
Expect(err).ToNot(HaveOccurred())
written := make(chan struct{}, 2)
go func() {
defer GinkgoRecover()
defer func() { written <- struct{}{} }()
_, err := firstStr.Write([]byte("first flight"))
Expect(err).ToNot(HaveOccurred())
}()
secondStr, err := conn.OpenUniStream()
}()
secondStr, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
defer func() { written <- struct{}{} }()
_, err := secondStr.Write([]byte("first flight"))
Expect(err).ToNot(HaveOccurred())
go func() {
defer GinkgoRecover()
defer func() { written <- struct{}{} }()
_, err := secondStr.Write([]byte("first flight"))
Expect(err).ToNot(HaveOccurred())
}()
}()

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_, err = conn.AcceptStream(ctx)
Expect(err).To(MatchError(quic.Err0RTTRejected))
Eventually(written).Should(Receive())
Eventually(written).Should(Receive())
_, err = firstStr.Write([]byte("foobar"))
Expect(err).To(MatchError(quic.Err0RTTRejected))
_, err = conn.OpenUniStream()
Expect(err).To(MatchError(quic.Err0RTTRejected))

_, err = conn.AcceptStream(ctx)
Expect(err).To(Equal(quic.Err0RTTRejected))

newConn := conn.NextConnection()
str, err := newConn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
_, err = newConn.OpenUniStream()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("too many open streams"))
_, err = str.Write([]byte("second flight"))
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
Expect(conn.CloseWithError(0, "")).To(Succeed())
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_, err = conn.AcceptStream(ctx)
Expect(err).To(MatchError(quic.Err0RTTRejected))
Eventually(written).Should(Receive())
Eventually(written).Should(Receive())
_, err = firstStr.Write([]byte("foobar"))
Expect(err).To(MatchError(quic.Err0RTTRejected))
_, err = conn.OpenUniStream()
Expect(err).To(MatchError(quic.Err0RTTRejected))

// The client should send 0-RTT packets, but the server doesn't process them.
num0RTT := num0RTTPackets.Load()
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
})
_, err = conn.AcceptStream(ctx)
Expect(err).To(Equal(quic.Err0RTTRejected))

newConn := conn.NextConnection()
str, err := newConn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
_, err = newConn.OpenUniStream()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("too many open streams"))
_, err = str.Write([]byte("second flight"))
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
Expect(conn.CloseWithError(0, "")).To(Succeed())

// The client should send 0-RTT packets, but the server doesn't process them.
num0RTT := num0RTTPackets.Load()
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty())
}

It("correctly deals with 0-RTT rejections", func() {
test0RTTRejection(getTLSConfig())
})

It("correctly deals with 0-RTT rejections, when the server uses GetConfigForClient", func() {
tlsConf := getTLSConfig()
test0RTTRejection(&tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { return tlsConf, nil },
})
})

It("queues 0-RTT packets, if the Initial is delayed", func() {
tlsConf := getTLSConfig()
clientConf := getTLSClientConfig()
Expand Down
42 changes: 4 additions & 38 deletions internal/handshake/crypto_setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,44 +123,12 @@ func NewCryptoSetupServer(
)
cs.allow0RTT = allow0RTT

quicConf := &tls.QUICConfig{TLSConfig: tlsConf}
qtls.SetupConfigForServer(quicConf, cs.getDataForSessionTicket, cs.handleSessionTicket)
addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr)

cs.tlsConf = quicConf.TLSConfig
cs.conn = tls.QUICServer(quicConf)

tlsConf = qtls.SetupConfigForServer(tlsConf, localAddr, remoteAddr, cs.getDataForSessionTicket, cs.handleSessionTicket)
cs.tlsConf = tlsConf
cs.conn = tls.QUICServer(&tls.QUICConfig{TLSConfig: tlsConf})
return cs
}

// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
// that allows the caller to get the local and the remote address.
func addConnToClientHelloInfo(conf *tls.Config, localAddr, remoteAddr net.Addr) {
if conf.GetConfigForClient != nil {
gcfc := conf.GetConfigForClient
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
c, err := gcfc(info)
if c != nil {
c = c.Clone()
// This won't be necessary anymore once https://github.com/golang/go/issues/63722 is accepted.
c.MinVersion = tls.VersionTLS13
// We're returning a tls.Config here, so we need to apply this recursively.
addConnToClientHelloInfo(c, localAddr, remoteAddr)
}
return c, err
}
}
if conf.GetCertificate != nil {
gc := conf.GetCertificate
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
return gc(info)
}
}
}

func newCryptoSetup(
connID protocol.ConnectionID,
tp *wire.TransportParameters,
Expand Down Expand Up @@ -376,9 +344,7 @@ func (h *cryptoSetup) getDataForSessionTicket() []byte {
// Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection.
// It is only valid for the server.
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{
EarlyData: h.allow0RTT,
}); err != nil {
if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{EarlyData: h.allow0RTT}); err != nil {
// Session tickets might be disabled by tls.Config.SessionTicketsDisabled.
// We can't check h.tlsConfig here, since the actual config might have been obtained from
// the GetConfigForClient callback.
Expand Down
74 changes: 0 additions & 74 deletions internal/handshake/crypto_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"crypto/x509/pkix"
"math/big"
"net"
"reflect"
"time"

mocktls "github.com/quic-go/quic-go/internal/mocks/tls"
Expand Down Expand Up @@ -106,79 +105,6 @@ var _ = Describe("Crypto Setup TLS", func() {
Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
})

Context("filling in a net.Conn in tls.ClientHelloInfo", func() {
var (
local = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}
remote = &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
)

It("wraps GetCertificate", func() {
var localAddr, remoteAddr net.Addr
tlsConf := &tls.Config{
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
localAddr = info.Conn.LocalAddr()
remoteAddr = info.Conn.RemoteAddr()
cert := generateCert()
return &cert, nil
},
}
addConnToClientHelloInfo(tlsConf, local, remote)
_, err := tlsConf.GetCertificate(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(localAddr).To(Equal(local))
Expect(remoteAddr).To(Equal(remote))
})

It("wraps GetConfigForClient", func() {
var localAddr, remoteAddr net.Addr
tlsConf := &tls.Config{
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
localAddr = info.Conn.LocalAddr()
remoteAddr = info.Conn.RemoteAddr()
return &tls.Config{}, nil
},
}
addConnToClientHelloInfo(tlsConf, local, remote)
conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(localAddr).To(Equal(local))
Expect(remoteAddr).To(Equal(remote))
Expect(conf).ToNot(BeNil())
Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
})

It("wraps GetConfigForClient, recursively", func() {
var localAddr, remoteAddr net.Addr
tlsConf := &tls.Config{}
var innerConf *tls.Config
getCert := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam
localAddr = info.Conn.LocalAddr()
remoteAddr = info.Conn.RemoteAddr()
cert := generateCert()
return &cert, nil
}
tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
innerConf = tlsConf.Clone()
// set the MaxVersion, so we can check that quic-go doesn't overwrite the user's config
innerConf.MaxVersion = tls.VersionTLS12
innerConf.GetCertificate = getCert
return innerConf, nil
}
addConnToClientHelloInfo(tlsConf, local, remote)
conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(conf).ToNot(BeNil())
Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
_, err = conf.GetCertificate(&tls.ClientHelloInfo{})
Expect(err).ToNot(HaveOccurred())
Expect(localAddr).To(Equal(local))
Expect(remoteAddr).To(Equal(remote))
// make sure that the tls.Config returned by GetConfigForClient isn't modified
Expect(reflect.ValueOf(innerConf.GetCertificate).Pointer() == reflect.ValueOf(getCert).Pointer()).To(BeTrue())
Expect(innerConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12))
})
})

Context("doing the handshake", func() {
newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats {
rttStats := &utils.RTTStats{}
Expand Down
2 changes: 1 addition & 1 deletion internal/handshake/conn.go → internal/qtls/conn.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package handshake
package qtls

import (
"net"
Expand Down
34 changes: 30 additions & 4 deletions internal/qtls/qtls.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,23 @@ import (
"bytes"
"crypto/tls"
"fmt"
"net"

"github.com/quic-go/quic-go/internal/protocol"
)

func SetupConfigForServer(qconf *tls.QUICConfig, getData func() []byte, handleSessionTicket func([]byte, bool) bool) {
conf := qconf.TLSConfig

func SetupConfigForServer(
conf *tls.Config,
localAddr, remoteAddr net.Addr,
getData func() []byte,
handleSessionTicket func([]byte, bool) bool,
) *tls.Config {
// Workaround for https://github.com/golang/go/issues/60506.
// This initializes the session tickets _before_ cloning the config.
_, _ = conf.DecryptTicket(nil, tls.ConnectionState{})

conf = conf.Clone()
conf.MinVersion = tls.VersionTLS13
qconf.TLSConfig = conf

// add callbacks to save transport parameters into the session ticket
origWrapSession := conf.WrapSession
Expand Down Expand Up @@ -58,6 +61,29 @@ func SetupConfigForServer(qconf *tls.QUICConfig, getData func() []byte, handleSe

return state, nil
}
// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
// that allows the caller to get the local and the remote address.
if conf.GetConfigForClient != nil {
gcfc := conf.GetConfigForClient
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
c, err := gcfc(info)
if c != nil {
// We're returning a tls.Config here, so we need to apply this recursively.
c = SetupConfigForServer(c, localAddr, remoteAddr, getData, handleSessionTicket)
}
return c, err
}
}
if conf.GetCertificate != nil {
gc := conf.GetCertificate
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
return gc(info)
}
}
return conf
}

func SetupConfigForClient(
Expand Down
Loading
Loading