From 11111228cfd7297c42001442c7c091a95937d23a Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 3 May 2024 13:58:18 +0200 Subject: [PATCH] quicvarint: add function to parse a varint from a byte slice (#4475) --- quicvarint/varint.go | 33 ++++++++-- quicvarint/varint_test.go | 123 +++++++++++++++++++++++++++++++++++++- 2 files changed, 151 insertions(+), 5 deletions(-) diff --git a/quicvarint/varint.go b/quicvarint/varint.go index ff99d8592f3..9a22e334f35 100644 --- a/quicvarint/varint.go +++ b/quicvarint/varint.go @@ -26,16 +26,16 @@ func Read(r io.ByteReader) (uint64, error) { return 0, err } // the first two bits of the first byte encode the length - len := 1 << ((firstByte & 0xc0) >> 6) + l := 1 << ((firstByte & 0xc0) >> 6) b1 := firstByte & (0xff - 0xc0) - if len == 1 { + if l == 1 { return uint64(b1), nil } b2, err := r.ReadByte() if err != nil { return 0, err } - if len == 2 { + if l == 2 { return uint64(b2) + uint64(b1)<<8, nil } b3, err := r.ReadByte() @@ -46,7 +46,7 @@ func Read(r io.ByteReader) (uint64, error) { if err != nil { return 0, err } - if len == 4 { + if l == 4 { return uint64(b4) + uint64(b3)<<8 + uint64(b2)<<16 + uint64(b1)<<24, nil } b5, err := r.ReadByte() @@ -68,6 +68,31 @@ func Read(r io.ByteReader) (uint64, error) { return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil } +// Parse reads a number in the QUIC varint format. +// It returns the number of bytes consumed. +func Parse(b []byte) (uint64 /* value */, int /* bytes consumed */, error) { + if len(b) == 0 { + return 0, 0, io.EOF + } + firstByte := b[0] + // the first two bits of the first byte encode the length + l := 1 << ((firstByte & 0xc0) >> 6) + if len(b) < l { + return 0, 0, io.ErrUnexpectedEOF + } + b0 := firstByte & (0xff - 0xc0) + if l == 1 { + return uint64(b0), 1, nil + } + if l == 2 { + return uint64(b[1]) + uint64(b0)<<8, 2, nil + } + if l == 4 { + return uint64(b[3]) + uint64(b[2])<<8 + uint64(b[1])<<16 + uint64(b0)<<24, 4, nil + } + return uint64(b[7]) + uint64(b[6])<<8 + uint64(b[5])<<16 + uint64(b[4])<<24 + uint64(b[3])<<32 + uint64(b[2])<<40 + uint64(b[1])<<48 + uint64(b0)<<56, 8, nil +} + // Append appends i in the QUIC varint format. func Append(b []byte, i uint64) []byte { if i <= maxVarInt1 { diff --git a/quicvarint/varint_test.go b/quicvarint/varint_test.go index 2519d26a878..104b0e620e1 100644 --- a/quicvarint/varint_test.go +++ b/quicvarint/varint_test.go @@ -2,6 +2,10 @@ package quicvarint import ( "bytes" + "io" + "testing" + + "golang.org/x/exp/rand" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -18,7 +22,7 @@ var _ = Describe("Varint encoding / decoding", func() { }) }) - Context("decoding", func() { + Context("reading", func() { It("reads a 1 byte number", func() { b := bytes.NewReader([]byte{0b00011001}) val, err := Read(b) @@ -60,6 +64,59 @@ var _ = Describe("Varint encoding / decoding", func() { }) }) + Context("parsing", func() { + It("fails on an empty slice", func() { + _, _, err := Parse([]byte{}) + Expect(err).To(Equal(io.EOF)) + }) + + It("parses a 1 byte number", func() { + b := []byte{0b00011001} + val, n, err := Parse(b) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint64(25))) + Expect(n).To(Equal(1)) + }) + + It("parses a number that is encoded too long", func() { + b := []byte{0b01000000, 0x25} + val, n, err := Parse(b) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint64(37))) + Expect(n).To(Equal(2)) + }) + + It("parses a 2 byte number", func() { + b := []byte{0b01111011, 0xbd} + val, n, err := Parse(b) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint64(15293))) + Expect(n).To(Equal(2)) + }) + + It("parses a 4 byte number", func() { + b := []byte{0b10011101, 0x7f, 0x3e, 0x7d} + val, n, err := Parse(b) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint64(494878333))) + Expect(n).To(Equal(4)) + }) + + It("parses an 8 byte number", func() { + b := []byte{0b11000010, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c} + val, n, err := Parse(b) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(uint64(151288809941952652))) + Expect(n).To(Equal(8)) + }) + + It("fails if the slice is too short", func() { + b := Append(nil, maxVarInt2*10) + _, _, err := Parse(b[:3]) + Expect(err).To(Equal(io.ErrUnexpectedEOF)) + }) + }) + Context("encoding", func() { Context("with minimal length", func() { It("writes a 1 byte number", func() { @@ -192,3 +249,67 @@ var _ = Describe("Varint encoding / decoding", func() { }) }) }) + +type benchmarkValue struct { + b []byte + v uint64 +} + +func randomValues(num int, maxValue uint64) []benchmarkValue { + r := rand.New(rand.NewSource(1)) + + bv := make([]benchmarkValue, num) + for i := 0; i < num; i++ { + v := r.Uint64() % maxValue + bv[i].v = v + bv[i].b = Append([]byte{}, v) + } + return bv +} + +func BenchmarkRead(b *testing.B) { + b.Run("1-byte", func(b *testing.B) { benchmarkRead(b, randomValues(min(b.N, 1024), maxVarInt1)) }) + b.Run("2-byte", func(b *testing.B) { benchmarkRead(b, randomValues(min(b.N, 1024), maxVarInt2)) }) + b.Run("4-byte", func(b *testing.B) { benchmarkRead(b, randomValues(min(b.N, 1024), maxVarInt4)) }) + b.Run("8-byte", func(b *testing.B) { benchmarkRead(b, randomValues(min(b.N, 1024), maxVarInt8)) }) +} + +func benchmarkRead(b *testing.B, inputs []benchmarkValue) { + r := bytes.NewReader([]byte{}) + b.ResetTimer() + for i := 0; i < b.N; i++ { + index := i % len(inputs) + r.Reset(inputs[index].b) + val, err := Read(r) + if err != nil { + b.Fatal(err) + } + if val != inputs[index].v { + b.Fatalf("expected %d, got %d", inputs[index].v, val) + } + } +} + +func BenchmarkParse(b *testing.B) { + b.Run("1-byte", func(b *testing.B) { benchmarkParse(b, randomValues(min(b.N, 1024), maxVarInt1)) }) + b.Run("2-byte", func(b *testing.B) { benchmarkParse(b, randomValues(min(b.N, 1024), maxVarInt2)) }) + b.Run("4-byte", func(b *testing.B) { benchmarkParse(b, randomValues(min(b.N, 1024), maxVarInt4)) }) + b.Run("8-byte", func(b *testing.B) { benchmarkParse(b, randomValues(min(b.N, 1024), maxVarInt8)) }) +} + +func benchmarkParse(b *testing.B, inputs []benchmarkValue) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + index := i % 1024 + val, n, err := Parse(inputs[index].b) + if err != nil { + b.Fatal(err) + } + if n != len(inputs[index].b) { + b.Fatalf("expected to consume %d bytes, consumed %d", len(inputs[i].b), n) + } + if val != inputs[index].v { + b.Fatalf("expected %d, got %d", inputs[index].v, val) + } + } +}