Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wire: refactor header parsing to use quicvarint.Parse #4481

Merged
merged 3 commits into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 62 additions & 58 deletions internal/wire/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"io"

"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
)

Expand Down Expand Up @@ -139,18 +138,18 @@ type Header struct {
parsedLen protocol.ByteCount // how many bytes were read while parsing this header
}

// ParsePacket parses a packet.
// If the packet has a long header, the packet is cut according to the length field.
// If we understand the version, the packet is header up unto the packet number.
// ParsePacket parses a long header packet.
// The packet is cut according to the length field.
// If we understand the version, the packet is parsed up unto the packet number.
// Otherwise, only the invariant part of the header is parsed.
func ParsePacket(data []byte) (*Header, []byte, []byte, error) {
if len(data) == 0 || !IsLongHeaderPacket(data[0]) {
return nil, nil, nil, errors.New("not a long header packet")
}
hdr, err := parseHeader(bytes.NewReader(data))
hdr, err := parseHeader(data)
if err != nil {
if err == ErrUnsupportedVersion {
return hdr, nil, nil, ErrUnsupportedVersion
if errors.Is(err, ErrUnsupportedVersion) {
return hdr, nil, nil, err
}
return nil, nil, nil, err
}
Expand All @@ -161,55 +160,59 @@ func ParsePacket(data []byte) (*Header, []byte, []byte, error) {
return hdr, data[:packetLen], data[packetLen:], nil
}

// ParseHeader parses the header.
// For short header packets: up to the packet number.
// For long header packets:
// ParseHeader parses the header:
// * if we understand the version: up to the packet number
// * if not, only the invariant part of the header
func parseHeader(b *bytes.Reader) (*Header, error) {
startLen := b.Len()
typeByte, err := b.ReadByte()
if err != nil {
return nil, err
func parseHeader(b []byte) (*Header, error) {
if len(b) == 0 {
return nil, io.EOF
}
typeByte := b[0]

h := &Header{typeByte: typeByte}
err = h.parseLongHeader(b)
h.parsedLen = protocol.ByteCount(startLen - b.Len())
l, err := h.parseLongHeader(b[1:])
h.parsedLen = l + 1
return h, err
}

func (h *Header) parseLongHeader(b *bytes.Reader) error {
v, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return err
func (h *Header) parseLongHeader(b []byte) (l protocol.ByteCount, err error) {
if len(b) < 5 {
return 0, io.EOF
}
h.Version = protocol.Version(v)
h.Version = protocol.Version(binary.BigEndian.Uint32(b[:4]))
l = 4
if h.Version != 0 && h.typeByte&0x40 == 0 {
return errors.New("not a QUIC packet")
}
destConnIDLen, err := b.ReadByte()
if err != nil {
return err
}
h.DestConnectionID, err = protocol.ReadConnectionID(b, int(destConnIDLen))
if err != nil {
return err
}
srcConnIDLen, err := b.ReadByte()
if err != nil {
return err
}
h.SrcConnectionID, err = protocol.ReadConnectionID(b, int(srcConnIDLen))
if err != nil {
return err
return l, errors.New("not a QUIC packet")
}
destConnIDLen := int(b[4])
l++
if destConnIDLen > protocol.MaxConnIDLen {
return l, protocol.ErrInvalidConnectionIDLen
}
b = b[5:]
if len(b) < destConnIDLen+1 {
return l, io.EOF
}
h.DestConnectionID = protocol.ParseConnectionID(b[:destConnIDLen])
l += protocol.ByteCount(destConnIDLen)
srcConnIDLen := int(b[destConnIDLen])
l++
if srcConnIDLen > protocol.MaxConnIDLen {
return l, protocol.ErrInvalidConnectionIDLen
}
b = b[destConnIDLen+1:]
if len(b) < srcConnIDLen {
return l, io.EOF
}
h.SrcConnectionID = protocol.ParseConnectionID(b[:srcConnIDLen])
l += protocol.ByteCount(srcConnIDLen)
b = b[srcConnIDLen:]
if h.Version == 0 { // version negotiation packet
return nil
return l, nil
}
// If we don't understand the version, we have no idea how to interpret the rest of the bytes
if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
return ErrUnsupportedVersion
return l, ErrUnsupportedVersion
}

if h.Version == protocol.Version2 {
Expand Down Expand Up @@ -237,38 +240,39 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error {
}

if h.Type == protocol.PacketTypeRetry {
tokenLen := b.Len() - 16
tokenLen := len(b) - 16
if tokenLen <= 0 {
return io.EOF
return l, io.EOF
}
h.Token = make([]byte, tokenLen)
if _, err := io.ReadFull(b, h.Token); err != nil {
return err
}
_, err := b.Seek(16, io.SeekCurrent)
return err
copy(h.Token, b[:tokenLen])
l += protocol.ByteCount(tokenLen)
return l + 16, nil
}

if h.Type == protocol.PacketTypeInitial {
tokenLen, err := quicvarint.Read(b)
tokenLen, n, err := quicvarint.Parse(b)
l += protocol.ByteCount(n)
if err != nil {
return err
return l, err
}
if tokenLen > uint64(b.Len()) {
return io.EOF
b = b[n:]
if tokenLen > uint64(len(b)) {
return l, io.EOF
}
l += protocol.ByteCount(tokenLen)
h.Token = make([]byte, tokenLen)
if _, err := io.ReadFull(b, h.Token); err != nil {
return err
}
copy(h.Token, b[:tokenLen])
b = b[tokenLen:]
}

pl, err := quicvarint.Read(b)
pl, n, err := quicvarint.Parse(b)
l += protocol.ByteCount(n)
if err != nil {
return err
return 0, err
}
h.Length = protocol.ByteCount(pl)
return nil
return l, nil
marten-seemann marked this conversation as resolved.
Show resolved Hide resolved
}

// ParsedLen returns the number of bytes that were consumed when parsing the header
Expand Down
75 changes: 75 additions & 0 deletions internal/wire/header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,3 +505,78 @@ func BenchmarkIs0RTTPacket(b *testing.B) {
Is0RTTPacket(packets[i%len(packets)])
}
}

func BenchmarkParseInitial(b *testing.B) {
b.Run("without token", func(b *testing.B) {
benchmarkInitialPacketParsing(b, nil)
})
b.Run("with token", func(b *testing.B) {
token := make([]byte, 32)
rand.Read(token)
benchmarkInitialPacketParsing(b, token)
})
}

func benchmarkInitialPacketParsing(b *testing.B, token []byte) {
hdr := Header{
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
SrcConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}),
Length: 1000,
Token: token,
Version: protocol.Version1,
}
data, err := (&ExtendedHeader{
Header: hdr,
PacketNumber: 0x1337,
PacketNumberLen: 4,
}).Append(nil, protocol.Version1)
if err != nil {
b.Fatal(err)
}
data = append(data, make([]byte, 1000)...)

b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
h, _, _, err := ParsePacket(data)
if err != nil {
b.Fatal(err)
}
if h.Type != hdr.Type || h.DestConnectionID != hdr.DestConnectionID || h.SrcConnectionID != hdr.SrcConnectionID ||
!bytes.Equal(h.Token, hdr.Token) {
b.Fatalf("headers don't match: %v vs %v", h, hdr)
}
}
}

func BenchmarkParseRetry(b *testing.B) {
token := make([]byte, 64)
rand.Read(token)
hdr := &ExtendedHeader{
Header: Header{
Type: protocol.PacketTypeRetry,
SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
DestConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}),
Token: token,
Version: protocol.Version1,
},
}
data, err := hdr.Append(nil, hdr.Version)
if err != nil {
b.Fatal(err)
}

b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
h, _, _, err := ParsePacket(data)
if err != nil {
b.Fatal(err)
}
if h.Type != hdr.Type || h.DestConnectionID != hdr.DestConnectionID || h.SrcConnectionID != hdr.SrcConnectionID ||
!bytes.Equal(h.Token, hdr.Token[:len(hdr.Token)-16]) {
b.Fatalf("headers don't match: %#v vs %#v", h, hdr)
}
}
}
Loading