Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reject sending of DATAGRAM frames that exceed the current MTU #4497

Merged
merged 2 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ type connection struct {
packer packer
mtuDiscoverer mtuDiscoverer // initialized when the handshake completes

maxPayloadSizeEstimate atomic.Uint32

initialStream cryptoStream
handshakeStream cryptoStream
oneRTTStream cryptoStream // only set for the server
Expand Down Expand Up @@ -274,17 +276,19 @@ var newConnection = func(
)
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
s.preSetup()
initialPacketSize := getMaxPacketSize(s.conn.RemoteAddr())
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
0,
getMaxPacketSize(s.conn.RemoteAddr()),
initialPacketSize,
s.rttStats,
clientAddressValidated,
s.conn.capabilities().ECN,
s.perspective,
s.tracer,
s.logger,
)
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, initialPacketSize, s.onMTUIncreased)
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(initialPacketSize)))
params := &wire.TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
Expand Down Expand Up @@ -383,17 +387,19 @@ var newClientConnection = func(
)
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
s.preSetup()
initialPacketSize := getMaxPacketSize(s.conn.RemoteAddr())
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
initialPacketNumber,
getMaxPacketSize(s.conn.RemoteAddr()),
initialPacketSize,
s.rttStats,
false, // has no effect
s.conn.capabilities().ECN,
s.perspective,
s.tracer,
s.logger,
)
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, initialPacketSize, s.onMTUIncreased)
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(initialPacketSize)))
oneRTTStream := newCryptoStream()
params := &wire.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
Expand Down Expand Up @@ -2352,13 +2358,23 @@ func (s *connection) onStreamCompleted(id protocol.StreamID) {
}
}

func (s *connection) onMTUIncreased(mtu protocol.ByteCount) {
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(mtu)))
s.sentPacketHandler.SetMaxDatagramSize(mtu)
}

func (s *connection) SendDatagram(p []byte) error {
if !s.supportsDatagrams() {
return errors.New("datagram support disabled")
}

f := &wire.DatagramFrame{DataLenPresent: true}
maxDataLen := f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version)
// The payload size estimate is conservative.
// Under many circumstances we could send a few more bytes.
maxDataLen := min(
f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version),
protocol.ByteCount(s.maxPayloadSizeEstimate.Load()),
)
if protocol.ByteCount(len(p)) > maxDataLen {
return &DatagramTooLargeError{MaxDatagramPayloadSize: int64(maxDataLen)}
}
Expand Down Expand Up @@ -2391,3 +2407,10 @@ func (s *connection) NextConnection() Connection {
s.streamsMap.UseResetMaps()
return s
}

// estimateMaxPayloadSize estimates the maximum payload size for short header packets.
// It is not very sophisticated: it just subtracts the size of header (assuming the maximum
// connection ID length), and the size of the encryption tag.
func estimateMaxPayloadSize(mtu protocol.ByteCount) protocol.ByteCount {
return mtu - 1 /* type byte */ - 20 /* maximum connection ID length */ - 16 /* tag size */
}
22 changes: 19 additions & 3 deletions integrationtests/self/dplpmtud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@ import (

var _ = Describe("DPLPMTUD", func() {
It("discovers the MTU", func() {
const rtt = 100 * time.Millisecond
rtt := scaleDuration(10 * time.Millisecond)
const mtu = 1400

ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}))
ln, err := quic.ListenAddr(
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, EnableDatagrams: true}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
go func() {
Expand Down Expand Up @@ -73,7 +77,7 @@ var _ = Describe("DPLPMTUD", func() {
context.Background(),
proxy.LocalAddr(),
getTLSClientConfig(),
getQuicConfig(nil),
getQuicConfig(&quic.Config{EnableDatagrams: true}),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
Expand All @@ -87,15 +91,27 @@ var _ = Describe("DPLPMTUD", func() {
Expect(err).ToNot(HaveOccurred())
Expect(data).To(Equal(PRDataLong))
}()
err = conn.SendDatagram(make([]byte, 2000))
Expect(err).To(BeAssignableToTypeOf(&quic.DatagramTooLargeError{}))
initialMaxDatagramSize := err.(*quic.DatagramTooLargeError).MaxDatagramPayloadSize
_, err = str.Write(PRDataLong)
Expect(err).ToNot(HaveOccurred())
str.Close()
Eventually(done, 20*time.Second).Should(BeClosed())
err = conn.SendDatagram(make([]byte, 2000))
Expect(err).To(BeAssignableToTypeOf(&quic.DatagramTooLargeError{}))
finalMaxDatagramSize := err.(*quic.DatagramTooLargeError).MaxDatagramPayloadSize

mx.Lock()
defer mx.Unlock()
fmt.Fprintf(GinkgoWriter, "max client packet size: %d, MTU: %d\n", maxPacketSizeClient, mtu)
fmt.Fprintf(GinkgoWriter, "max datagram size: initial: %d, final: %d\n", initialMaxDatagramSize, finalMaxDatagramSize)
fmt.Fprintf(GinkgoWriter, "max server packet size: %d, MTU: %d\n", maxPacketSizeServer, mtu)
Expect(maxPacketSizeClient).To(BeNumerically(">=", mtu-25))
const maxDiff = 40 // this includes the 21 bytes for the short header, 16 bytes for the encryption tag, and framing overhead
Expect(initialMaxDatagramSize).To(BeNumerically(">=", 1252-maxDiff))
Expect(finalMaxDatagramSize).To(BeNumerically(">=", maxPacketSizeClient-maxDiff))
// MTU discovery was disabled on the server side
Expect(maxPacketSizeServer).To(Equal(1252))
})
})
Loading