Skip to content

Commit

Permalink
initialize the MTU discoverer when processing the transport parameters
Browse files Browse the repository at this point in the history
On the client side, we always use the configured packet size. This comes
with the risk of failing the handshake if the path doesn't support this
MTU. If the server sends a max_udp_payload_size that's smaller than this
size, we can safely ignore this: Obviously, the server still processed
the (fully padded) Initial packet, despite claiming that it wouldn't do
so.

On the server side, there's no downside to using 1200 bytes until we
received the client's transport parameters:
* If the first packet didn't contain the entire ClientHello, all we can
do is ACK that packet. We don't need a lot of bytes for that.
* If it did, we will have processed the transport parameters and
initialized the MTU discoverer.
  • Loading branch information
marten-seemann committed May 14, 2024
1 parent 508b402 commit b4a6d66
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 47 deletions.
58 changes: 40 additions & 18 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ type connection struct {
unpacker unpacker
frameParser wire.FrameParser
packer packer
mtuDiscoverer mtuDiscoverer // initialized when the handshake completes
mtuDiscoverer mtuDiscoverer // initialized when the transport parameters are received

maxPayloadSizeEstimate atomic.Uint32

Expand Down Expand Up @@ -286,7 +286,6 @@ var newConnection = func(
s.tracer,
s.logger,
)
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, protocol.ByteCount(s.config.InitialPacketSize), s.onMTUIncreased)
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
params := &wire.TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
Expand Down Expand Up @@ -397,7 +396,6 @@ var newClientConnection = func(
s.tracer,
s.logger,
)
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, protocol.ByteCount(s.config.InitialPacketSize), s.onMTUIncreased)
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
oneRTTStream := newCryptoStream()
params := &wire.TransportParameters{
Expand Down Expand Up @@ -787,11 +785,7 @@ func (s *connection) handleHandshakeConfirmed() error {
s.cryptoStreamHandler.SetHandshakeConfirmed()

if !s.config.DisablePathMTUDiscovery && s.conn.capabilities().DF {
maxPacketSize := s.peerParams.MaxUDPPayloadSize
if maxPacketSize == 0 {
maxPacketSize = protocol.MaxByteCount
}
s.mtuDiscoverer.Start(min(maxPacketSize, protocol.MaxPacketBufferSize))
s.mtuDiscoverer.Start()
}
return nil
}
Expand Down Expand Up @@ -1780,6 +1774,16 @@ func (s *connection) applyTransportParameters() {
// Retire the connection ID.
s.connIDManager.AddFromPreferredAddress(params.PreferredAddress.ConnectionID, params.PreferredAddress.StatelessResetToken)
}
maxPacketSize := protocol.ByteCount(protocol.MaxPacketBufferSize)
if params.MaxUDPPayloadSize > 0 && params.MaxUDPPayloadSize < maxPacketSize {
maxPacketSize = params.MaxUDPPayloadSize
}
s.mtuDiscoverer = newMTUDiscoverer(
s.rttStats,
protocol.ByteCount(s.config.InitialPacketSize),
maxPacketSize,
s.onMTUIncreased,
)
}

