Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use Transport.VerifySourceAddress to control the Retry Mechanism #4362

Merged
merged 2 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
pass the remote address to Transport.VerifySourceAddress
  • Loading branch information
marten-seemann committed Mar 13, 2024
commit b3037da5575bb9c2789be302a75a2bd467103085
2 changes: 1 addition & 1 deletion integrationtests/self/handshake_drop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ var _ = Describe("Handshake drop tests", func() {
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{Conn: conn}
if doRetry {
tr.VerifySourceAddress = func() bool { return true }
tr.VerifySourceAddress = func(net.Addr) bool { return true }
}
ln, err = tr.Listen(tlsConf, conf)
Expect(err).ToNot(HaveOccurred())
Expand Down
2 changes: 1 addition & 1 deletion integrationtests/self/handshake_rtt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ var _ = Describe("Handshake RTT tests", func() {
defer udpConn.Close()
tr := &quic.Transport{
Conn: udpConn,
VerifySourceAddress: func() bool { return true },
VerifySourceAddress: func(net.Addr) bool { return true },
}
addTracer(tr)
defer tr.Close()
Expand Down
2 changes: 1 addition & 1 deletion integrationtests/self/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ var _ = Describe("Handshake tests", func() {
defer udpConn.Close()
tr := &quic.Transport{
Conn: udpConn,
VerifySourceAddress: func() bool { return true },
VerifySourceAddress: func(net.Addr) bool { return true },
}
addTracer(tr)
defer tr.Close()
Expand Down
2 changes: 1 addition & 1 deletion integrationtests/self/mitm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ var _ = Describe("MITM test", func() {
}
addTracer(serverTransport)
if forceAddressValidation {
serverTransport.VerifySourceAddress = func() bool { return true }
serverTransport.VerifySourceAddress = func(net.Addr) bool { return true }
}
ln, err := serverTransport.Listen(getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
Expand Down
2 changes: 1 addition & 1 deletion integrationtests/self/zero_rtt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ var _ = Describe("0-RTT", func() {
defer udpConn.Close()
tr := &quic.Transport{
Conn: udpConn,
VerifySourceAddress: func() bool { return true },
VerifySourceAddress: func(net.Addr) bool { return true },
}
addTracer(tr)
defer tr.Close()
Expand Down
2 changes: 1 addition & 1 deletion interop/http09/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (s *Server) ListenAndServe() error {
tlsConf.NextProtos = []string{h09alpn}
tr := quic.Transport{Conn: conn}
if s.ForceRetry {
tr.VerifySourceAddress = func() bool { return true }
tr.VerifySourceAddress = func(net.Addr) bool { return true }
}
ln, err := tr.ListenEarly(tlsConf, s.QuicConfig)
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ type baseServer struct {
connectionRefusedQueue chan rejectedPacket
retryQueue chan rejectedPacket

verifySourceAddress func() bool
verifySourceAddress func(net.Addr) bool

connQueue chan quicConn

Expand Down Expand Up @@ -237,7 +237,7 @@ func newServer(
onClose func(),
tokenGeneratorKey TokenGeneratorKey,
maxTokenAge time.Duration,
verifySourceAddress func() bool,
verifySourceAddress func(net.Addr) bool,
disableVersionNegotiation bool,
acceptEarly bool,
) *baseServer {
Expand Down Expand Up @@ -598,7 +598,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
}

if token == nil && s.verifySourceAddress != nil && s.verifySourceAddress() {
if token == nil && s.verifySourceAddress != nil && s.verifySourceAddress(p.remoteAddr) {
// Retry invalidates all 0-RTT packets sent.
delete(s.zeroRTTQueues, hdr.DestConnectionID)
select {
Expand All @@ -614,7 +614,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
if s.config.GetConfigForClient != nil {
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{
RemoteAddr: p.remoteAddr,
AddrVerified: clientAddrValidated,
AddrVerified: clientAddrVerified,
})
if err != nil {
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
Expand Down
25 changes: 15 additions & 10 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ var _ = Describe("Server", func() {
})

It("creates a connection when the token is accepted", func() {
serv.verifySourceAddress = func() bool { return true }
serv.verifySourceAddress = func(net.Addr) bool { return true }
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
retryToken, err := serv.tokenGenerator.NewRetryToken(
raddr,
Expand Down Expand Up @@ -435,15 +435,20 @@ var _ = Describe("Server", func() {

It("replies with a Retry packet, if a token is required", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
serv.verifySourceAddress = func() bool { return true }
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
var called bool
serv.verifySourceAddress = func(addr net.Addr) bool {
Expect(addr).To(Equal(raddr))
called = true
return true
}
hdr := &wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: connID,
Version: protocol.Version1,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
packet.remoteAddr = raddr
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) {
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
Expand All @@ -465,6 +470,7 @@ var _ = Describe("Server", func() {
phm.EXPECT().Get(connID)
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
Expect(called).To(BeTrue())
})

It("creates a connection, if no token is required", func() {
Expand Down Expand Up @@ -542,7 +548,7 @@ var _ = Describe("Server", func() {
})

It("drops packets if the receive queue is full", func() {
serv.verifySourceAddress = func() bool { return false }
serv.verifySourceAddress = func(net.Addr) bool { return false }

phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
Expand Down Expand Up @@ -644,7 +650,7 @@ var _ = Describe("Server", func() {
It("limits the number of unvalidated handshakes", func() {
const limit = 3
limiter := rate.NewLimiter(0, limit)
serv.verifySourceAddress = func() bool { return !limiter.Allow() }
serv.verifySourceAddress = func(net.Addr) bool { return !limiter.Allow() }

phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
Expand Down Expand Up @@ -768,7 +774,7 @@ var _ = Describe("Server", func() {
})

It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
serv.verifySourceAddress = func() bool { return true }
serv.verifySourceAddress = func(net.Addr) bool { return true }
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
Expand Down Expand Up @@ -804,7 +810,7 @@ var _ = Describe("Server", func() {
})

It("sends an INVALID_TOKEN error, if an expired retry token is received", func() {
serv.verifySourceAddress = func() bool { return true }
serv.verifySourceAddress = func(net.Addr) bool { return true }
serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout
Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond))
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
Expand Down Expand Up @@ -842,7 +848,7 @@ var _ = Describe("Server", func() {
})

It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() {
serv.verifySourceAddress = func() bool { return true }
serv.verifySourceAddress = func(net.Addr) bool { return true }
token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337})
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
Expand Down Expand Up @@ -871,7 +877,7 @@ var _ = Describe("Server", func() {
})

It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() {
serv.verifySourceAddress = func() bool { return true }
serv.verifySourceAddress = func(net.Addr) bool { return true }
serv.maxTokenAge = time.Millisecond
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
token, err := serv.tokenGenerator.NewToken(raddr)
Expand Down Expand Up @@ -900,7 +906,6 @@ var _ = Describe("Server", func() {
})

It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
serv.verifySourceAddress = func() bool { return true }
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
Expand Down
4 changes: 3 additions & 1 deletion transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,13 @@ type Transport struct {
// VerifySourceAddress decides if a connection attempt originating from unvalidated source
// addresses first needs to go through source address validation using QUIC's Retry mechanism,
// as described in RFC 9000 section 8.1.2.
// Note that the address passed to this callback is unvalidated, and might be spoofed in case
// of an attack.
// Validating the source address adds one additional network roundtrip to the handshake,
// and should therefore only be used if a suspiciously high number of incoming connection is recorded.
// For most use cases, wrapping the Allow function of a rate.Limiter will be a reasonable
// implementation of this callback (negating its return value).
VerifySourceAddress func() bool
VerifySourceAddress func(net.Addr) bool

// A Tracer traces events that don't belong to a single QUIC connection.
// Tracer.Close is called when the transport is closed.
Expand Down
Loading