Skip to content

Commit

Permalink
pass the remote address to Transport.VerifySourceAddress
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Mar 13, 2024
1 parent f18f897 commit 49e2cdc
Show file tree
Hide file tree
Showing 9 changed files with 27 additions and 20 deletions.
2 changes: 1 addition & 1 deletion integrationtests/self/handshake_drop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ var _ = Describe("Handshake drop tests", func() {
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{Conn: conn}
if doRetry {
tr.VerifySourceAddress = func() bool { return true }
tr.VerifySourceAddress = func(net.Addr) bool { return true }
}
ln, err = tr.Listen(tlsConf, conf)
Expect(err).ToNot(HaveOccurred())
Expand Down
2 changes: 1 addition & 1 deletion integrationtests/self/handshake_rtt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ var _ = Describe("Handshake RTT tests", func() {
defer udpConn.Close()
tr := &quic.Transport{
Conn: udpConn,
VerifySourceAddress: func() bool { return true },
VerifySourceAddress: func(net.Addr) bool { return true },
}
addTracer(tr)
defer tr.Close()
Expand Down
2 changes: 1 addition & 1 deletion integrationtests/self/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ var _ = Describe("Handshake tests", func() {
defer udpConn.Close()
tr := &quic.Transport{
Conn: udpConn,
VerifySourceAddress: func() bool { return true },
VerifySourceAddress: func(net.Addr) bool { return true },
}
addTracer(tr)
defer tr.Close()
Expand Down
2 changes: 1 addition & 1 deletion integrationtests/self/mitm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ var _ = Describe("MITM test", func() {
}
addTracer(serverTransport)
if forceAddressValidation {
serverTransport.VerifySourceAddress = func() bool { return true }
serverTransport.VerifySourceAddress = func(net.Addr) bool { return true }
}
ln, err := serverTransport.Listen(getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
Expand Down
2 changes: 1 addition & 1 deletion integrationtests/self/zero_rtt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ var _ = Describe("0-RTT", func() {
defer udpConn.Close()
tr := &quic.Transport{
Conn: udpConn,
VerifySourceAddress: func() bool { return true },
VerifySourceAddress: func(net.Addr) bool { return true },
}
addTracer(tr)
defer tr.Close()
Expand Down
2 changes: 1 addition & 1 deletion interop/http09/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (s *Server) ListenAndServe() error {
tlsConf.NextProtos = []string{h09alpn}
tr := quic.Transport{Conn: conn}
if s.ForceRetry {
tr.VerifySourceAddress = func() bool { return true }
tr.VerifySourceAddress = func(net.Addr) bool { return true }
}
ln, err := tr.ListenEarly(tlsConf, s.QuicConfig)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ type baseServer struct {
connectionRefusedQueue chan rejectedPacket
retryQueue chan rejectedPacket

verifySourceAddress func() bool
verifySourceAddress func(net.Addr) bool

connQueue chan quicConn

Expand Down Expand Up @@ -237,7 +237,7 @@ func newServer(
onClose func(),
tokenGeneratorKey TokenGeneratorKey,
maxTokenAge time.Duration,
verifySourceAddress func() bool,
verifySourceAddress func(net.Addr) bool,
disableVersionNegotiation bool,
acceptEarly bool,
) *baseServer {
Expand Down Expand Up @@ -598,7 +598,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
}

if token == nil && s.verifySourceAddress != nil && s.verifySourceAddress() {
if token == nil && s.verifySourceAddress != nil && s.verifySourceAddress(p.remoteAddr) {
// Retry invalidates all 0-RTT packets sent.
delete(s.zeroRTTQueues, hdr.DestConnectionID)
select {
Expand Down
25 changes: 15 additions & 10 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ var _ = Describe("Server", func() {
})

It("creates a connection when the token is accepted", func() {
serv.verifySourceAddress = func() bool { return true }
serv.verifySourceAddress = func(net.Addr) bool { return true }
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
retryToken, err := serv.tokenGenerator.NewRetryToken(
raddr,
Expand Down Expand Up @@ -435,15 +435,20 @@ var _ = Describe("Server", func() {

It("replies with a Retry packet, if a token is required", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
serv.verifySourceAddress = func() bool { return true }
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
var called bool
serv.verifySourceAddress = func(addr net.Addr) bool {
Expect(addr).To(Equal(raddr))
called = true
return true
}
hdr := &wire.Header{
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: connID,
Version: protocol.Version1,
}
packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
packet.remoteAddr = raddr
tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) {
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
Expand All @@ -465,6 +470,7 @@ var _ = Describe("Server", func() {
phm.EXPECT().Get(connID)
serv.handlePacket(packet)
Eventually(done).Should(BeClosed())
Expect(called).To(BeTrue())
})

It("creates a connection, if no token is required", func() {
Expand Down Expand Up @@ -542,7 +548,7 @@ var _ = Describe("Server", func() {
})

It("drops packets if the receive queue is full", func() {
serv.verifySourceAddress = func() bool { return false }
serv.verifySourceAddress = func(net.Addr) bool { return false }

phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
Expand Down Expand Up @@ -644,7 +650,7 @@ var _ = Describe("Server", func() {
It("limits the number of unvalidated handshakes", func() {
const limit = 3
limiter := rate.NewLimiter(0, limit)
serv.verifySourceAddress = func() bool { return !limiter.Allow() }
serv.verifySourceAddress = func(net.Addr) bool { return !limiter.Allow() }

phm.EXPECT().Get(gomock.Any()).AnyTimes()
phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes()
Expand Down Expand Up @@ -768,7 +774,7 @@ var _ = Describe("Server", func() {
})

It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() {
serv.verifySourceAddress = func() bool { return true }
serv.verifySourceAddress = func(net.Addr) bool { return true }
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
Expand Down Expand Up @@ -804,7 +810,7 @@ var _ = Describe("Server", func() {
})

It("sends an INVALID_TOKEN error, if an expired retry token is received", func() {
serv.verifySourceAddress = func() bool { return true }
serv.verifySourceAddress = func(net.Addr) bool { return true }
serv.config.HandshakeIdleTimeout = time.Millisecond / 2 // the maximum retry token age is equivalent to the handshake timeout
Expect(serv.config.maxRetryTokenAge()).To(Equal(time.Millisecond))
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
Expand Down Expand Up @@ -842,7 +848,7 @@ var _ = Describe("Server", func() {
})

It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() {
serv.verifySourceAddress = func() bool { return true }
serv.verifySourceAddress = func(net.Addr) bool { return true }
token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337})
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
Expand Down Expand Up @@ -871,7 +877,7 @@ var _ = Describe("Server", func() {
})

It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() {
serv.verifySourceAddress = func() bool { return true }
serv.verifySourceAddress = func(net.Addr) bool { return true }
serv.maxTokenAge = time.Millisecond
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
token, err := serv.tokenGenerator.NewToken(raddr)
Expand Down Expand Up @@ -900,7 +906,6 @@ var _ = Describe("Server", func() {
})

It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() {
serv.verifySourceAddress = func() bool { return true }
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
Expand Down
4 changes: 3 additions & 1 deletion transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,13 @@ type Transport struct {
// VerifySourceAddress decides if a connection attempt originating from unvalidated source
// addresses first needs to go through source address validation using QUIC's Retry mechanism,
// as described in RFC 9000 section 8.1.2.
// Note that the address passed to this callback is unvalidated, and might be spoofed in case
// of an attack.
// Validating the source address adds one additional network roundtrip to the handshake,
// and should therefore only be used if a suspiciously high number of incoming connection is recorded.
// For most use cases, wrapping the Allow function of a rate.Limiter will be a reasonable
// implementation of this callback (negating its return value).
VerifySourceAddress func() bool
VerifySourceAddress func(net.Addr) bool

// A Tracer traces events that don't belong to a single QUIC connection.
// Tracer.Close is called when the transport is closed.
Expand Down

0 comments on commit 49e2cdc

Please sign in to comment.