func (s *connection) triggerSending(now time.Time) error {
Expand Down Expand Up @@ -1868,7 +1872,7 @@ func (s *connection) sendPackets(now time.Time) error {
}

if !s.handshakeConfirmed {
packet, err := s.packer.PackCoalescedPacket(false, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err := s.packer.PackCoalescedPacket(false, s.maxPacketSize(), s.version)
if err != nil || packet == nil {
return err
}
Expand All @@ -1895,7 +1899,7 @@ func (s *connection) sendPacketsWithoutGSO(now time.Time) error {
for {
buf := getPacketBuffer()
ecn := s.sentPacketHandler.ECNMode(true)
if _, err := s.appendOneShortHeaderPacket(buf, s.mtuDiscoverer.CurrentSize(), ecn, now); err != nil {
if _, err := s.appendOneShortHeaderPacket(buf, s.maxPacketSize(), ecn, now); err != nil {
if err == errNothingToPack {
buf.Release()
return nil
Expand Down Expand Up @@ -1926,7 +1930,7 @@ func (s *connection) sendPacketsWithoutGSO(now time.Time) error {

func (s *connection) sendPacketsWithGSO(now time.Time) error {
buf := getLargePacketBuffer()
maxSize := s.mtuDiscoverer.CurrentSize()
maxSize := s.maxPacketSize()

ecn := s.sentPacketHandler.ECNMode(true)
for {
Expand Down Expand Up @@ -1995,7 +1999,7 @@ func (s *connection) resetPacingDeadline() {
func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
if !s.handshakeConfirmed {
ecn := s.sentPacketHandler.ECNMode(false)
packet, err := s.packer.PackCoalescedPacket(true, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err := s.packer.PackCoalescedPacket(true, s.maxPacketSize(), s.version)
if err != nil {
return err
}
Expand All @@ -2006,7 +2010,7 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
}

ecn := s.sentPacketHandler.ECNMode(true)
p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version)
p, buf, err := s.packer.PackAckOnlyPacket(s.maxPacketSize(), s.version)
if err != nil {
if err == errNothingToPack {
return nil
Expand All @@ -2028,7 +2032,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time
break
}
var err error
packet, err = s.packer.MaybePackProbePacket(encLevel, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err = s.packer.MaybePackProbePacket(encLevel, s.maxPacketSize(), s.version)
if err != nil {
return err
}
Expand All @@ -2039,7 +2043,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time
if packet == nil {
s.retransmissionQueue.AddPing(encLevel)
var err error
packet, err = s.packer.MaybePackProbePacket(encLevel, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err = s.packer.MaybePackProbePacket(encLevel, s.maxPacketSize(), s.version)
if err != nil {
return err
}
Expand Down Expand Up @@ -2118,14 +2122,14 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) {
var transportErr *qerr.TransportError
var applicationErr *qerr.ApplicationError
if errors.As(e, &transportErr) {
packet, err = s.packer.PackConnectionClose(transportErr, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err = s.packer.PackConnectionClose(transportErr, s.maxPacketSize(), s.version)
} else if errors.As(e, &applicationErr) {
packet, err = s.packer.PackApplicationClose(applicationErr, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err = s.packer.PackApplicationClose(applicationErr, s.maxPacketSize(), s.version)
} else {
packet, err = s.packer.PackConnectionClose(&qerr.TransportError{
ErrorCode: qerr.InternalError,
ErrorMessage: fmt.Sprintf("connection BUG: unspecified error type (msg: %s)", e.Error()),
}, s.mtuDiscoverer.CurrentSize(), s.version)
}, s.maxPacketSize(), s.version)
}
if err != nil {
return nil, err
Expand All @@ -2135,6 +2139,24 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) {
return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0, ecn)
}

func (s *connection) maxPacketSize() protocol.ByteCount {
if s.mtuDiscoverer == nil {
// Use the configured packet size on the client side.
// If the server sends a max_udp_payload_size that's smaller than this size, we can ignore this:
// Apparently the server still processed the (fully padded) Initial packet anyway.
if s.perspective == protocol.PerspectiveClient {
return protocol.ByteCount(s.config.InitialPacketSize)
}
// On the server side, there's no downside to using 1200 bytes until we received the client's transport
// parameters:
// * If the first packet didn't contain the entire ClientHello, all we can do is ACK that packet. We don't
// need a lot of bytes for that.
// * If it did, we will have processed the transport parameters and initialized the MTU discoverer.
return protocol.MinInitialPacketSize
}
return s.mtuDiscoverer.CurrentSize()
}

func (s *connection) logLongHeaderPacket(p *longHeaderPacket, ecn protocol.ECN) {
// quic-go logging
if s.logger.Debug() {
Expand Down
27 changes: 13 additions & 14 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1439,15 +1439,15 @@ var _ = Describe("Connection", func() {
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
sph.EXPECT().ECNMode(true).Return(protocol.ECT1).Times(4)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3)
payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize())
payload1 := make([]byte, conn.maxPacketSize())
rand.Read(payload1)
payload2 := make([]byte, conn.mtuDiscoverer.CurrentSize())
payload2 := make([]byte, conn.maxPacketSize())
rand.Read(payload2)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, payload1)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2)
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
sender.EXPECT().Send(gomock.Any(), uint16(conn.maxPacketSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal(append(payload1, payload2...)))
})
go func() {
Expand All @@ -1466,20 +1466,20 @@ var _ = Describe("Connection", func() {
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
sph.EXPECT().ECNMode(true).Times(4)
payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize())
payload1 := make([]byte, conn.maxPacketSize())
rand.Read(payload1)
payload2 := make([]byte, conn.mtuDiscoverer.CurrentSize()-1)
payload2 := make([]byte, conn.maxPacketSize()-1)
rand.Read(payload2)
payload3 := make([]byte, conn.mtuDiscoverer.CurrentSize())
payload3 := make([]byte, conn.maxPacketSize())
rand.Read(payload3)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, payload1)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 12}, payload3)
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
sender.EXPECT().Send(gomock.Any(), uint16(conn.maxPacketSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal(append(payload1, payload2...)))
})
sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
sender.EXPECT().Send(gomock.Any(), uint16(conn.maxPacketSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal(payload3))
})
go func() {
Expand All @@ -1499,20 +1499,20 @@ var _ = Describe("Connection", func() {
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
sph.EXPECT().ECNMode(true).Return(protocol.ECT1).Times(2)
sph.EXPECT().ECNMode(true).Return(protocol.ECT0).Times(2)
payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize())
payload1 := make([]byte, conn.maxPacketSize())
rand.Read(payload1)
payload2 := make([]byte, conn.mtuDiscoverer.CurrentSize())
payload2 := make([]byte, conn.maxPacketSize())
rand.Read(payload2)
payload3 := make([]byte, conn.mtuDiscoverer.CurrentSize())
payload3 := make([]byte, conn.maxPacketSize())
rand.Read(payload3)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, payload1)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload3)
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
sender.EXPECT().Send(gomock.Any(), uint16(conn.maxPacketSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal(append(payload1, payload2...)))
})
sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
sender.EXPECT().Send(gomock.Any(), uint16(conn.maxPacketSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal(payload3))
})
go func() {
Expand Down Expand Up @@ -2504,7 +2504,6 @@ var _ = Describe("Connection", func() {
Expect(err).To(BeAssignableToTypeOf(&DatagramTooLargeError{}))
derr := err.(*DatagramTooLargeError)
Expect(derr.MaxDatagramPayloadSize).To(BeNumerically("<", 1000))
fmt.Println(derr.MaxDatagramPayloadSize)
Expect(conn.SendDatagram(make([]byte, derr.MaxDatagramPayloadSize))).To(Succeed())
})

Expand Down
12 changes: 6 additions & 6 deletions mock_mtu_discoverer_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 11 additions & 4 deletions mtu_discoverer.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
type mtuDiscoverer interface {
// Start starts the MTU discovery process.
// It's unnecessary to call ShouldSendProbe before that.
Start(maxPacketSize protocol.ByteCount)
Start()
ShouldSendProbe(now time.Time) bool
CurrentSize() protocol.ByteCount
GetPing() (ping ackhandler.Frame, datagramSize protocol.ByteCount)
Expand All @@ -38,10 +38,11 @@ type mtuFinder struct {

var _ mtuDiscoverer = &mtuFinder{}

func newMTUDiscoverer(rttStats *utils.RTTStats, start protocol.ByteCount, mtuIncreased func(protocol.ByteCount)) *mtuFinder {
func newMTUDiscoverer(rttStats *utils.RTTStats, start, max protocol.ByteCount, mtuIncreased func(protocol.ByteCount)) *mtuFinder {
return &mtuFinder{
inFlight: protocol.InvalidByteCount,
current: start,
max: max,
rttStats: rttStats,
mtuIncreased: mtuIncreased,
}
Expand All @@ -51,9 +52,15 @@ func (f *mtuFinder) done() bool {
return f.max-f.current <= maxMTUDiff+1
}

func (f *mtuFinder) Start(maxPacketSize protocol.ByteCount) {
func (f *mtuFinder) SetMax(max protocol.ByteCount) {
f.max = max
}

func (f *mtuFinder) Start() {
if f.max == protocol.InvalidByteCount {
panic("invalid")
}
f.lastProbeTime = time.Now() // makes sure the first probe packet is not sent immediately
f.max = maxPacketSize
}

func (f *mtuFinder) ShouldSendProbe(now time.Time) bool {
Expand Down
10 changes: 5 additions & 5 deletions mtu_discoverer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ var _ = Describe("MTU Discoverer", func() {
rttStats = &utils.RTTStats{}
rttStats.SetInitialRTT(rtt)
Expect(rttStats.SmoothedRTT()).To(Equal(rtt))
d = newMTUDiscoverer(rttStats, startMTU, func(s protocol.ByteCount) { discoveredMTU = s })
d.Start(maxMTU)
d = newMTUDiscoverer(rttStats, startMTU, maxMTU, func(s protocol.ByteCount) { discoveredMTU = s })
d.Start()
now = time.Now()
})

Expand Down Expand Up @@ -78,7 +78,7 @@ var _ = Describe("MTU Discoverer", func() {
})

It("doesn't do discovery before being started", func() {
d := newMTUDiscoverer(rttStats, startMTU, func(s protocol.ByteCount) {})
d := newMTUDiscoverer(rttStats, startMTU, protocol.MaxByteCount, func(s protocol.ByteCount) {})
for i := 0; i < 5; i++ {
Expect(d.ShouldSendProbe(time.Now())).To(BeFalse())
}
Expand All @@ -90,8 +90,8 @@ var _ = Describe("MTU Discoverer", func() {
for i := 0; i < rep; i++ {
maxMTU := protocol.ByteCount(rand.Intn(int(3000-startMTU))) + startMTU + 1
currentMTU := startMTU
d := newMTUDiscoverer(rttStats, startMTU, func(s protocol.ByteCount) { currentMTU = s })
d.Start(maxMTU)
d := newMTUDiscoverer(rttStats, startMTU, maxMTU, func(s protocol.ByteCount) { currentMTU = s })
d.Start()
now := time.Now()
realMTU := protocol.ByteCount(rand.Intn(int(maxMTU-startMTU))) + startMTU
t := now.Add(mtuProbeDelay * rtt)
Expand Down

0 comments on commit b4a6d66

Please sign in to comment.