Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initialize the MTU discoverer when processing the transport parameters #4514

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
initialize the MTU discoverer when processing the transport parameters
On the client side, we always use the configured packet size. This comes
with the risk of failing the handshake if the path doesn't support this
MTU. If the server sends a max_udp_payload_size that's smaller than this
size, we can safely ignore this: Obviously, the server still processed
the (fully padded) Initial packet, despite claiming that it wouldn't do
so.

On the server side, there's no downside to using 1200 bytes until we
received the client's transport parameters:
* If the first packet didn't contain the entire ClientHello, all we can
do is ACK that packet. We don't need a lot of bytes for that.
* If it did, we will have processed the transport parameters and
initialized the MTU discoverer.
  • Loading branch information
marten-seemann committed May 14, 2024
commit b4a6d66bee50d0af68c93ccd43e0d6a4a626b1b0
58 changes: 40 additions & 18 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ type connection struct {
unpacker unpacker
frameParser wire.FrameParser
packer packer
mtuDiscoverer mtuDiscoverer // initialized when the handshake completes
mtuDiscoverer mtuDiscoverer // initialized when the transport parameters are received
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the original comment here was incorrect: We used to initialize the mtuDiscoverer with the connection, not on handshake completion.


maxPayloadSizeEstimate atomic.Uint32

Expand Down Expand Up @@ -286,7 +286,6 @@ var newConnection = func(
s.tracer,
s.logger,
)
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, protocol.ByteCount(s.config.InitialPacketSize), s.onMTUIncreased)
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
params := &wire.TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
Expand Down Expand Up @@ -397,7 +396,6 @@ var newClientConnection = func(
s.tracer,
s.logger,
)
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, protocol.ByteCount(s.config.InitialPacketSize), s.onMTUIncreased)
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
oneRTTStream := newCryptoStream()
params := &wire.TransportParameters{
Expand Down Expand Up @@ -787,11 +785,7 @@ func (s *connection) handleHandshakeConfirmed() error {
s.cryptoStreamHandler.SetHandshakeConfirmed()

if !s.config.DisablePathMTUDiscovery && s.conn.capabilities().DF {
maxPacketSize := s.peerParams.MaxUDPPayloadSize
if maxPacketSize == 0 {
maxPacketSize = protocol.MaxByteCount
}
s.mtuDiscoverer.Start(min(maxPacketSize, protocol.MaxPacketBufferSize))
s.mtuDiscoverer.Start()
}
return nil
}
Expand Down Expand Up @@ -1780,6 +1774,16 @@ func (s *connection) applyTransportParameters() {
// Retire the connection ID.
s.connIDManager.AddFromPreferredAddress(params.PreferredAddress.ConnectionID, params.PreferredAddress.StatelessResetToken)
}
maxPacketSize := protocol.ByteCount(protocol.MaxPacketBufferSize)
if params.MaxUDPPayloadSize > 0 && params.MaxUDPPayloadSize < maxPacketSize {
maxPacketSize = params.MaxUDPPayloadSize
}
s.mtuDiscoverer = newMTUDiscoverer(
s.rttStats,
protocol.ByteCount(s.config.InitialPacketSize),
maxPacketSize,
s.onMTUIncreased,
)
}

func (s *connection) triggerSending(now time.Time) error {
Expand Down Expand Up @@ -1868,7 +1872,7 @@ func (s *connection) sendPackets(now time.Time) error {
}

if !s.handshakeConfirmed {
packet, err := s.packer.PackCoalescedPacket(false, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err := s.packer.PackCoalescedPacket(false, s.maxPacketSize(), s.version)
if err != nil || packet == nil {
return err
}
Expand All @@ -1895,7 +1899,7 @@ func (s *connection) sendPacketsWithoutGSO(now time.Time) error {
for {
buf := getPacketBuffer()
ecn := s.sentPacketHandler.ECNMode(true)
if _, err := s.appendOneShortHeaderPacket(buf, s.mtuDiscoverer.CurrentSize(), ecn, now); err != nil {
if _, err := s.appendOneShortHeaderPacket(buf, s.maxPacketSize(), ecn, now); err != nil {
if err == errNothingToPack {
buf.Release()
return nil
Expand Down Expand Up @@ -1926,7 +1930,7 @@ func (s *connection) sendPacketsWithoutGSO(now time.Time) error {

func (s *connection) sendPacketsWithGSO(now time.Time) error {
buf := getLargePacketBuffer()
maxSize := s.mtuDiscoverer.CurrentSize()
maxSize := s.maxPacketSize()

ecn := s.sentPacketHandler.ECNMode(true)
for {
Expand Down Expand Up @@ -1995,7 +1999,7 @@ func (s *connection) resetPacingDeadline() {
func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
if !s.handshakeConfirmed {
ecn := s.sentPacketHandler.ECNMode(false)
packet, err := s.packer.PackCoalescedPacket(true, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err := s.packer.PackCoalescedPacket(true, s.maxPacketSize(), s.version)
if err != nil {
return err
}
Expand All @@ -2006,7 +2010,7 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
}

ecn := s.sentPacketHandler.ECNMode(true)
p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version)
p, buf, err := s.packer.PackAckOnlyPacket(s.maxPacketSize(), s.version)
if err != nil {
if err == errNothingToPack {
return nil
Expand All @@ -2028,7 +2032,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time
break
}
var err error
packet, err = s.packer.MaybePackProbePacket(encLevel, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err = s.packer.MaybePackProbePacket(encLevel, s.maxPacketSize(), s.version)
if err != nil {
return err
}
Expand All @@ -2039,7 +2043,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time
if packet == nil {
s.retransmissionQueue.AddPing(encLevel)
var err error
packet, err = s.packer.MaybePackProbePacket(encLevel, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err = s.packer.MaybePackProbePacket(encLevel, s.maxPacketSize(), s.version)
if err != nil {
return err
}
Expand Down Expand Up @@ -2118,14 +2122,14 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) {
var transportErr *qerr.TransportError
var applicationErr *qerr.ApplicationError
if errors.As(e, &transportErr) {
packet, err = s.packer.PackConnectionClose(transportErr, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err = s.packer.PackConnectionClose(transportErr, s.maxPacketSize(), s.version)
} else if errors.As(e, &applicationErr) {
packet, err = s.packer.PackApplicationClose(applicationErr, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err = s.packer.PackApplicationClose(applicationErr, s.maxPacketSize(), s.version)
} else {
packet, err = s.packer.PackConnectionClose(&qerr.TransportError{
ErrorCode: qerr.InternalError,
ErrorMessage: fmt.Sprintf("connection BUG: unspecified error type (msg: %s)", e.Error()),
}, s.mtuDiscoverer.CurrentSize(), s.version)
}, s.maxPacketSize(), s.version)
}
if err != nil {
return nil, err
Expand All @@ -2135,6 +2139,24 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) {
return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0, ecn)
}

func (s *connection) maxPacketSize() protocol.ByteCount {
if s.mtuDiscoverer == nil {
// Use the configured packet size on the client side.
// If the server sends a max_udp_payload_size that's smaller than this size, we can ignore this:
// Apparently the server still processed the (fully padded) Initial packet anyway.
if s.perspective == protocol.PerspectiveClient {
return protocol.ByteCount(s.config.InitialPacketSize)
}
// On the server side, there's no downside to using 1200 bytes until we received the client's transport
// parameters:
// * If the first packet didn't contain the entire ClientHello, all we can do is ACK that packet. We don't
// need a lot of bytes for that.
// * If it did, we will have processed the transport parameters and initialized the MTU discoverer.
return protocol.MinInitialPacketSize
}
return s.mtuDiscoverer.CurrentSize()
}

func (s *connection) logLongHeaderPacket(p *longHeaderPacket, ecn protocol.ECN) {
// quic-go logging
if s.logger.Debug() {
Expand Down
27 changes: 13 additions & 14 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1439,15 +1439,15 @@ var _ = Describe("Connection", func() {
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
sph.EXPECT().ECNMode(true).Return(protocol.ECT1).Times(4)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3)
payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize())
payload1 := make([]byte, conn.maxPacketSize())
rand.Read(payload1)
payload2 := make([]byte, conn.mtuDiscoverer.CurrentSize())
payload2 := make([]byte, conn.maxPacketSize())
rand.Read(payload2)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, payload1)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2)
packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
sender.EXPECT().Send(gomock.Any(), uint16(conn.maxPacketSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal(append(payload1, payload2...)))
})
go func() {
Expand All @@ -1466,20 +1466,20 @@ var _ = Describe("Connection", func() {
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(3)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
sph.EXPECT().ECNMode(true).Times(4)
payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize())
payload1 := make([]byte, conn.maxPacketSize())
rand.Read(payload1)
payload2 := make([]byte, conn.mtuDiscoverer.CurrentSize()-1)
payload2 := make([]byte, conn.maxPacketSize()-1)
rand.Read(payload2)
payload3 := make([]byte, conn.mtuDiscoverer.CurrentSize())
payload3 := make([]byte, conn.maxPacketSize())
rand.Read(payload3)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, payload1)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 12}, payload3)
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
sender.EXPECT().Send(gomock.Any(), uint16(conn.maxPacketSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal(append(payload1, payload2...)))
})
sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
sender.EXPECT().Send(gomock.Any(), uint16(conn.maxPacketSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal(payload3))
})
go func() {
Expand All @@ -1499,20 +1499,20 @@ var _ = Describe("Connection", func() {
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
sph.EXPECT().ECNMode(true).Return(protocol.ECT1).Times(2)
sph.EXPECT().ECNMode(true).Return(protocol.ECT0).Times(2)
payload1 := make([]byte, conn.mtuDiscoverer.CurrentSize())
payload1 := make([]byte, conn.maxPacketSize())
rand.Read(payload1)
payload2 := make([]byte, conn.mtuDiscoverer.CurrentSize())
payload2 := make([]byte, conn.maxPacketSize())
rand.Read(payload2)
payload3 := make([]byte, conn.mtuDiscoverer.CurrentSize())
payload3 := make([]byte, conn.maxPacketSize())
rand.Read(payload3)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 10}, payload1)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload2)
expectAppendPacket(packer, shortHeaderPacket{PacketNumber: 11}, payload3)
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
sender.EXPECT().Send(gomock.Any(), uint16(conn.maxPacketSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal(append(payload1, payload2...)))
})
sender.EXPECT().Send(gomock.Any(), uint16(conn.mtuDiscoverer.CurrentSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
sender.EXPECT().Send(gomock.Any(), uint16(conn.maxPacketSize()), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
Expect(b.Data).To(Equal(payload3))
})
go func() {
Expand Down Expand Up @@ -2504,7 +2504,6 @@ var _ = Describe("Connection", func() {
Expect(err).To(BeAssignableToTypeOf(&DatagramTooLargeError{}))
derr := err.(*DatagramTooLargeError)
Expect(derr.MaxDatagramPayloadSize).To(BeNumerically("<", 1000))
fmt.Println(derr.MaxDatagramPayloadSize)
Expect(conn.SendDatagram(make([]byte, derr.MaxDatagramPayloadSize))).To(Succeed())
})

Expand Down
12 changes: 6 additions & 6 deletions mock_mtu_discoverer_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 11 additions & 4 deletions mtu_discoverer.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
type mtuDiscoverer interface {
// Start starts the MTU discovery process.
// It's unnecessary to call ShouldSendProbe before that.
Start(maxPacketSize protocol.ByteCount)
Start()
ShouldSendProbe(now time.Time) bool
CurrentSize() protocol.ByteCount
GetPing() (ping ackhandler.Frame, datagramSize protocol.ByteCount)
Expand All @@ -38,10 +38,11 @@ type mtuFinder struct {

var _ mtuDiscoverer = &mtuFinder{}

func newMTUDiscoverer(rttStats *utils.RTTStats, start protocol.ByteCount, mtuIncreased func(protocol.ByteCount)) *mtuFinder {
func newMTUDiscoverer(rttStats *utils.RTTStats, start, max protocol.ByteCount, mtuIncreased func(protocol.ByteCount)) *mtuFinder {
return &mtuFinder{
inFlight: protocol.InvalidByteCount,
current: start,
max: max,
rttStats: rttStats,
mtuIncreased: mtuIncreased,
}
Expand All @@ -51,9 +52,15 @@ func (f *mtuFinder) done() bool {
return f.max-f.current <= maxMTUDiff+1
}

func (f *mtuFinder) Start(maxPacketSize protocol.ByteCount) {
func (f *mtuFinder) SetMax(max protocol.ByteCount) {
f.max = max
}

func (f *mtuFinder) Start() {
if f.max == protocol.InvalidByteCount {
panic("invalid")
}
f.lastProbeTime = time.Now() // makes sure the first probe packet is not sent immediately
f.max = maxPacketSize
}

func (f *mtuFinder) ShouldSendProbe(now time.Time) bool {
Expand Down
10 changes: 5 additions & 5 deletions mtu_discoverer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ var _ = Describe("MTU Discoverer", func() {
rttStats = &utils.RTTStats{}
rttStats.SetInitialRTT(rtt)
Expect(rttStats.SmoothedRTT()).To(Equal(rtt))
d = newMTUDiscoverer(rttStats, startMTU, func(s protocol.ByteCount) { discoveredMTU = s })
d.Start(maxMTU)
d = newMTUDiscoverer(rttStats, startMTU, maxMTU, func(s protocol.ByteCount) { discoveredMTU = s })
d.Start()
now = time.Now()
})

Expand Down Expand Up @@ -78,7 +78,7 @@ var _ = Describe("MTU Discoverer", func() {
})

It("doesn't do discovery before being started", func() {
d := newMTUDiscoverer(rttStats, startMTU, func(s protocol.ByteCount) {})
d := newMTUDiscoverer(rttStats, startMTU, protocol.MaxByteCount, func(s protocol.ByteCount) {})
for i := 0; i < 5; i++ {
Expect(d.ShouldSendProbe(time.Now())).To(BeFalse())
}
Expand All @@ -90,8 +90,8 @@ var _ = Describe("MTU Discoverer", func() {
for i := 0; i < rep; i++ {
maxMTU := protocol.ByteCount(rand.Intn(int(3000-startMTU))) + startMTU + 1
currentMTU := startMTU
d := newMTUDiscoverer(rttStats, startMTU, func(s protocol.ByteCount) { currentMTU = s })
d.Start(maxMTU)
d := newMTUDiscoverer(rttStats, startMTU, maxMTU, func(s protocol.ByteCount) { currentMTU = s })
d.Start()
now := time.Now()
realMTU := protocol.ByteCount(rand.Intn(int(maxMTU-startMTU))) + startMTU
t := now.Add(mtuProbeDelay * rtt)
Expand Down
Loading