From d28c0b17437aac95ef32d25b3ff2c919ee165ebb Mon Sep 17 00:00:00 2001 From: cui fliter Date: Sat, 6 May 2023 11:50:14 +0800 Subject: [PATCH 01/28] all: fix some comments Change-Id: I005e210f0ae030c507b4bfd1548c5a885df4c6b9 Reviewed-on: https://go-review.googlesource.com/c/net/+/493155 Run-TryBot: shuang cui Reviewed-by: Ian Lance Taylor TryBot-Result: Gopher Robot Reviewed-by: Cherry Mui Run-TryBot: Ian Lance Taylor Auto-Submit: Ian Lance Taylor --- http2/h2c/h2c.go | 2 +- http2/server.go | 4 ++-- http2/transport.go | 2 +- webdav/if.go | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/http2/h2c/h2c.go b/http2/h2c/h2c.go index a72bbed1b..2d6bf861b 100644 --- a/http2/h2c/h2c.go +++ b/http2/h2c/h2c.go @@ -44,7 +44,7 @@ func init() { // HTTP/1, but unlikely to occur in practice and (2) Upgrading from HTTP/1 to // h2c - this works by using the HTTP/1 Upgrade header to request an upgrade to // h2c. When either of those situations occur we hijack the HTTP/1 connection, -// convert it to a HTTP/2 connection and pass the net.Conn to http2.ServeConn. +// convert it to an HTTP/2 connection and pass the net.Conn to http2.ServeConn. type h2cHandler struct { Handler http.Handler s *http2.Server diff --git a/http2/server.go b/http2/server.go index cd057f398..0c5bdeec5 100644 --- a/http2/server.go +++ b/http2/server.go @@ -2429,7 +2429,7 @@ type requestBody struct { conn *serverConn closeOnce sync.Once // for use by Close only sawEOF bool // for use by Read only - pipe *pipe // non-nil if we have a HTTP entity message body + pipe *pipe // non-nil if we have an HTTP entity message body needsContinue bool // need to send a 100-continue } @@ -2774,7 +2774,7 @@ func (w *responseWriter) FlushError() error { err = rws.bw.Flush() } else { // The bufio.Writer won't call chunkWriter.Write - // (writeChunk with zero bytes, so we have to do it + // (writeChunk with zero bytes), so we have to do it // ourselves to force the HTTP response header and/or // final DATA frame (with END_STREAM) to be sent. _, err = chunkWriter{rws}.Write(nil) diff --git a/http2/transport.go b/http2/transport.go index ac90a2631..ff86a765e 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -1899,7 +1899,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail // 8.1.2.3 Request Pseudo-Header Fields // The :path pseudo-header field includes the path and query parts of the // target URI (the path-absolute production and optionally a '?' character - // followed by the query production (see Sections 3.3 and 3.4 of + // followed by the query production, see Sections 3.3 and 3.4 of // [RFC3986]). f(":authority", host) m := req.Method diff --git a/webdav/if.go b/webdav/if.go index 416e81cdf..e646570bb 100644 --- a/webdav/if.go +++ b/webdav/if.go @@ -24,7 +24,7 @@ type ifList struct { // parseIfHeader parses the "If: foo bar" HTTP header. The httpHeader string // should omit the "If:" prefix and have any "\r\n"s collapsed to a " ", as is -// returned by req.Header.Get("If") for a http.Request req. +// returned by req.Header.Get("If") for an http.Request req. func parseIfHeader(httpHeader string) (h ifHeader, ok bool) { s := strings.TrimSpace(httpHeader) switch tokenType, _, _ := lex(s); tokenType { From 2b0b97d53f17ba6f8bd4b7bc61b1ec4f3971e072 Mon Sep 17 00:00:00 2001 From: Mateusz Poliwczak Date: Mon, 15 May 2023 12:16:19 +0000 Subject: [PATCH 02/28] dns/dnsmessage: reject packing of 255B rooted names, reject unpacking of 256B (dns encoded) names Packing a 255B (rooted) name will create an 256B (dns encoded) name, which is an invalid name. Similar with unpacking, we can't unpack 256B (dns encoded) name, because it is too long. Change-Id: I17cc93a93a17a879a2a789629e56ad39999da9ac GitHub-Last-Rev: ddf151af6c650160f2583f66c07d41ed18c7ae5e GitHub-Pull-Request: golang/net#156 Reviewed-on: https://go-review.googlesource.com/c/net/+/448156 Reviewed-by: Michael Knyszek Reviewed-by: Damien Neil TryBot-Result: Gopher Robot Run-TryBot: Mateusz Poliwczak --- dns/dnsmessage/message.go | 19 ++++++---- dns/dnsmessage/message_test.go | 65 ++++++++++++++++++++++++++-------- 2 files changed, 62 insertions(+), 22 deletions(-) diff --git a/dns/dnsmessage/message.go b/dns/dnsmessage/message.go index ffdf19d5d..69c611bda 100644 --- a/dns/dnsmessage/message.go +++ b/dns/dnsmessage/message.go @@ -263,6 +263,7 @@ var ( errNilResouceBody = errors.New("nil resource body") errResourceLen = errors.New("insufficient data for resource body length") errSegTooLong = errors.New("segment length too long") + errNameTooLong = errors.New("name too long") errZeroSegLen = errors.New("zero length segment") errResTooLong = errors.New("resource length too long") errTooManyQuestions = errors.New("too many Questions to pack (>65535)") @@ -1728,7 +1729,7 @@ const ( // // The provided extRCode must be an extended RCode. func (h *ResourceHeader) SetEDNS0(udpPayloadLen int, extRCode RCode, dnssecOK bool) error { - h.Name = Name{Data: [nameLen]byte{'.'}, Length: 1} // RFC 6891 section 6.1.2 + h.Name = Name{Data: [255]byte{'.'}, Length: 1} // RFC 6891 section 6.1.2 h.Type = TypeOPT h.Class = Class(udpPayloadLen) h.TTL = uint32(extRCode) >> 4 << 24 @@ -1888,21 +1889,21 @@ func unpackBytes(msg []byte, off int, field []byte) (int, error) { return newOff, nil } -const nameLen = 255 +const nonEncodedNameMax = 254 // A Name is a non-encoded domain name. It is used instead of strings to avoid // allocations. type Name struct { - Data [nameLen]byte // 255 bytes + Data [255]byte Length uint8 } // NewName creates a new Name from a string. func NewName(name string) (Name, error) { - if len(name) > nameLen { + n := Name{Length: uint8(len(name))} + if len(name) > len(n.Data) { return Name{}, errCalcLen } - n := Name{Length: uint8(len(name))} copy(n.Data[:], name) return n, nil } @@ -1936,6 +1937,10 @@ func (n *Name) GoString() string { func (n *Name) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) { oldMsg := msg + if n.Length > nonEncodedNameMax { + return nil, errNameTooLong + } + // Add a trailing dot to canonicalize name. if n.Length == 0 || n.Data[n.Length-1] != '.' { return oldMsg, errNonCanonicalName @@ -2057,8 +2062,8 @@ Loop: if len(name) == 0 { name = append(name, '.') } - if len(name) > len(n.Data) { - return off, errCalcLen + if len(name) > nonEncodedNameMax { + return off, errNameTooLong } n.Length = uint8(len(name)) if ptr == 0 { diff --git a/dns/dnsmessage/message_test.go b/dns/dnsmessage/message_test.go index 3cddfca99..ef5326db8 100644 --- a/dns/dnsmessage/message_test.go +++ b/dns/dnsmessage/message_test.go @@ -212,25 +212,28 @@ func TestName(t *testing.T) { } func TestNamePackUnpack(t *testing.T) { + const suffix = ".go.dev." + var longDNSPrefix = strings.Repeat("verylongdomainlabel.", 20) + tests := []struct { - in string - want string - err error + in string + err error }{ - {"", "", errNonCanonicalName}, - {".", ".", nil}, - {"google..com", "", errNonCanonicalName}, - {"google.com", "", errNonCanonicalName}, - {"google..com.", "", errZeroSegLen}, - {"google.com.", "google.com.", nil}, - {".google.com.", "", errZeroSegLen}, - {"www..google.com.", "", errZeroSegLen}, - {"www.google.com.", "www.google.com.", nil}, + {"", errNonCanonicalName}, + {".", nil}, + {"google..com", errNonCanonicalName}, + {"google.com", errNonCanonicalName}, + {"google..com.", errZeroSegLen}, + {"google.com.", nil}, + {".google.com.", errZeroSegLen}, + {"www..google.com.", errZeroSegLen}, + {"www.google.com.", nil}, + {in: longDNSPrefix[:254-len(suffix)] + suffix}, // 254B name, with ending dot. + {in: longDNSPrefix[:255-len(suffix)] + suffix, err: errNameTooLong}, // 255B name, with ending dot. } for _, test := range tests { in := MustNewName(test.in) - want := MustNewName(test.want) buf, err := in.pack(make([]byte, 0, 30), map[string]int{}, 0) if err != test.err { t.Errorf("got %q.pack() = %v, want = %v", test.in, err, test.err) @@ -253,8 +256,40 @@ func TestNamePackUnpack(t *testing.T) { len(buf), ) } - if got != want { - t.Errorf("unpacking packing of %q: got = %#v, want = %#v", test.in, got, want) + if got != in { + t.Errorf("unpacking packing of %q: got = %#v, want = %#v", test.in, got, in) + } + } +} + +func TestNameUnpackTooLongName(t *testing.T) { + var suffix = []byte{2, 'g', 'o', 3, 'd', 'e', 'v', 0} + + const label = "longdnslabel" + labelBinary := append([]byte{byte(len(label))}, []byte(label)...) + var longDNSPrefix = bytes.Repeat(labelBinary, 18) + longDNSPrefix = longDNSPrefix[:len(longDNSPrefix):len(longDNSPrefix)] + + prepName := func(length int) []byte { + missing := length - (len(longDNSPrefix) + len(suffix) + 1) + name := append(longDNSPrefix, byte(missing)) + name = append(name, bytes.Repeat([]byte{'a'}, missing)...) + return append(name, suffix...) + } + + tests := []struct { + name []byte + err error + }{ + {name: prepName(255)}, + {name: prepName(256), err: errNameTooLong}, + } + + for i, test := range tests { + var got Name + _, err := got.unpack(test.name, 0) + if err != test.err { + t.Errorf("%v: %v: expected error: %v, got %v", i, test.name, test.err, err) } } } From 120fc906b30bade8c220769da77801566d7f4ec8 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Wed, 22 Mar 2023 14:42:28 -0700 Subject: [PATCH 03/28] http2: change default frame scheduler to round robin The priority scheduler allows stream starvation (see golang/go#58804) and is CPU intensive. In addition, the RFC 7540 prioritization scheme it implements was deprecated in RFC 9113 and does not appear to have ever had significant adoption. Add a simple round-robin scheduler and enable it by default. For golang/go#58804 Change-Id: I5c5143aa9bc339fc0894f70d773fa7c0d7d87eef Reviewed-on: https://go-review.googlesource.com/c/net/+/478735 TryBot-Result: Gopher Robot Reviewed-by: Bryan Mills Run-TryBot: Damien Neil --- http2/server.go | 2 +- http2/writesched.go | 3 +- http2/writesched_roundrobin.go | 119 ++++++++++++++++++++++++++++ http2/writesched_roundrobin_test.go | 65 +++++++++++++++ 4 files changed, 187 insertions(+), 2 deletions(-) create mode 100644 http2/writesched_roundrobin.go create mode 100644 http2/writesched_roundrobin_test.go diff --git a/http2/server.go b/http2/server.go index 0c5bdeec5..396d53b90 100644 --- a/http2/server.go +++ b/http2/server.go @@ -441,7 +441,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { if s.NewWriteScheduler != nil { sc.writeSched = s.NewWriteScheduler() } else { - sc.writeSched = NewPriorityWriteScheduler(nil) + sc.writeSched = newRoundRobinWriteScheduler() } // These start at the RFC-specified defaults. If there is a higher diff --git a/http2/writesched.go b/http2/writesched.go index c7cd00173..cc893adc2 100644 --- a/http2/writesched.go +++ b/http2/writesched.go @@ -184,7 +184,8 @@ func (wr *FrameWriteRequest) replyToWriter(err error) { // writeQueue is used by implementations of WriteScheduler. type writeQueue struct { - s []FrameWriteRequest + s []FrameWriteRequest + prev, next *writeQueue } func (q *writeQueue) empty() bool { return len(q.s) == 0 } diff --git a/http2/writesched_roundrobin.go b/http2/writesched_roundrobin.go new file mode 100644 index 000000000..54fe86322 --- /dev/null +++ b/http2/writesched_roundrobin.go @@ -0,0 +1,119 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http2 + +import ( + "fmt" + "math" +) + +type roundRobinWriteScheduler struct { + // control contains control frames (SETTINGS, PING, etc.). + control writeQueue + + // streams maps stream ID to a queue. + streams map[uint32]*writeQueue + + // stream queues are stored in a circular linked list. + // head is the next stream to write, or nil if there are no streams open. + head *writeQueue + + // pool of empty queues for reuse. + queuePool writeQueuePool +} + +// newRoundRobinWriteScheduler constructs a new write scheduler. +// The round robin scheduler priorizes control frames +// like SETTINGS and PING over DATA frames. +// When there are no control frames to send, it performs a round-robin +// selection from the ready streams. +func newRoundRobinWriteScheduler() WriteScheduler { + ws := &roundRobinWriteScheduler{ + streams: make(map[uint32]*writeQueue), + } + return ws +} + +func (ws *roundRobinWriteScheduler) OpenStream(streamID uint32, options OpenStreamOptions) { + if ws.streams[streamID] != nil { + panic(fmt.Errorf("stream %d already opened", streamID)) + } + q := ws.queuePool.get() + ws.streams[streamID] = q + if ws.head == nil { + ws.head = q + q.next = q + q.prev = q + } else { + // Queues are stored in a ring. + // Insert the new stream before ws.head, putting it at the end of the list. + q.prev = ws.head.prev + q.next = ws.head + q.prev.next = q + q.next.prev = q + } +} + +func (ws *roundRobinWriteScheduler) CloseStream(streamID uint32) { + q := ws.streams[streamID] + if q == nil { + return + } + if q.next == q { + // This was the only open stream. + ws.head = nil + } else { + q.prev.next = q.next + q.next.prev = q.prev + if ws.head == q { + ws.head = q.next + } + } + delete(ws.streams, streamID) + ws.queuePool.put(q) +} + +func (ws *roundRobinWriteScheduler) AdjustStream(streamID uint32, priority PriorityParam) {} + +func (ws *roundRobinWriteScheduler) Push(wr FrameWriteRequest) { + if wr.isControl() { + ws.control.push(wr) + return + } + q := ws.streams[wr.StreamID()] + if q == nil { + // This is a closed stream. + // wr should not be a HEADERS or DATA frame. + // We push the request onto the control queue. + if wr.DataSize() > 0 { + panic("add DATA on non-open stream") + } + ws.control.push(wr) + return + } + q.push(wr) +} + +func (ws *roundRobinWriteScheduler) Pop() (FrameWriteRequest, bool) { + // Control and RST_STREAM frames first. + if !ws.control.empty() { + return ws.control.shift(), true + } + if ws.head == nil { + return FrameWriteRequest{}, false + } + q := ws.head + for { + if wr, ok := q.consume(math.MaxInt32); ok { + ws.head = q.next + return wr, true + } + q = q.next + if q == ws.head { + break + } + } + return FrameWriteRequest{}, false +} diff --git a/http2/writesched_roundrobin_test.go b/http2/writesched_roundrobin_test.go new file mode 100644 index 000000000..032b2bc6c --- /dev/null +++ b/http2/writesched_roundrobin_test.go @@ -0,0 +1,65 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http2 + +import ( + "reflect" + "testing" +) + +func TestRoundRobinScheduler(t *testing.T) { + const maxFrameSize = 16 + sc := &serverConn{maxFrameSize: maxFrameSize} + ws := newRoundRobinWriteScheduler() + streams := make([]*stream, 4) + for i := range streams { + streamID := uint32(i) + 1 + streams[i] = &stream{ + id: streamID, + sc: sc, + } + streams[i].flow.add(1 << 20) // arbitrary large value + ws.OpenStream(streamID, OpenStreamOptions{}) + wr := FrameWriteRequest{ + write: &writeData{ + streamID: streamID, + p: make([]byte, maxFrameSize*(i+1)), + endStream: false, + }, + stream: streams[i], + } + ws.Push(wr) + } + const controlFrames = 2 + for i := 0; i < controlFrames; i++ { + ws.Push(makeWriteNonStreamRequest()) + } + + // We should get the control frames first. + for i := 0; i < controlFrames; i++ { + wr, ok := ws.Pop() + if !ok || wr.StreamID() != 0 { + t.Fatalf("wr.Pop() = stream %v, %v; want 0, true", wr.StreamID(), ok) + } + } + + // Each stream should write maxFrameSize bytes until it runs out of data. + // Stream 1 has one frame of data, 2 has two frames, etc. + want := []uint32{1, 2, 3, 4, 2, 3, 4, 3, 4, 4} + var got []uint32 + for { + wr, ok := ws.Pop() + if !ok { + break + } + if wr.DataSize() != maxFrameSize { + t.Fatalf("wr.Pop() = %v data bytes, want %v", wr.DataSize(), maxFrameSize) + } + got = append(got, wr.StreamID()) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("popped streams %v, want %v", got, want) + } +} From 23ce3b89bcb1f0c5e9bd486c6d8002150d693a66 Mon Sep 17 00:00:00 2001 From: Laurent Senta Date: Mon, 27 Feb 2023 10:49:56 +0100 Subject: [PATCH 04/28] http2: disable Content-Length when nilled Change-Id: Iefef8dc1004a8e889d0e9f7243f594ae7b727a07 Reviewed-on: https://go-review.googlesource.com/c/net/+/471535 Reviewed-by: Damien Neil Auto-Submit: Damien Neil TryBot-Result: Gopher Robot Run-TryBot: Jorropo Reviewed-by: Heschi Kreinick Run-TryBot: Damien Neil --- http2/server.go | 3 ++- http2/server_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/http2/server.go b/http2/server.go index 396d53b90..033b6e6db 100644 --- a/http2/server.go +++ b/http2/server.go @@ -2569,7 +2569,8 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) { clen = "" } } - if clen == "" && rws.handlerDone && bodyAllowedForStatus(rws.status) && (len(p) > 0 || !isHeadResp) { + _, hasContentLength := rws.snapHeader["Content-Length"] + if !hasContentLength && clen == "" && rws.handlerDone && bodyAllowedForStatus(rws.status) && (len(p) > 0 || !isHeadResp) { clen = strconv.Itoa(len(p)) } _, hasContentType := rws.snapHeader["Content-Type"] diff --git a/http2/server_test.go b/http2/server_test.go index 40ab750fc..cd73291ea 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -3555,6 +3555,30 @@ func TestServerNoDuplicateContentType(t *testing.T) { } } +func TestServerContentLengthCanBeDisabled(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.Header()["Content-Length"] = nil + fmt.Fprintf(w, "OK") + }) + defer st.Close() + st.greet() + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(), + EndStream: true, + EndHeaders: true, + }) + h := st.wantHeaders() + headers := st.decodeHeader(h.HeaderBlockFragment()) + want := [][2]string{ + {":status", "200"}, + {"content-type", "text/plain; charset=utf-8"}, + } + if !reflect.DeepEqual(headers, want) { + t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want) + } +} + func disableGoroutineTracking() (restore func()) { old := DebugGoroutines DebugGoroutines = false From 3b31286d862a5c296653920dc484e4051ed5e85d Mon Sep 17 00:00:00 2001 From: Nick Figgins Date: Fri, 5 May 2023 12:26:43 +0000 Subject: [PATCH 05/28] ipv4,ipv6: remove unneeded deadlines added for flaky tests Deadlines were added in https://go.dev/cl/21360043, but these are unneeded as the tests will fail anyways as a result of the timeout. This prevents these timeouts from causing further test flakes. Fixes #58955 Change-Id: I76ebf7452bf326a09f1a7665d362fe68f345a4be GitHub-Last-Rev: 9ca732a5ea0f17bb80434ea32c0c933995dd7e38 GitHub-Pull-Request: golang/net#172 Reviewed-on: https://go-review.googlesource.com/c/net/+/492620 Reviewed-by: Ian Lance Taylor Run-TryBot: WANG Xuerui Run-TryBot: Ian Lance Taylor Reviewed-by: Bryan Mills TryBot-Result: Gopher Robot Auto-Submit: Ian Lance Taylor --- ipv4/multicast_test.go | 10 ---------- ipv4/unicast_test.go | 18 ------------------ ipv6/unicast_test.go | 12 ------------ 3 files changed, 40 deletions(-) diff --git a/ipv4/multicast_test.go b/ipv4/multicast_test.go index d056ff6de..ddd85def6 100644 --- a/ipv4/multicast_test.go +++ b/ipv4/multicast_test.go @@ -11,7 +11,6 @@ import ( "os" "runtime" "testing" - "time" "golang.org/x/net/icmp" "golang.org/x/net/internal/iana" @@ -100,9 +99,6 @@ func TestPacketConnReadWriteMulticastUDP(t *testing.T) { } t.Fatal(err) } - if err := p.SetDeadline(time.Now().Add(200 * time.Millisecond)); err != nil { - t.Fatal(err) - } if err := p.SetMulticastTTL(i + 1); err != nil { t.Fatal(err) } @@ -217,9 +213,6 @@ func TestPacketConnReadWriteMulticastICMP(t *testing.T) { } t.Fatal(err) } - if err := p.SetDeadline(time.Now().Add(200 * time.Millisecond)); err != nil { - t.Fatal(err) - } if err := p.SetMulticastTTL(i + 1); err != nil { t.Fatal(err) } @@ -337,9 +330,6 @@ func TestRawConnReadWriteMulticastICMP(t *testing.T) { } t.Fatal(err) } - if err := r.SetDeadline(time.Now().Add(200 * time.Millisecond)); err != nil { - t.Fatal(err) - } r.SetMulticastTTL(i + 1) if err := r.WriteTo(wh, wb, nil); err != nil { t.Fatal(err) diff --git a/ipv4/unicast_test.go b/ipv4/unicast_test.go index a68f4cc02..2e55f2e5b 100644 --- a/ipv4/unicast_test.go +++ b/ipv4/unicast_test.go @@ -50,9 +50,6 @@ func TestPacketConnReadWriteUnicastUDP(t *testing.T) { t.Fatal(err) } p.SetTTL(i + 1) - if err := p.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)); err != nil { - t.Fatal(err) - } backoff := time.Millisecond for { @@ -72,9 +69,6 @@ func TestPacketConnReadWriteUnicastUDP(t *testing.T) { } rb := make([]byte, 128) - if err := p.SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil { - t.Fatal(err) - } if n, _, _, err := p.ReadFrom(rb); err != nil { t.Fatal(err) } else if !bytes.Equal(rb[:n], wb) { @@ -130,9 +124,6 @@ func TestPacketConnReadWriteUnicastICMP(t *testing.T) { t.Fatal(err) } p.SetTTL(i + 1) - if err := p.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)); err != nil { - t.Fatal(err) - } backoff := time.Millisecond for { @@ -153,9 +144,6 @@ func TestPacketConnReadWriteUnicastICMP(t *testing.T) { rb := make([]byte, 128) loop: - if err := p.SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil { - t.Fatal(err) - } if n, _, _, err := p.ReadFrom(rb); err != nil { t.Fatal(err) } else { @@ -226,17 +214,11 @@ func TestRawConnReadWriteUnicastICMP(t *testing.T) { } t.Fatal(err) } - if err := r.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)); err != nil { - t.Fatal(err) - } if err := r.WriteTo(wh, wb, nil); err != nil { t.Fatal(err) } rb := make([]byte, ipv4.HeaderLen+128) loop: - if err := r.SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil { - t.Fatal(err) - } if _, b, _, err := r.ReadFrom(rb); err != nil { t.Fatal(err) } else { diff --git a/ipv6/unicast_test.go b/ipv6/unicast_test.go index 04e5b06fa..79de14c5a 100644 --- a/ipv6/unicast_test.go +++ b/ipv6/unicast_test.go @@ -57,9 +57,6 @@ func TestPacketConnReadWriteUnicastUDP(t *testing.T) { t.Fatal(err) } cm.HopLimit = i + 1 - if err := p.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)); err != nil { - t.Fatal(err) - } backoff := time.Millisecond for { @@ -79,9 +76,6 @@ func TestPacketConnReadWriteUnicastUDP(t *testing.T) { } rb := make([]byte, 128) - if err := p.SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil { - t.Fatal(err) - } if n, _, _, err := p.ReadFrom(rb); err != nil { t.Fatal(err) } else if !bytes.Equal(rb[:n], wb) { @@ -168,9 +162,6 @@ func TestPacketConnReadWriteUnicastICMP(t *testing.T) { t.Fatal(err) } cm.HopLimit = i + 1 - if err := p.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)); err != nil { - t.Fatal(err) - } backoff := time.Millisecond for { @@ -190,9 +181,6 @@ func TestPacketConnReadWriteUnicastICMP(t *testing.T) { } rb := make([]byte, 128) - if err := p.SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil { - t.Fatal(err) - } if n, _, _, err := p.ReadFrom(rb); err != nil { t.Fatal(err) } else { From abee42a2abc502cd1e66ada35c62cef6aec3c3a8 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Wed, 17 May 2023 16:48:06 -0700 Subject: [PATCH 06/28] http2: deflake TestTransportReuseAfterError This test issues a request with a short timeout, and expects that the request timing out will result in the connection it was sent on being marked as unusable. However, it is possible for the request to time out before it is sent, with no effect on the connection. The test's next request then uses the same connection and hangs. Rather than a timeout, cancel the request after it is received on the server. Fixes golang/go#59934 Change-Id: I1144686377158d0654e0f91a1b0312021a02a01d Reviewed-on: https://go-review.googlesource.com/c/net/+/496055 Reviewed-by: Bryan Mills TryBot-Result: Gopher Robot Run-TryBot: Damien Neil --- http2/transport_test.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/http2/transport_test.go b/http2/transport_test.go index 54d455148..68e17fd5d 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -6434,15 +6434,17 @@ func TestTransportReuseAfterError(t *testing.T) { // Request 2 is also made on conn 1. // Reading the response will block. - // The request fails when the context deadline expires. + // The request is canceled once the server receives it. // Conn 1 should now be flagged as unfit for reuse. - timeoutCtx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) - defer cancel() - _, err := tr.RoundTrip(req.Clone(timeoutCtx)) + req2Ctx, cancel := context.WithCancel(context.Background()) + go func() { + <-serverReqc + cancel() + }() + _, err := tr.RoundTrip(req.Clone(req2Ctx)) if err == nil { - t.Errorf("request 2 unexpectedly succeeded (want timeout)") + t.Errorf("request 2 unexpectedly succeeded (want cancel)") } - time.Sleep(1 * time.Millisecond) // Request 3 is made on a new conn, and succeeds. res3, err := tr.RoundTrip(req.Clone(context.Background())) From 056145cf6257608ff79a4db43b01f8d7e384243e Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 18 May 2023 10:15:31 -0700 Subject: [PATCH 07/28] net/http: deflake TestTransportRetryAfterGOAWAY Drop a redundant Close of a net.Conn. On Windows, writing to a closed connection will cause future reads from the connection to fail, even if there is buffered data available. When the test server writes a GOAWAY frame and immediately closes the connection, this can result in the client never seeing the GOAWAY. To avoid this, don't close server connections until after all test functions have returned. Fixes golang/go#59919 Change-Id: I82ed15870f3e6cd47f833a7a60b007b2fa2e15b0 Reviewed-on: https://go-review.googlesource.com/c/net/+/496056 TryBot-Result: Gopher Robot Run-TryBot: Damien Neil Reviewed-by: Bryan Mills --- http2/transport_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/http2/transport_test.go b/http2/transport_test.go index 68e17fd5d..53999f6a0 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -3783,7 +3783,6 @@ func testClientMultipleDials(t *testing.T, client func(*Transport), server func( go func(count int) { defer wg.Done() server(count, ct) - sc.Close() }(count) return cc, nil } From ca96da60189be7363ed3570858befb1a9bab5892 Mon Sep 17 00:00:00 2001 From: Mateusz Poliwczak Date: Thu, 18 May 2023 13:40:58 +0000 Subject: [PATCH 08/28] dns/dnsmessage: reject names with dots inside label Fixes golang/go#56246 Change-Id: I9c8d611d1305536a7510bf6c4a02a5e551aa657a GitHub-Last-Rev: 8a8703a1a7bde3457682c27db3e7d63856bd0cbc GitHub-Pull-Request: golang/net#154 Reviewed-on: https://go-review.googlesource.com/c/net/+/443215 TryBot-Result: Gopher Robot Reviewed-by: Roland Shoemaker Reviewed-by: Matthew Dempsky Run-TryBot: Mateusz Poliwczak Auto-Submit: Roland Shoemaker --- dns/dnsmessage/message.go | 10 ++++++++++ dns/dnsmessage/message_test.go | 9 +++++++++ 2 files changed, 19 insertions(+) diff --git a/dns/dnsmessage/message.go b/dns/dnsmessage/message.go index 69c611bda..1577d4a19 100644 --- a/dns/dnsmessage/message.go +++ b/dns/dnsmessage/message.go @@ -260,6 +260,7 @@ var ( errReserved = errors.New("segment prefix is reserved") errTooManyPtr = errors.New("too many pointers (>10)") errInvalidPtr = errors.New("invalid pointer") + errInvalidName = errors.New("invalid dns name") errNilResouceBody = errors.New("nil resource body") errResourceLen = errors.New("insufficient data for resource body length") errSegTooLong = errors.New("segment length too long") @@ -2034,6 +2035,15 @@ Loop: if endOff > len(msg) { return off, errCalcLen } + + // Reject names containing dots. + // See issue golang/go#56246 + for _, v := range msg[currOff:endOff] { + if v == '.' { + return off, errInvalidName + } + } + name = append(name, msg[currOff:endOff]...) name = append(name, '.') currOff = endOff diff --git a/dns/dnsmessage/message_test.go b/dns/dnsmessage/message_test.go index ef5326db8..ce2716e42 100644 --- a/dns/dnsmessage/message_test.go +++ b/dns/dnsmessage/message_test.go @@ -211,6 +211,15 @@ func TestName(t *testing.T) { } } +func TestNameWithDotsUnpack(t *testing.T) { + name := []byte{3, 'w', '.', 'w', 2, 'g', 'o', 3, 'd', 'e', 'v', 0} + var n Name + _, err := n.unpack(name, 0) + if err != errInvalidName { + t.Fatalf("expected %v, got %v", errInvalidName, err) + } +} + func TestNamePackUnpack(t *testing.T) { const suffix = ".go.dev." var longDNSPrefix = strings.Repeat("verylongdomainlabel.", 20) From 6826f5a7dbc43dc0bbeab569a7fc5a698f32e254 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 18 May 2023 14:39:45 -0700 Subject: [PATCH 09/28] http2: close request bodies before RoundTrip error return When returning an error from RoundTrip, wait for the close of the request body to complete before returning. This avoids a race between the HTTP/2 transport closing the request body and the net/http retry loop examining the readTrackingBody to see if it has been closed. For golang/go#60041 Change-Id: I8be69ff5056806406716e01e02d1f631deeca088 Reviewed-on: https://go-review.googlesource.com/c/net/+/496335 Run-TryBot: Damien Neil TryBot-Result: Gopher Robot Reviewed-by: Bryan Mills --- http2/transport.go | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/http2/transport.go b/http2/transport.go index ff86a765e..4f08ccba9 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -1268,8 +1268,8 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { cancelRequest := func(cs *clientStream, err error) error { cs.cc.mu.Lock() - defer cs.cc.mu.Unlock() cs.abortStreamLocked(err) + bodyClosed := cs.reqBodyClosed if cs.ID != 0 { // This request may have failed because of a problem with the connection, // or for some unrelated reason. (For example, the user might have canceled @@ -1284,6 +1284,23 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { // will not help. cs.cc.doNotReuse = true } + cs.cc.mu.Unlock() + // Wait for the request body to be closed. + // + // If nothing closed the body before now, abortStreamLocked + // will have started a goroutine to close it. + // + // Closing the body before returning avoids a race condition + // with net/http checking its readTrackingBody to see if the + // body was read from or closed. See golang/go#60041. + // + // The body is closed in a separate goroutine without the + // connection mutex held, but dropping the mutex before waiting + // will keep us from holding it indefinitely if the body + // close is slow for some reason. + if bodyClosed != nil { + <-bodyClosed + } return err } From ee6956ff9fcd901591e0e5a5ba6c8b94530f14b3 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Fri, 14 Oct 2022 09:12:44 -0700 Subject: [PATCH 10/28] quic: add internal/quic package This package will eventually contain an implementation of the QUIC protocol. Place it under internal/ to begin with to avoid accidental use while it is in an incomplete state. For golang/go#58547 Change-Id: Ib3526e0bbe433e91283859913818d3e72fc194b6 Reviewed-on: https://go-review.googlesource.com/c/net/+/468402 Reviewed-by: Matt Layher Reviewed-by: Roland Shoemaker Run-TryBot: Matt Layher TryBot-Result: Gopher Robot Run-TryBot: Damien Neil Reviewed-by: Cherry Mui --- internal/quic/doc.go | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 internal/quic/doc.go diff --git a/internal/quic/doc.go b/internal/quic/doc.go new file mode 100644 index 000000000..2fe17fe22 --- /dev/null +++ b/internal/quic/doc.go @@ -0,0 +1,9 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package quic is an experimental, incomplete implementation of the QUIC protocol. +// This package is a work in progress, and is not ready for use at this time. +// +// This package implements (or will implement) RFC 9000, RFC 9001, and RFC 9002. +package quic From 0d6f3cba5e5201e644e346793713dc3f0608bb2c Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Fri, 14 Oct 2022 09:42:42 -0700 Subject: [PATCH 11/28] quic: add various useful common constants and types For golang/go#58547 Change-Id: I178373329de20fe8e1b3d256638f0ae7ab366d03 Reviewed-on: https://go-review.googlesource.com/c/net/+/475435 Run-TryBot: Damien Neil Reviewed-by: Roland Shoemaker Reviewed-by: Cherry Mui TryBot-Result: Gopher Robot Reviewed-by: Matt Layher --- internal/quic/quic.go | 142 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 internal/quic/quic.go diff --git a/internal/quic/quic.go b/internal/quic/quic.go new file mode 100644 index 000000000..f7c1b765d --- /dev/null +++ b/internal/quic/quic.go @@ -0,0 +1,142 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "time" +) + +// connIDLen is the length in bytes of connection IDs chosen by this package. +// Since 1-RTT packets don't include a connection ID length field, +// we use a consistent length for all our IDs. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1-6 +const connIDLen = 8 + +// Local values of various transport parameters. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-18.2 +const ( + // The max_udp_payload_size transport parameter is the size of our + // network receive buffer. + // + // Set this to the largest UDP packet that can be sent over + // Ethernet without using jumbo frames: 1500 byte Ethernet frame, + // minus 20 byte IPv4 header and 8 byte UDP header. + // + // The maximum possible UDP payload is 65527 bytes. Supporting this + // without wasting memory in unused receive buffers will require some + // care. For now, just limit ourselves to the most common case. + maxUDPPayloadSize = 1472 + + ackDelayExponent = 3 // ack_delay_exponent + maxAckDelay = 25 * time.Millisecond // max_ack_delay +) + +// A side distinguishes between the client and server sides of a connection. +type side int8 + +const ( + clientSide = side(iota) + serverSide +) + +func (s side) String() string { + switch s { + case clientSide: + return "client" + case serverSide: + return "server" + default: + return "BUG" + } +} + +// A numberSpace is the context in which a packet number applies. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-12.3-7 +type numberSpace byte + +const ( + initialSpace = numberSpace(iota) + handshakeSpace + appDataSpace +) + +func (n numberSpace) String() string { + switch n { + case initialSpace: + return "Initial" + case handshakeSpace: + return "Handshake" + case appDataSpace: + return "AppData" + default: + return "BUG" + } +} + +// A streamType is the type of a stream: bidirectional or unidirectional. +type streamType uint8 + +const ( + bidiStream = streamType(iota) + uniStream +) + +func (s streamType) String() string { + switch s { + case bidiStream: + return "bidi" + case uniStream: + return "uni" + default: + return "BUG" + } +} + +// A streamID is a QUIC stream ID. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-2.1 +type streamID uint64 + +// The two least significant bits of a stream ID indicate the initiator +// and directionality of the stream. The upper bits are the stream number. +// Each of the four possible combinations of initiator and direction +// each has a distinct number space. +const ( + clientInitiatedStreamBit = 0x0 + serverInitiatedStreamBit = 0x1 + initiatorStreamBitMask = 0x1 + + bidiStreamBit = 0x0 + uniStreamBit = 0x2 + dirStreamBitMask = 0x2 +) + +func newStreamID(initiator side, typ streamType, num int64) streamID { + id := streamID(num << 2) + if typ == uniStream { + id |= uniStreamBit + } + if initiator == serverSide { + id |= serverInitiatedStreamBit + } + return id +} + +func (s streamID) initiator() side { + if s&initiatorStreamBitMask == serverInitiatedStreamBit { + return serverSide + } + return clientSide +} + +func (s streamID) num() int64 { + return int64(s) >> 2 +} + +func (s streamID) streamType() streamType { + if s&dirStreamBitMask == uniStreamBit { + return uniStream + } + return bidiStream +} From d4a2c13d06aa7ffbb05431ce72e6171090406c7e Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 13 Oct 2022 11:49:26 -0700 Subject: [PATCH 12/28] quic: packet number encoding/decoding QUIC packet numbers are integers in the range [0, 2^62). Packet numbers are encoded as the 1-4 least significant bytes of the full number, with the remaining bytes extrapolated based on the largest packet number seen by the receiver. RFC 9000, Section 17.1. For golang/go#58547 Change-Id: I9e111fe6c9c437fdcd9dc57336e094512c0b52b0 Reviewed-on: https://go-review.googlesource.com/c/net/+/475436 Reviewed-by: Matt Layher Reviewed-by: Cherry Mui Reviewed-by: Roland Shoemaker Run-TryBot: Damien Neil TryBot-Result: Gopher Robot Reviewed-by: Jonathan Amsterdam --- internal/quic/packet_number.go | 72 +++++++++++ internal/quic/packet_number_test.go | 183 ++++++++++++++++++++++++++++ 2 files changed, 255 insertions(+) create mode 100644 internal/quic/packet_number.go create mode 100644 internal/quic/packet_number_test.go diff --git a/internal/quic/packet_number.go b/internal/quic/packet_number.go new file mode 100644 index 000000000..9e9f0ad00 --- /dev/null +++ b/internal/quic/packet_number.go @@ -0,0 +1,72 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +// A packetNumber is a QUIC packet number. +// Packet numbers are integers in the range [0, 2^62-1]. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-12.3 +type packetNumber int64 + +const maxPacketNumber = 1<<62 - 1 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.1-1 + +// decodePacketNumber decodes a truncated packet number, given +// the largest acknowledged packet number in this number space, +// the truncated number received in a packet, and the size of the +// number received in bytes. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-17.1 +// https://www.rfc-editor.org/rfc/rfc9000.html#section-a.3 +func decodePacketNumber(largest, truncated packetNumber, numLenInBytes int) packetNumber { + expected := largest + 1 + win := packetNumber(1) << (uint(numLenInBytes) * 8) + hwin := win / 2 + mask := win - 1 + candidate := (expected &^ mask) | truncated + if candidate <= expected-hwin && candidate < (1<<62)-win { + return candidate + win + } + if candidate > expected+hwin && candidate >= win { + return candidate - win + } + return candidate +} + +// appendPacketNumber appends an encoded packet number to b. +// The packet number must be larger than the largest acknowledged packet number. +// When no packets have been acknowledged yet, largestAck is -1. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-17.1-5 +func appendPacketNumber(b []byte, pnum, largestAck packetNumber) []byte { + switch packetNumberLength(pnum, largestAck) { + case 1: + return append(b, byte(pnum)) + case 2: + return append(b, byte(pnum>>8), byte(pnum)) + case 3: + return append(b, byte(pnum>>16), byte(pnum>>8), byte(pnum)) + default: + return append(b, byte(pnum>>24), byte(pnum>>16), byte(pnum>>8), byte(pnum)) + } +} + +// packetNumberLength returns the minimum length, in bytes, needed to encode +// a packet number given the largest acknowledged packet number. +// The packet number must be larger than the largest acknowledged packet number. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-17.1-5 +func packetNumberLength(pnum, largestAck packetNumber) int { + d := pnum - largestAck + switch { + case d < 0x80: + return 1 + case d < 0x8000: + return 2 + case d < 0x800000: + return 3 + default: + return 4 + } +} diff --git a/internal/quic/packet_number_test.go b/internal/quic/packet_number_test.go new file mode 100644 index 000000000..7450e3988 --- /dev/null +++ b/internal/quic/packet_number_test.go @@ -0,0 +1,183 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "bytes" + "encoding/binary" + "testing" +) + +func TestDecodePacketNumber(t *testing.T) { + for _, test := range []struct { + largest packetNumber + truncated packetNumber + want packetNumber + size int + }{{ + largest: 0, + truncated: 1, + size: 4, + want: 1, + }, { + largest: 0, + truncated: 0, + size: 1, + want: 0, + }, { + largest: 0x00, + truncated: 0x01, + size: 1, + want: 0x01, + }, { + largest: 0x00, + truncated: 0xff, + size: 1, + want: 0xff, + }, { + largest: 0xff, + truncated: 0x01, + size: 1, + want: 0x101, + }, { + largest: 0x1000, + truncated: 0xff, + size: 1, + want: 0xfff, + }, { + largest: 0xa82f30ea, + truncated: 0x9b32, + size: 2, + want: 0xa82f9b32, + }} { + got := decodePacketNumber(test.largest, test.truncated, test.size) + if got != test.want { + t.Errorf("decodePacketNumber(largest=0x%x, truncated=0x%x, size=%v) = 0x%x, want 0x%x", test.largest, test.truncated, test.size, got, test.want) + } + } +} + +func TestEncodePacketNumber(t *testing.T) { + for _, test := range []struct { + largestAcked packetNumber + pnum packetNumber + wantSize int + }{{ + largestAcked: -1, + pnum: 0, + wantSize: 1, + }, { + largestAcked: 1000, + pnum: 1000 + 0x7f, + wantSize: 1, + }, { + largestAcked: 1000, + pnum: 1000 + 0x80, // 0x468 + wantSize: 2, + }, { + largestAcked: 0x12345678, + pnum: 0x12345678 + 0x7fff, // 0x305452663 + wantSize: 2, + }, { + largestAcked: 0x12345678, + pnum: 0x12345678 + 0x8000, + wantSize: 3, + }, { + largestAcked: 0, + pnum: 0x7fffff, + wantSize: 3, + }, { + largestAcked: 0, + pnum: 0x800000, + wantSize: 4, + }, { + largestAcked: 0xabe8bc, + pnum: 0xac5c02, + wantSize: 2, + }, { + largestAcked: 0xabe8bc, + pnum: 0xace8fe, + wantSize: 3, + }} { + size := packetNumberLength(test.pnum, test.largestAcked) + if got, want := size, test.wantSize; got != want { + t.Errorf("packetNumberLength(num=%x, maxAck=%x) = %v, want %v", test.pnum, test.largestAcked, got, want) + } + var enc packetNumber + switch size { + case 1: + enc = test.pnum & 0xff + case 2: + enc = test.pnum & 0xffff + case 3: + enc = test.pnum & 0xffffff + case 4: + enc = test.pnum & 0xffffffff + } + wantBytes := binary.BigEndian.AppendUint32(nil, uint32(enc))[4-size:] + gotBytes := appendPacketNumber(nil, test.pnum, test.largestAcked) + if !bytes.Equal(gotBytes, wantBytes) { + t.Errorf("appendPacketNumber(num=%v, maxAck=%x) = {%x}, want {%x}", test.pnum, test.largestAcked, gotBytes, wantBytes) + } + gotNum := decodePacketNumber(test.largestAcked, enc, size) + if got, want := gotNum, test.pnum; got != want { + t.Errorf("packetNumberLength(num=%x, maxAck=%x) = %v, but decoded number=%x", test.pnum, test.largestAcked, size, got) + } + } +} + +func FuzzPacketNumber(f *testing.F) { + truncatedNumber := func(in []byte) packetNumber { + var truncated packetNumber + for _, b := range in { + truncated = (truncated << 8) | packetNumber(b) + } + return truncated + } + f.Fuzz(func(t *testing.T, in []byte, largestAckedInt64 int64) { + largestAcked := packetNumber(largestAckedInt64) + if len(in) < 1 || len(in) > 4 || largestAcked < 0 || largestAcked > maxPacketNumber { + return + } + truncatedIn := truncatedNumber(in) + decoded := decodePacketNumber(largestAcked, truncatedIn, len(in)) + + // Check that the decoded packet number's least significant bits match the input. + var mask packetNumber + for i := 0; i < len(in); i++ { + mask = (mask << 8) | 0xff + } + if truncatedIn != decoded&mask { + t.Fatalf("decoding mismatch: input=%x largestAcked=%v decoded=0x%x", in, largestAcked, decoded) + } + + // We don't support encoding packet numbers less than largestAcked (since packet numbers + // never decrease), so skip the encoder tests if this would make us go backwards. + if decoded < largestAcked { + return + } + + // We might encode this number using a different length than we received, + // but the common portions should match. + encoded := appendPacketNumber(nil, decoded, largestAcked) + a, b := in, encoded + if len(b) < len(a) { + a, b = b, a + } + for len(a) < len(b) { + b = b[1:] + } + if len(a) == 0 || !bytes.Equal(a, b) { + t.Fatalf("encoding mismatch: input=%x largestAcked=%v decoded=%v reencoded=%x", in, largestAcked, decoded, encoded) + } + + if g := decodePacketNumber(largestAcked, truncatedNumber(encoded), len(encoded)); g != decoded { + t.Fatalf("packet encode/decode mismatch: pnum=%v largestAcked=%v encoded=%x got=%v", decoded, largestAcked, encoded, g) + } + if l := packetNumberLength(decoded, largestAcked); l != len(encoded) { + t.Fatalf("packet number length mismatch: pnum=%v largestAcked=%v encoded=%x len=%v", decoded, largestAcked, encoded, l) + } + }) +} From 6488c8f45768fbdf300fe3faabd75833c9bab21a Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 13 Oct 2022 12:39:26 -0700 Subject: [PATCH 13/28] quic: basic packet operations The type of a QUIC packet can be identified by inspecting its first byte, and the destination connection ID can be determined without decrypting and parsing the entire packet. For golang/go#58547 Change-Id: Ie298c0f6c0017343168a0974543e37ab7a569b0f Reviewed-on: https://go-review.googlesource.com/c/net/+/475437 Run-TryBot: Damien Neil TryBot-Result: Gopher Robot Reviewed-by: Matt Layher Reviewed-by: Jonathan Amsterdam --- internal/quic/packet.go | 159 +++++++++++++++++++++++++++++++++++ internal/quic/packet_test.go | 125 +++++++++++++++++++++++++++ 2 files changed, 284 insertions(+) create mode 100644 internal/quic/packet.go create mode 100644 internal/quic/packet_test.go diff --git a/internal/quic/packet.go b/internal/quic/packet.go new file mode 100644 index 000000000..4645ae709 --- /dev/null +++ b/internal/quic/packet.go @@ -0,0 +1,159 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +// packetType is a QUIC packet type. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-17 +type packetType byte + +const ( + packetTypeInvalid = packetType(iota) + packetTypeInitial + packetType0RTT + packetTypeHandshake + packetTypeRetry + packetType1RTT + packetTypeVersionNegotiation +) + +// Bits set in the first byte of a packet. +const ( + headerFormLong = 0x80 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.2.1 + headerFormShort = 0x00 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.3.1-4.2.1 + fixedBit = 0x40 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.4.1 + reservedBits = 0x0c // https://www.rfc-editor.org/rfc/rfc9000#section-17.2-8.2.1 +) + +// Long Packet Type bits. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.6.1 +const ( + longPacketTypeInitial = 0 << 4 + longPacketType0RTT = 1 << 4 + longPacketTypeHandshake = 2 << 4 + longPacketTypeRetry = 3 << 4 +) + +// Frame types. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-19 +const ( + frameTypePadding = 0x00 + frameTypePing = 0x01 + frameTypeAck = 0x02 + frameTypeAckECN = 0x03 + frameTypeResetStream = 0x04 + frameTypeStopSending = 0x05 + frameTypeCrypto = 0x06 + frameTypeNewToken = 0x07 + frameTypeStreamBase = 0x08 // low three bits carry stream flags + frameTypeMaxData = 0x10 + frameTypeMaxStreamData = 0x11 + frameTypeMaxStreamsBidi = 0x12 + frameTypeMaxStreamsUni = 0x13 + frameTypeDataBlocked = 0x14 + frameTypeStreamDataBlocked = 0x15 + frameTypeStreamsBlockedBidi = 0x16 + frameTypeStreamsBlockedUni = 0x17 + frameTypeNewConnectionID = 0x18 + frameTypeRetireConnectionID = 0x19 + frameTypePathChallenge = 0x1a + frameTypePathResponse = 0x1b + frameTypeConnectionCloseTransport = 0x1c + frameTypeConnectionCloseApplication = 0x1d + frameTypeHandshakeDone = 0x1e +) + +// The low three bits of STREAM frames. +// https://www.rfc-editor.org/rfc/rfc9000.html#section-19.8 +const ( + streamOffBit = 0x04 + streamLenBit = 0x02 + streamFinBit = 0x01 +) + +// isLongHeader returns true if b is the first byte of a long header. +func isLongHeader(b byte) bool { + return b&headerFormLong == headerFormLong +} + +// getPacketType returns the type of a packet. +func getPacketType(b []byte) packetType { + if len(b) == 0 { + return packetTypeInvalid + } + if !isLongHeader(b[0]) { + if b[0]&fixedBit != fixedBit { + return packetTypeInvalid + } + return packetType1RTT + } + if len(b) < 5 { + return packetTypeInvalid + } + if b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0 { + // Version Negotiation packets don't necessarily set the fixed bit. + return packetTypeVersionNegotiation + } + if b[0]&fixedBit != fixedBit { + return packetTypeInvalid + } + switch b[0] & 0x30 { + case longPacketTypeInitial: + return packetTypeInitial + case longPacketType0RTT: + return packetType0RTT + case longPacketTypeHandshake: + return packetTypeHandshake + case longPacketTypeRetry: + return packetTypeRetry + } + return packetTypeInvalid +} + +// dstConnIDForDatagram returns the destination connection ID field of the +// first QUIC packet in a datagram. +func dstConnIDForDatagram(pkt []byte) (id []byte, ok bool) { + if len(pkt) < 1 { + return nil, false + } + var n int + var b []byte + if isLongHeader(pkt[0]) { + if len(pkt) < 6 { + return nil, false + } + n = int(pkt[5]) + b = pkt[6:] + } else { + n = connIDLen + b = pkt[1:] + } + if len(b) < n { + return nil, false + } + return b[:n], true +} + +// A longPacket is a long header packet. +type longPacket struct { + ptype packetType + reservedBits uint8 + version uint32 + num packetNumber + dstConnID []byte + srcConnID []byte + payload []byte + + // The extra data depends on the packet type: + // Initial: Token. + // Retry: Retry token and integrity tag. + extra []byte +} + +// A shortPacket is a short header (1-RTT) packet. +type shortPacket struct { + reservedBits uint8 + num packetNumber + payload []byte +} diff --git a/internal/quic/packet_test.go b/internal/quic/packet_test.go new file mode 100644 index 000000000..3011dda1d --- /dev/null +++ b/internal/quic/packet_test.go @@ -0,0 +1,125 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "bytes" + "encoding/hex" + "strings" + "testing" +) + +func TestPacketHeader(t *testing.T) { + for _, test := range []struct { + name string + packet []byte + isLongHeader bool + packetType packetType + dstConnID []byte + }{{ + // Initial packet from https://www.rfc-editor.org/rfc/rfc9001#section-a.1 + // (truncated) + name: "rfc9001_a1", + packet: unhex(` + c000000001088394c8f03e5157080000 449e7b9aec34d1b1c98dd7689fb8ec11 + `), + isLongHeader: true, + packetType: packetTypeInitial, + dstConnID: unhex(`8394c8f03e515708`), + }, { + // Initial packet from https://www.rfc-editor.org/rfc/rfc9001#section-a.3 + // (truncated) + name: "rfc9001_a3", + packet: unhex(` + cf000000010008f067a5502a4262b500 4075c0d95a482cd0991cd25b0aac406a + `), + isLongHeader: true, + packetType: packetTypeInitial, + dstConnID: []byte{}, + }, { + // Retry packet from https://www.rfc-editor.org/rfc/rfc9001#section-a.4 + name: "rfc9001_a4", + packet: unhex(` + ff000000010008f067a5502a4262b574 6f6b656e04a265ba2eff4d829058fb3f + 0f2496ba + `), + isLongHeader: true, + packetType: packetTypeRetry, + dstConnID: []byte{}, + }, { + // Short header packet from https://www.rfc-editor.org/rfc/rfc9001#section-a.5 + name: "rfc9001_a5", + packet: unhex(` + 4cfe4189655e5cd55c41f69080575d7999c25a5bfb + `), + isLongHeader: false, + packetType: packetType1RTT, + dstConnID: unhex(`fe4189655e5cd55c`), + }, { + // Version Negotiation packet. + name: "version_negotiation", + packet: unhex(` + 80 00000000 01ff0001020304 + `), + isLongHeader: true, + packetType: packetTypeVersionNegotiation, + dstConnID: []byte{0xff}, + }, { + // Too-short packet. + name: "truncated_after_connid_length", + packet: unhex(` + cf0000000105 + `), + isLongHeader: true, + packetType: packetTypeInitial, + dstConnID: nil, + }, { + // Too-short packet. + name: "truncated_after_version", + packet: unhex(` + cf00000001 + `), + isLongHeader: true, + packetType: packetTypeInitial, + dstConnID: nil, + }, { + // Much too short packet. + name: "truncated_in_version", + packet: unhex(` + cf000000 + `), + isLongHeader: true, + packetType: packetTypeInvalid, + dstConnID: nil, + }} { + t.Run(test.name, func(t *testing.T) { + if got, want := isLongHeader(test.packet[0]), test.isLongHeader; got != want { + t.Errorf("packet %x:\nisLongHeader(packet) = %v, want %v", test.packet, got, want) + } + if got, want := getPacketType(test.packet), test.packetType; got != want { + t.Errorf("packet %x:\ngetPacketType(packet) = %v, want %v", test.packet, got, want) + } + gotConnID, gotOK := dstConnIDForDatagram(test.packet) + wantConnID, wantOK := test.dstConnID, test.dstConnID != nil + if !bytes.Equal(gotConnID, wantConnID) || gotOK != wantOK { + t.Errorf("packet %x:\ndstConnIDForDatagram(packet) = {%x}, %v; want {%x}, %v", test.packet, gotConnID, gotOK, wantConnID, wantOK) + } + }) + } +} + +func unhex(s string) []byte { + b, err := hex.DecodeString(strings.Map(func(c rune) rune { + switch c { + case ' ', '\t', '\n': + return -1 + } + return c + }, s)) + if err != nil { + panic(err) + } + return b +} From f71a821cfa7cd0a98b6bc3c5522aa06e39aba5b7 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 13 Oct 2022 12:03:05 -0700 Subject: [PATCH 14/28] quic: packet protection Encrypt and decrypt QUIC packets according to RFC 9001. For golang/go#58547 Change-Id: Ib7f824cf08f8520400bd38d3b3ab89e8a968114e Reviewed-on: https://go-review.googlesource.com/c/net/+/475438 Reviewed-by: Roland Shoemaker Run-TryBot: Damien Neil TryBot-Result: Gopher Robot Reviewed-by: Jonathan Amsterdam --- go.mod | 1 + go.sum | 3 + internal/quic/packet_protection.go | 266 ++++++++++++++++++++++++ internal/quic/packet_protection_test.go | 162 +++++++++++++++ 4 files changed, 432 insertions(+) create mode 100644 internal/quic/packet_protection.go create mode 100644 internal/quic/packet_protection_test.go diff --git a/go.mod b/go.mod index f661af405..b1d3e5474 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module golang.org/x/net go 1.17 require ( + golang.org/x/crypto v0.9.0 golang.org/x/sys v0.8.0 golang.org/x/term v0.8.0 golang.org/x/text v0.9.0 diff --git a/go.sum b/go.sum index 6408b66ea..af21d7cac 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,15 @@ github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= +golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/internal/quic/packet_protection.go b/internal/quic/packet_protection.go new file mode 100644 index 000000000..7d96d69cd --- /dev/null +++ b/internal/quic/packet_protection.go @@ -0,0 +1,266 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/sha256" + "crypto/tls" + "errors" + "fmt" + "hash" + + "golang.org/x/crypto/chacha20" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/hkdf" +) + +var errInvalidPacket = errors.New("quic: invalid packet") + +// keys holds the cryptographic material used to protect packets +// at an encryption level and direction. (e.g., Initial client keys.) +// +// keys are not safe for concurrent use. +type keys struct { + // AEAD function used for packet protection. + aead cipher.AEAD + + // The header_protection function as defined in: + // https://www.rfc-editor.org/rfc/rfc9001#section-5.4.1 + // + // This function takes a sample of the packet ciphertext + // and returns a 5-byte mask which will be applied to the + // protected portions of the packet header. + headerProtection func(sample []byte) (mask [5]byte) + + // IV used to construct the AEAD nonce. + iv []byte +} + +// newKeys creates keys for a given cipher suite and secret. +// +// It returns an error if the suite is unknown. +func newKeys(suite uint16, secret []byte) (keys, error) { + switch suite { + case tls.TLS_AES_128_GCM_SHA256: + return newAESKeys(secret, crypto.SHA256, 128/8), nil + case tls.TLS_AES_256_GCM_SHA384: + return newAESKeys(secret, crypto.SHA384, 256/8), nil + case tls.TLS_CHACHA20_POLY1305_SHA256: + return newChaCha20Keys(secret), nil + } + return keys{}, fmt.Errorf("unknown cipher suite %x", suite) +} + +func newAESKeys(secret []byte, h crypto.Hash, keyBytes int) keys { + // https://www.rfc-editor.org/rfc/rfc9001#section-5.1 + key := hkdfExpandLabel(h.New, secret, "quic key", nil, keyBytes) + c, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(c) + if err != nil { + panic(err) + } + iv := hkdfExpandLabel(h.New, secret, "quic iv", nil, aead.NonceSize()) + // https://www.rfc-editor.org/rfc/rfc9001#section-5.4.3 + hpKey := hkdfExpandLabel(h.New, secret, "quic hp", nil, keyBytes) + hp, err := aes.NewCipher(hpKey) + if err != nil { + panic(err) + } + var scratch [aes.BlockSize]byte + headerProtection := func(sample []byte) (mask [5]byte) { + hp.Encrypt(scratch[:], sample) + copy(mask[:], scratch[:]) + return mask + } + return keys{ + aead: aead, + iv: iv, + headerProtection: headerProtection, + } +} + +func newChaCha20Keys(secret []byte) keys { + // https://www.rfc-editor.org/rfc/rfc9001#section-5.1 + key := hkdfExpandLabel(sha256.New, secret, "quic key", nil, chacha20poly1305.KeySize) + aead, err := chacha20poly1305.New(key) + if err != nil { + panic(err) + } + iv := hkdfExpandLabel(sha256.New, secret, "quic iv", nil, aead.NonceSize()) + // https://www.rfc-editor.org/rfc/rfc9001#section-5.4.4 + hpKey := hkdfExpandLabel(sha256.New, secret, "quic hp", nil, chacha20.KeySize) + headerProtection := func(sample []byte) [5]byte { + counter := uint32(sample[3])<<24 | uint32(sample[2])<<16 | uint32(sample[1])<<8 | uint32(sample[0]) + nonce := sample[4:16] + c, err := chacha20.NewUnauthenticatedCipher(hpKey, nonce) + if err != nil { + panic(err) + } + c.SetCounter(counter) + var mask [5]byte + c.XORKeyStream(mask[:], mask[:]) + return mask + } + return keys{ + aead: aead, + iv: iv, + headerProtection: headerProtection, + } +} + +// https://www.rfc-editor.org/rfc/rfc9001#section-5.2-2 +var initialSalt = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a} + +// initialKeys returns the keys used to protect Initial packets. +// +// The Initial packet keys are derived from the Destination Connection ID +// field in the client's first Initial packet. +// +// https://www.rfc-editor.org/rfc/rfc9001#section-5.2 +func initialKeys(cid []byte) (clientKeys, serverKeys keys) { + initialSecret := hkdf.Extract(sha256.New, cid, initialSalt) + clientInitialSecret := hkdfExpandLabel(sha256.New, initialSecret, "client in", nil, sha256.Size) + clientKeys, err := newKeys(tls.TLS_AES_128_GCM_SHA256, clientInitialSecret) + if err != nil { + panic(err) + } + + serverInitialSecret := hkdfExpandLabel(sha256.New, initialSecret, "server in", nil, sha256.Size) + serverKeys, err = newKeys(tls.TLS_AES_128_GCM_SHA256, serverInitialSecret) + if err != nil { + panic(err) + } + + return clientKeys, serverKeys +} + +const headerProtectionSampleSize = 16 + +// aeadOverhead is the difference in size between the AEAD output and input. +// All cipher suites defined for use with QUIC have 16 bytes of overhead. +const aeadOverhead = 16 + +// xorIV xors the packet protection IV with the packet number. +func (k keys) xorIV(pnum packetNumber) { + k.iv[len(k.iv)-8] ^= uint8(pnum >> 56) + k.iv[len(k.iv)-7] ^= uint8(pnum >> 48) + k.iv[len(k.iv)-6] ^= uint8(pnum >> 40) + k.iv[len(k.iv)-5] ^= uint8(pnum >> 32) + k.iv[len(k.iv)-4] ^= uint8(pnum >> 24) + k.iv[len(k.iv)-3] ^= uint8(pnum >> 16) + k.iv[len(k.iv)-2] ^= uint8(pnum >> 8) + k.iv[len(k.iv)-1] ^= uint8(pnum) +} + +// initialized returns true if valid keys are available. +func (k keys) initialized() bool { + return k.aead != nil +} + +// discard discards the keys (in the sense that we won't use them any more, +// not that the keys are securely erased). +// +// https://www.rfc-editor.org/rfc/rfc9001.html#section-4.9 +func (k *keys) discard() { + *k = keys{} +} + +// protect applies packet protection to a packet. +// +// On input, hdr contains the packet header, pay the unencrypted payload, +// pnumOff the offset of the packet number in the header, and pnum the untruncated +// packet number. +// +// protect returns the result of appending the encrypted payload to hdr and +// applying header protection. +func (k keys) protect(hdr, pay []byte, pnumOff int, pnum packetNumber) []byte { + k.xorIV(pnum) + hdr = k.aead.Seal(hdr, k.iv, pay, hdr) + k.xorIV(pnum) + + // Apply header protection. + pnumSize := int(hdr[0]&0x03) + 1 + sample := hdr[pnumOff+4:][:headerProtectionSampleSize] + mask := k.headerProtection(sample) + if isLongHeader(hdr[0]) { + hdr[0] ^= mask[0] & 0x0f + } else { + hdr[0] ^= mask[0] & 0x1f + } + for i := 0; i < pnumSize; i++ { + hdr[pnumOff+i] ^= mask[1+i] + } + + return hdr +} + +// unprotect removes packet protection from a packet. +// +// On input, pkt contains the full protected packet, pnumOff the offset of +// the packet number in the header, and pnumMax the largest packet number +// seen in the number space of this packet. +// +// unprotect removes header protection from the header in pkt, and returns +// the unprotected payload and packet number. +func (k keys) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (pay []byte, num packetNumber, err error) { + if len(pkt) < pnumOff+4+headerProtectionSampleSize { + fmt.Println("too short") + return nil, 0, errInvalidPacket + } + numpay := pkt[pnumOff:] + sample := numpay[4:][:headerProtectionSampleSize] + mask := k.headerProtection(sample) + if isLongHeader(pkt[0]) { + pkt[0] ^= mask[0] & 0x0f + } else { + pkt[0] ^= mask[0] & 0x1f + } + pnumLen := int(pkt[0]&0x03) + 1 + pnum := packetNumber(0) + for i := 0; i < pnumLen; i++ { + numpay[i] ^= mask[1+i] + pnum = (pnum << 8) | packetNumber(numpay[i]) + } + pnum = decodePacketNumber(pnumMax, pnum, pnumLen) + + hdr := pkt[:pnumOff+pnumLen] + pay = numpay[pnumLen:] + k.xorIV(pnum) + pay, err = k.aead.Open(pay[:0], k.iv, pay, hdr) + k.xorIV(pnum) + if err != nil { + return nil, 0, err + } + + return pay, pnum, nil +} + +// hdkfExpandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1. +// +// Copied from crypto/tls/key_schedule.go. +func hkdfExpandLabel(hash func() hash.Hash, secret []byte, label string, context []byte, length int) []byte { + var hkdfLabel cryptobyte.Builder + hkdfLabel.AddUint16(uint16(length)) + hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte("tls13 ")) + b.AddBytes([]byte(label)) + }) + hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(context) + }) + out := make([]byte, length) + n, err := hkdf.Expand(hash, secret, hkdfLabel.BytesOrPanic()).Read(out) + if err != nil || n != length { + panic("quic: HKDF-Expand-Label invocation failed unexpectedly") + } + return out +} diff --git a/internal/quic/packet_protection_test.go b/internal/quic/packet_protection_test.go new file mode 100644 index 000000000..f1d353d8e --- /dev/null +++ b/internal/quic/packet_protection_test.go @@ -0,0 +1,162 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "bytes" + "crypto/tls" + "testing" +) + +func TestPacketProtection(t *testing.T) { + // Test cases from: + // https://www.rfc-editor.org/rfc/rfc9001#section-appendix.a + cid := unhex(`8394c8f03e515708`) + initialClientKeys, initialServerKeys := initialKeys(cid) + for _, test := range []struct { + name string + k keys + pnum packetNumber + hdr []byte + pay []byte + prot []byte + }{{ + name: "Client Initial", + k: initialClientKeys, + pnum: 2, + hdr: unhex(` + c300000001088394c8f03e5157080000 449e00000002 + `), + pay: pad(1162, unhex(` + 060040f1010000ed0303ebf8fa56f129 39b9584a3896472ec40bb863cfd3e868 + 04fe3a47f06a2b69484c000004130113 02010000c000000010000e00000b6578 + 616d706c652e636f6dff01000100000a 00080006001d00170018001000070005 + 04616c706e0005000501000000000033 00260024001d00209370b2c9caa47fba + baf4559fedba753de171fa71f50f1ce1 5d43e994ec74d748002b000302030400 + 0d0010000e0403050306030203080408 050806002d00020101001c0002400100 + 3900320408ffffffffffffffff050480 00ffff07048000ffff08011001048000 + 75300901100f088394c8f03e51570806 048000ffff + `)), + prot: unhex(` + c000000001088394c8f03e5157080000 449e7b9aec34d1b1c98dd7689fb8ec11 + d242b123dc9bd8bab936b47d92ec356c 0bab7df5976d27cd449f63300099f399 + 1c260ec4c60d17b31f8429157bb35a12 82a643a8d2262cad67500cadb8e7378c + 8eb7539ec4d4905fed1bee1fc8aafba1 7c750e2c7ace01e6005f80fcb7df6212 + 30c83711b39343fa028cea7f7fb5ff89 eac2308249a02252155e2347b63d58c5 + 457afd84d05dfffdb20392844ae81215 4682e9cf012f9021a6f0be17ddd0c208 + 4dce25ff9b06cde535d0f920a2db1bf3 62c23e596d11a4f5a6cf3948838a3aec + 4e15daf8500a6ef69ec4e3feb6b1d98e 610ac8b7ec3faf6ad760b7bad1db4ba3 + 485e8a94dc250ae3fdb41ed15fb6a8e5 eba0fc3dd60bc8e30c5c4287e53805db + 059ae0648db2f64264ed5e39be2e20d8 2df566da8dd5998ccabdae053060ae6c + 7b4378e846d29f37ed7b4ea9ec5d82e7 961b7f25a9323851f681d582363aa5f8 + 9937f5a67258bf63ad6f1a0b1d96dbd4 faddfcefc5266ba6611722395c906556 + be52afe3f565636ad1b17d508b73d874 3eeb524be22b3dcbc2c7468d54119c74 + 68449a13d8e3b95811a198f3491de3e7 fe942b330407abf82a4ed7c1b311663a + c69890f4157015853d91e923037c227a 33cdd5ec281ca3f79c44546b9d90ca00 + f064c99e3dd97911d39fe9c5d0b23a22 9a234cb36186c4819e8b9c5927726632 + 291d6a418211cc2962e20fe47feb3edf 330f2c603a9d48c0fcb5699dbfe58964 + 25c5bac4aee82e57a85aaf4e2513e4f0 5796b07ba2ee47d80506f8d2c25e50fd + 14de71e6c418559302f939b0e1abd576 f279c4b2e0feb85c1f28ff18f58891ff + ef132eef2fa09346aee33c28eb130ff2 8f5b766953334113211996d20011a198 + e3fc433f9f2541010ae17c1bf202580f 6047472fb36857fe843b19f5984009dd + c324044e847a4f4a0ab34f719595de37 252d6235365e9b84392b061085349d73 + 203a4a13e96f5432ec0fd4a1ee65accd d5e3904df54c1da510b0ff20dcc0c77f + cb2c0e0eb605cb0504db87632cf3d8b4 dae6e705769d1de354270123cb11450e + fc60ac47683d7b8d0f811365565fd98c 4c8eb936bcab8d069fc33bd801b03ade + a2e1fbc5aa463d08ca19896d2bf59a07 1b851e6c239052172f296bfb5e724047 + 90a2181014f3b94a4e97d117b4381303 68cc39dbb2d198065ae3986547926cd2 + 162f40a29f0c3c8745c0f50fba3852e5 66d44575c29d39a03f0cda721984b6f4 + 40591f355e12d439ff150aab7613499d bd49adabc8676eef023b15b65bfc5ca0 + 6948109f23f350db82123535eb8a7433 bdabcb909271a6ecbcb58b936a88cd4e + 8f2e6ff5800175f113253d8fa9ca8885 c2f552e657dc603f252e1a8e308f76f0 + be79e2fb8f5d5fbbe2e30ecadd220723 c8c0aea8078cdfcb3868263ff8f09400 + 54da48781893a7e49ad5aff4af300cd8 04a6b6279ab3ff3afb64491c85194aab + 760d58a606654f9f4400e8b38591356f bf6425aca26dc85244259ff2b19c41b9 + f96f3ca9ec1dde434da7d2d392b905dd f3d1f9af93d1af5950bd493f5aa731b4 + 056df31bd267b6b90a079831aaf579be 0a39013137aac6d404f518cfd4684064 + 7e78bfe706ca4cf5e9c5453e9f7cfd2b 8b4c8d169a44e55c88d4a9a7f9474241 + e221af44860018ab0856972e194cd934 + `), + }, { + name: "Server Initial", + k: initialServerKeys, + pnum: 1, + hdr: unhex(` + c1000000010008f067a5502a4262b500 40750001 + `), + pay: unhex(` + 02000000000600405a020000560303ee fce7f7b37ba1d1632e96677825ddf739 + 88cfc79825df566dc5430b9a045a1200 130100002e00330024001d00209d3c94 + 0d89690b84d08a60993c144eca684d10 81287c834d5311bcf32bb9da1a002b00 + 020304 + `), + prot: unhex(` + cf000000010008f067a5502a4262b500 4075c0d95a482cd0991cd25b0aac406a + 5816b6394100f37a1c69797554780bb3 8cc5a99f5ede4cf73c3ec2493a1839b3 + dbcba3f6ea46c5b7684df3548e7ddeb9 c3bf9c73cc3f3bded74b562bfb19fb84 + 022f8ef4cdd93795d77d06edbb7aaf2f 58891850abbdca3d20398c276456cbc4 + 2158407dd074ee + `), + }, { + name: "ChaCha20_Poly1305 Short Header", + k: func() keys { + secret := unhex(` + 9ac312a7f877468ebe69422748ad00a1 + 5443f18203a07d6060f688f30f21632b + `) + k, err := newKeys(tls.TLS_CHACHA20_POLY1305_SHA256, secret) + if err != nil { + t.Fatal(err) + } + return k + }(), + pnum: 654360564, + hdr: unhex(`4200bff4`), + pay: unhex(`01`), + prot: unhex(` + 4cfe4189655e5cd55c41f69080575d79 99c25a5bfb + `), + }} { + test := test + t.Run(test.name, func(t *testing.T) { + pnumLen := int(test.hdr[0]&0x03) + 1 + pnumOff := len(test.hdr) - pnumLen + + b := append([]byte{}, test.hdr...) + gotProt := test.k.protect(b, test.pay, pnumOff, test.pnum) + if got, want := gotProt, test.prot; !bytes.Equal(got, want) { + t.Errorf("Protected payload does not match:") + t.Errorf("got: %x", got) + t.Errorf("want: %x", want) + } + + pkt := append([]byte{}, test.prot...) + gotPay, gotNum, err := test.k.unprotect(pkt, pnumOff, test.pnum-1) + if err != nil { + t.Fatalf("Unexpected error unprotecting packet: %v", err) + } + if got, want := pkt[:len(test.hdr)], test.hdr; !bytes.Equal(got, want) { + t.Errorf("Unprotected header does not match:") + t.Errorf("got: %x", got) + t.Errorf("want: %x", want) + } + if got, want := gotPay, test.pay; !bytes.Equal(got, want) { + t.Errorf("Unprotected payload does not match:") + t.Errorf("got: %x", got) + t.Errorf("want: %x", want) + } + if got, want := gotNum, test.pnum; got != want { + t.Errorf("Unprotected packet number does not match: got %v, want %v", got, want) + } + }) + } +} + +func pad(n int, b []byte) []byte { + for len(b) < n { + b = append(b, 0) + } + return b +} From 10d90690bcb215f6270c032bbbba31c38be80bf2 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 13 Oct 2022 13:02:55 -0700 Subject: [PATCH 15/28] quic: add rangeset type A rangeset is an ordered list of non-overlapping int64 ranges. This type will be used for tracking which packet numbers need to be acknowledged and which parts of a stream have been sent/received. For golang/go#58547 Change-Id: Ia4ab3a47e82d0e7aea738a0f857b2129d4ea1f63 Reviewed-on: https://go-review.googlesource.com/c/net/+/478295 TryBot-Result: Gopher Robot Run-TryBot: Damien Neil Reviewed-by: Jonathan Amsterdam --- internal/quic/rangeset.go | 183 ++++++++++++++++++++ internal/quic/rangeset_test.go | 295 +++++++++++++++++++++++++++++++++ 2 files changed, 478 insertions(+) create mode 100644 internal/quic/rangeset.go create mode 100644 internal/quic/rangeset_test.go diff --git a/internal/quic/rangeset.go b/internal/quic/rangeset.go new file mode 100644 index 000000000..9d6b63a74 --- /dev/null +++ b/internal/quic/rangeset.go @@ -0,0 +1,183 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +// A rangeset is a set of int64s, stored as an ordered list of non-overlapping, +// non-empty ranges. +// +// Rangesets are efficient for small numbers of ranges, +// which is expected to be the common case. +// +// Once we're willing to drop support for pre-generics versions of Go, this can +// be made into a parameterized type to permit use with packetNumber without casts. +type rangeset []i64range + +type i64range struct { + start, end int64 // [start, end) +} + +// size returns the size of the range. +func (r i64range) size() int64 { + return r.end - r.start +} + +// contains reports whether v is in the range. +func (r i64range) contains(v int64) bool { + return r.start <= v && v < r.end +} + +// add adds [start, end) to the set, combining it with existing ranges if necessary. +func (s *rangeset) add(start, end int64) { + if start == end { + return + } + for i := range *s { + r := &(*s)[i] + if r.start > end { + // The new range comes before range i. + s.insertrange(i, start, end) + return + } + if start > r.end { + // The new range comes after range i. + continue + } + // The new range is adjacent to or overlapping range i. + if start < r.start { + r.start = start + } + if end <= r.end { + return + } + // Possibly coalesce subsquent ranges into range i. + r.end = end + j := i + 1 + for ; j < len(*s) && r.end >= (*s)[j].start; j++ { + if e := (*s)[j].end; e > r.end { + // Range j ends after the new range. + r.end = e + } + } + s.removeranges(i+1, j) + return + } + *s = append(*s, i64range{start, end}) +} + +// sub removes [start, end) from the set. +func (s *rangeset) sub(start, end int64) { + removefrom, removeto := -1, -1 + for i := range *s { + r := &(*s)[i] + if end < r.start { + break + } + if r.end < start { + continue + } + switch { + case start <= r.start && end >= r.end: + // Remove the entire range. + if removefrom == -1 { + removefrom = i + } + removeto = i + 1 + case start <= r.start: + // Remove a prefix. + r.start = end + case end >= r.end: + // Remove a suffix. + r.end = start + default: + // Remove the middle, leaving two new ranges. + rend := r.end + r.end = start + s.insertrange(i+1, end, rend) + return + } + } + if removefrom != -1 { + s.removeranges(removefrom, removeto) + } +} + +// contains reports whether s contains v. +func (s rangeset) contains(v int64) bool { + for _, r := range s { + if v >= r.end { + continue + } + if r.start <= v { + return true + } + return false + } + return false +} + +// rangeContaining returns the range containing v, or the range [0,0) if v is not in s. +func (s rangeset) rangeContaining(v int64) i64range { + for _, r := range s { + if v >= r.end { + continue + } + if r.start <= v { + return r + } + break + } + return i64range{0, 0} +} + +// min returns the minimum value in the set, or 0 if empty. +func (s rangeset) min() int64 { + if len(s) == 0 { + return 0 + } + return s[0].start +} + +// max returns the maximum value in the set, or 0 if empty. +func (s rangeset) max() int64 { + if len(s) == 0 { + return 0 + } + return s[len(s)-1].end - 1 +} + +// end returns the end of the last range in the set, or 0 if empty. +func (s rangeset) end() int64 { + if len(s) == 0 { + return 0 + } + return s[len(s)-1].end +} + +// isrange reports if the rangeset covers exactly the range [start, end). +func (s rangeset) isrange(start, end int64) bool { + switch len(s) { + case 0: + return start == 0 && end == 0 + case 1: + return s[0].start == start && s[0].end == end + } + return false +} + +// removeranges removes ranges [i,j). +func (s *rangeset) removeranges(i, j int) { + if i == j { + return + } + copy((*s)[i:], (*s)[j:]) + *s = (*s)[:len(*s)-(j-i)] +} + +// insert adds a new range at index i. +func (s *rangeset) insertrange(i int, start, end int64) { + *s = append(*s, i64range{}) + copy((*s)[i+1:], (*s)[i:]) + (*s)[i] = i64range{start, end} +} diff --git a/internal/quic/rangeset_test.go b/internal/quic/rangeset_test.go new file mode 100644 index 000000000..292284813 --- /dev/null +++ b/internal/quic/rangeset_test.go @@ -0,0 +1,295 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "reflect" + "testing" +) + +func TestRangeSize(t *testing.T) { + for _, test := range []struct { + r i64range + want int64 + }{{ + r: i64range{0, 100}, + want: 100, + }, { + r: i64range{10, 20}, + want: 10, + }} { + if got := test.r.size(); got != test.want { + t.Errorf("%+v.size = %v, want %v", test.r, got, test.want) + } + } +} + +func TestRangeContains(t *testing.T) { + r := i64range{5, 10} + for _, i := range []int64{0, 4, 10, 15} { + if r.contains(i) { + t.Errorf("%v.contains(%v) = true, want false", r, i) + } + } + for _, i := range []int64{5, 6, 7, 8, 9} { + if !r.contains(i) { + t.Errorf("%v.contains(%v) = false, want true", r, i) + } + } +} + +func TestRangesetAdd(t *testing.T) { + for _, test := range []struct { + desc string + set rangeset + add i64range + want rangeset + }{{ + desc: "add to empty set", + set: rangeset{}, + add: i64range{0, 100}, + want: rangeset{{0, 100}}, + }, { + desc: "add empty range", + set: rangeset{}, + add: i64range{100, 100}, + want: rangeset{}, + }, { + desc: "append nonadjacent range", + set: rangeset{{100, 200}}, + add: i64range{300, 400}, + want: rangeset{{100, 200}, {300, 400}}, + }, { + desc: "prepend nonadjacent range", + set: rangeset{{100, 200}}, + add: i64range{0, 50}, + want: rangeset{{0, 50}, {100, 200}}, + }, { + desc: "insert nonadjacent range", + set: rangeset{{100, 200}, {500, 600}}, + add: i64range{300, 400}, + want: rangeset{{100, 200}, {300, 400}, {500, 600}}, + }, { + desc: "prepend adjacent range", + set: rangeset{{100, 200}}, + add: i64range{50, 100}, + want: rangeset{{50, 200}}, + }, { + desc: "append adjacent range", + set: rangeset{{100, 200}}, + add: i64range{200, 250}, + want: rangeset{{100, 250}}, + }, { + desc: "prepend overlapping range", + set: rangeset{{100, 200}}, + add: i64range{50, 150}, + want: rangeset{{50, 200}}, + }, { + desc: "append overlapping range", + set: rangeset{{100, 200}}, + add: i64range{150, 250}, + want: rangeset{{100, 250}}, + }, { + desc: "replace range", + set: rangeset{{100, 200}}, + add: i64range{50, 250}, + want: rangeset{{50, 250}}, + }, { + desc: "prepend and combine", + set: rangeset{{100, 200}, {300, 400}, {500, 600}}, + add: i64range{50, 300}, + want: rangeset{{50, 400}, {500, 600}}, + }, { + desc: "combine several ranges", + set: rangeset{{100, 200}, {300, 400}, {500, 600}, {700, 800}, {900, 1000}}, + add: i64range{300, 850}, + want: rangeset{{100, 200}, {300, 850}, {900, 1000}}, + }} { + test := test + t.Run(test.desc, func(t *testing.T) { + got := test.set + got.add(test.add.start, test.add.end) + if !reflect.DeepEqual(got, test.want) { + t.Errorf("add [%v,%v) to %v", test.add.start, test.add.end, test.set) + t.Errorf(" got: %v", got) + t.Errorf(" want: %v", test.want) + } + }) + } +} + +func TestRangesetSub(t *testing.T) { + for _, test := range []struct { + desc string + set rangeset + sub i64range + want rangeset + }{{ + desc: "subtract from empty set", + set: rangeset{}, + sub: i64range{0, 100}, + want: rangeset{}, + }, { + desc: "subtract empty range", + set: rangeset{{0, 100}}, + sub: i64range{0, 0}, + want: rangeset{{0, 100}}, + }, { + desc: "subtract not present in set", + set: rangeset{{0, 100}, {200, 300}}, + sub: i64range{100, 200}, + want: rangeset{{0, 100}, {200, 300}}, + }, { + desc: "subtract prefix", + set: rangeset{{100, 200}}, + sub: i64range{0, 150}, + want: rangeset{{150, 200}}, + }, { + desc: "subtract suffix", + set: rangeset{{100, 200}}, + sub: i64range{150, 300}, + want: rangeset{{100, 150}}, + }, { + desc: "subtract middle", + set: rangeset{{0, 100}}, + sub: i64range{40, 60}, + want: rangeset{{0, 40}, {60, 100}}, + }, { + desc: "subtract from two ranges", + set: rangeset{{0, 100}, {200, 300}}, + sub: i64range{50, 250}, + want: rangeset{{0, 50}, {250, 300}}, + }, { + desc: "subtract removes range", + set: rangeset{{0, 100}, {200, 300}, {400, 500}}, + sub: i64range{200, 300}, + want: rangeset{{0, 100}, {400, 500}}, + }, { + desc: "subtract removes multiple ranges", + set: rangeset{{0, 100}, {200, 300}, {400, 500}, {600, 700}}, + sub: i64range{50, 650}, + want: rangeset{{0, 50}, {650, 700}}, + }, { + desc: "subtract only range", + set: rangeset{{0, 100}}, + sub: i64range{0, 100}, + want: rangeset{}, + }} { + test := test + t.Run(test.desc, func(t *testing.T) { + got := test.set + got.sub(test.sub.start, test.sub.end) + if !reflect.DeepEqual(got, test.want) { + t.Errorf("sub [%v,%v) from %v", test.sub.start, test.sub.end, test.set) + t.Errorf(" got: %v", got) + t.Errorf(" want: %v", test.want) + } + }) + } +} + +func TestRangesetContains(t *testing.T) { + var s rangeset + s.add(10, 20) + s.add(30, 40) + for i := int64(0); i < 50; i++ { + want := (i >= 10 && i < 20) || (i >= 30 && i < 40) + if got := s.contains(i); got != want { + t.Errorf("%v.contains(%v) = %v, want %v", s, i, got, want) + } + } +} + +func TestRangesetRangeContaining(t *testing.T) { + var s rangeset + s.add(10, 20) + s.add(30, 40) + for _, test := range []struct { + v int64 + want i64range + }{ + {0, i64range{0, 0}}, + {9, i64range{0, 0}}, + {10, i64range{10, 20}}, + {15, i64range{10, 20}}, + {19, i64range{10, 20}}, + {20, i64range{0, 0}}, + {29, i64range{0, 0}}, + {30, i64range{30, 40}}, + {39, i64range{30, 40}}, + {40, i64range{0, 0}}, + } { + got := s.rangeContaining(test.v) + if got != test.want { + t.Errorf("%v.rangeContaining(%v) = %v, want %v", s, test.v, got, test.want) + } + } +} + +func TestRangesetLimits(t *testing.T) { + for _, test := range []struct { + s rangeset + wantMin int64 + wantMax int64 + wantEnd int64 + }{{ + s: rangeset{}, + wantMin: 0, + wantMax: 0, + wantEnd: 0, + }, { + s: rangeset{{10, 20}}, + wantMin: 10, + wantMax: 19, + wantEnd: 20, + }, { + s: rangeset{{10, 20}, {30, 40}, {50, 60}}, + wantMin: 10, + wantMax: 59, + wantEnd: 60, + }} { + if got, want := test.s.min(), test.wantMin; got != want { + t.Errorf("%+v.min() = %v, want %v", test.s, got, want) + } + if got, want := test.s.max(), test.wantMax; got != want { + t.Errorf("%+v.max() = %v, want %v", test.s, got, want) + } + if got, want := test.s.end(), test.wantEnd; got != want { + t.Errorf("%+v.end() = %v, want %v", test.s, got, want) + } + } +} + +func TestRangesetIsRange(t *testing.T) { + for _, test := range []struct { + s rangeset + r i64range + want bool + }{{ + s: rangeset{{0, 100}}, + r: i64range{0, 100}, + want: true, + }, { + s: rangeset{{0, 100}}, + r: i64range{0, 101}, + want: false, + }, { + s: rangeset{{0, 10}, {11, 100}}, + r: i64range{0, 100}, + want: false, + }, { + s: rangeset{}, + r: i64range{0, 0}, + want: true, + }, { + s: rangeset{}, + r: i64range{0, 1}, + want: false, + }} { + if got := test.s.isrange(test.r.start, test.r.end); got != test.want { + t.Errorf("%+v.isrange(%v, %v) = %v, want %v", test.s, test.r.start, test.r.end, got, test.want) + } + } +} From d40f1541f796a58c388202df586d24c84dd97f4d Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 13 Oct 2022 21:37:09 -0700 Subject: [PATCH 16/28] quic: varint encoding and decoding Functions to encode and decode QUIC variable-length integers (RFC 9000, Section 16), as well as a few other common operations. For golang/go#58547 Change-Id: I2a738e8798b8013a7b13d7c1e1385bf846c6c2cd Reviewed-on: https://go-review.googlesource.com/c/net/+/478258 Run-TryBot: Damien Neil TryBot-Result: Gopher Robot Reviewed-by: Jonathan Amsterdam --- internal/quic/wire.go | 145 ++++++++++++++++++++++++ internal/quic/wire_test.go | 223 +++++++++++++++++++++++++++++++++++++ 2 files changed, 368 insertions(+) create mode 100644 internal/quic/wire.go create mode 100644 internal/quic/wire_test.go diff --git a/internal/quic/wire.go b/internal/quic/wire.go new file mode 100644 index 000000000..2494ad031 --- /dev/null +++ b/internal/quic/wire.go @@ -0,0 +1,145 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import "encoding/binary" + +const maxVarintSize = 8 + +// consumeVarint parses a variable-length integer, reporting its length. +// It returns a negative length upon an error. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-16 +func consumeVarint(b []byte) (v uint64, n int) { + if len(b) < 1 { + return 0, -1 + } + b0 := b[0] & 0x3f + switch b[0] >> 6 { + case 0: + return uint64(b0), 1 + case 1: + if len(b) < 2 { + return 0, -1 + } + return uint64(b0)<<8 | uint64(b[1]), 2 + case 2: + if len(b) < 4 { + return 0, -1 + } + return uint64(b0)<<24 | uint64(b[1])<<16 | uint64(b[2])<<8 | uint64(b[3]), 4 + case 3: + if len(b) < 8 { + return 0, -1 + } + return uint64(b0)<<56 | uint64(b[1])<<48 | uint64(b[2])<<40 | uint64(b[3])<<32 | uint64(b[4])<<24 | uint64(b[5])<<16 | uint64(b[6])<<8 | uint64(b[7]), 8 + } + return 0, -1 +} + +// consumeVarint64 parses a variable-length integer as an int64. +func consumeVarintInt64(b []byte) (v int64, n int) { + u, n := consumeVarint(b) + // QUIC varints are 62-bits large, so this conversion can never overflow. + return int64(u), n +} + +// appendVarint appends a variable-length integer to b. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-16 +func appendVarint(b []byte, v uint64) []byte { + switch { + case v <= 63: + return append(b, byte(v)) + case v <= 16383: + return append(b, (1<<6)|byte(v>>8), byte(v)) + case v <= 1073741823: + return append(b, (2<<6)|byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) + case v <= 4611686018427387903: + return append(b, (3<<6)|byte(v>>56), byte(v>>48), byte(v>>40), byte(v>>32), byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) + default: + panic("varint too large") + } +} + +// sizeVarint returns the size of the variable-length integer encoding of f. +func sizeVarint(v uint64) int { + switch { + case v <= 63: + return 1 + case v <= 16383: + return 2 + case v <= 1073741823: + return 4 + case v <= 4611686018427387903: + return 8 + default: + panic("varint too large") + } +} + +// consumeUint32 parses a 32-bit fixed-length, big-endian integer, reporting its length. +// It returns a negative length upon an error. +func consumeUint32(b []byte) (uint32, int) { + if len(b) < 4 { + return 0, -1 + } + return binary.BigEndian.Uint32(b), 4 +} + +// consumeUint64 parses a 64-bit fixed-length, big-endian integer, reporting its length. +// It returns a negative length upon an error. +func consumeUint64(b []byte) (uint64, int) { + if len(b) < 8 { + return 0, -1 + } + return binary.BigEndian.Uint64(b), 8 +} + +// consumeUint8Bytes parses a sequence of bytes prefixed with an 8-bit length, +// reporting the total number of bytes consumed. +// It returns a negative length upon an error. +func consumeUint8Bytes(b []byte) ([]byte, int) { + if len(b) < 1 { + return nil, -1 + } + size := int(b[0]) + const n = 1 + if size > len(b[n:]) { + return nil, -1 + } + return b[n:][:size], size + n +} + +// appendUint8Bytes appends a sequence of bytes prefixed by an 8-bit length. +func appendUint8Bytes(b, v []byte) []byte { + if len(v) > 0xff { + panic("uint8-prefixed bytes too large") + } + b = append(b, uint8(len(v))) + b = append(b, v...) + return b +} + +// consumeVarintBytes parses a sequence of bytes preceded by a variable-length integer length, +// reporting the total number of bytes consumed. +// It returns a negative length upon an error. +func consumeVarintBytes(b []byte) ([]byte, int) { + size, n := consumeVarint(b) + if n < 0 { + return nil, -1 + } + if size > uint64(len(b[n:])) { + return nil, -1 + } + return b[n:][:size], int(size) + n +} + +// appendVarintBytes appends a sequence of bytes prefixed by a variable-length integer length. +func appendVarintBytes(b, v []byte) []byte { + b = appendVarint(b, uint64(len(v))) + b = append(b, v...) + return b +} diff --git a/internal/quic/wire_test.go b/internal/quic/wire_test.go new file mode 100644 index 000000000..a5dd83661 --- /dev/null +++ b/internal/quic/wire_test.go @@ -0,0 +1,223 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "bytes" + "testing" +) + +func TestConsumeVarint(t *testing.T) { + for _, test := range []struct { + b []byte + want uint64 + wantLen int + }{ + {[]byte{0x00}, 0, 1}, + {[]byte{0x3f}, 63, 1}, + {[]byte{0x40, 0x00}, 0, 2}, + {[]byte{0x7f, 0xff}, 16383, 2}, + {[]byte{0x80, 0x00, 0x00, 0x00}, 0, 4}, + {[]byte{0xbf, 0xff, 0xff, 0xff}, 1073741823, 4}, + {[]byte{0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 0, 8}, + {[]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 4611686018427387903, 8}, + // Example cases from https://www.rfc-editor.org/rfc/rfc9000.html#section-a.1 + {[]byte{0xc2, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c}, 151288809941952652, 8}, + {[]byte{0x9d, 0x7f, 0x3e, 0x7d}, 494878333, 4}, + {[]byte{0x7b, 0xbd}, 15293, 2}, + {[]byte{0x25}, 37, 1}, + {[]byte{0x40, 0x25}, 37, 2}, + } { + got, gotLen := consumeVarint(test.b) + if got != test.want || gotLen != test.wantLen { + t.Errorf("consumeVarint(%x) = %v, %v; want %v, %v", test.b, got, gotLen, test.want, test.wantLen) + } + // Extra data in the buffer is ignored. + b := append(test.b, 0) + got, gotLen = consumeVarint(b) + if got != test.want || gotLen != test.wantLen { + t.Errorf("consumeVarint(%x) = %v, %v; want %v, %v", b, got, gotLen, test.want, test.wantLen) + } + // Short buffer results in an error. + for i := 1; i <= len(test.b); i++ { + b = test.b[:len(test.b)-i] + got, gotLen = consumeVarint(b) + if got != 0 || gotLen >= 0 { + t.Errorf("consumeVarint(%x) = %v, %v; want 0, -1", b, got, gotLen) + } + } + } +} + +func TestAppendVarint(t *testing.T) { + for _, test := range []struct { + v uint64 + want []byte + }{ + {0, []byte{0x00}}, + {63, []byte{0x3f}}, + {16383, []byte{0x7f, 0xff}}, + {1073741823, []byte{0xbf, 0xff, 0xff, 0xff}}, + {4611686018427387903, []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + // Example cases from https://www.rfc-editor.org/rfc/rfc9000.html#section-a.1 + {151288809941952652, []byte{0xc2, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c}}, + {494878333, []byte{0x9d, 0x7f, 0x3e, 0x7d}}, + {15293, []byte{0x7b, 0xbd}}, + {37, []byte{0x25}}, + } { + got := appendVarint([]byte{}, test.v) + if !bytes.Equal(got, test.want) { + t.Errorf("AppendVarint(nil, %v) = %x, want %x", test.v, got, test.want) + } + if gotLen, wantLen := sizeVarint(test.v), len(got); gotLen != wantLen { + t.Errorf("SizeVarint(%v) = %v, want %v", test.v, gotLen, wantLen) + } + } +} + +func TestConsumeUint32(t *testing.T) { + for _, test := range []struct { + b []byte + want uint32 + wantLen int + }{ + {[]byte{0x01, 0x02, 0x03, 0x04}, 0x01020304, 4}, + {[]byte{0x01, 0x02, 0x03}, 0, -1}, + } { + if got, n := consumeUint32(test.b); got != test.want || n != test.wantLen { + t.Errorf("consumeUint32(%x) = %v, %v; want %v, %v", test.b, got, n, test.want, test.wantLen) + } + } +} + +func TestConsumeUint64(t *testing.T) { + for _, test := range []struct { + b []byte + want uint64 + wantLen int + }{ + {[]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, 0x0102030405060708, 8}, + {[]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 0, -1}, + } { + if got, n := consumeUint64(test.b); got != test.want || n != test.wantLen { + t.Errorf("consumeUint32(%x) = %v, %v; want %v, %v", test.b, got, n, test.want, test.wantLen) + } + } +} + +func TestConsumeVarintBytes(t *testing.T) { + for _, test := range []struct { + b []byte + want []byte + wantLen int + }{ + {[]byte{0x00}, []byte{}, 1}, + {[]byte{0x40, 0x00}, []byte{}, 2}, + {[]byte{0x04, 0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, 5}, + {[]byte{0x40, 0x04, 0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, 6}, + } { + got, gotLen := consumeVarintBytes(test.b) + if !bytes.Equal(got, test.want) || gotLen != test.wantLen { + t.Errorf("consumeVarintBytes(%x) = {%x}, %v; want {%x}, %v", test.b, got, gotLen, test.want, test.wantLen) + } + // Extra data in the buffer is ignored. + b := append(test.b, 0) + got, gotLen = consumeVarintBytes(b) + if !bytes.Equal(got, test.want) || gotLen != test.wantLen { + t.Errorf("consumeVarintBytes(%x) = {%x}, %v; want {%x}, %v", b, got, gotLen, test.want, test.wantLen) + } + // Short buffer results in an error. + for i := 1; i <= len(test.b); i++ { + b = test.b[:len(test.b)-i] + got, gotLen := consumeVarintBytes(b) + if len(got) > 0 || gotLen > 0 { + t.Errorf("consumeVarintBytes(%x) = {%x}, %v; want {}, -1", b, got, gotLen) + } + } + + } +} + +func TestConsumeVarintBytesErrors(t *testing.T) { + for _, b := range [][]byte{ + {0x01}, + {0x40, 0x01}, + } { + got, gotLen := consumeVarintBytes(b) + if len(got) > 0 || gotLen > 0 { + t.Errorf("consumeVarintBytes(%x) = {%x}, %v; want {}, -1", b, got, gotLen) + } + } +} + +func TestConsumeUint8Bytes(t *testing.T) { + for _, test := range []struct { + b []byte + want []byte + wantLen int + }{ + {[]byte{0x00}, []byte{}, 1}, + {[]byte{0x01, 0x00}, []byte{0x00}, 2}, + {[]byte{0x04, 0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, 5}, + } { + got, gotLen := consumeUint8Bytes(test.b) + if !bytes.Equal(got, test.want) || gotLen != test.wantLen { + t.Errorf("consumeUint8Bytes(%x) = {%x}, %v; want {%x}, %v", test.b, got, gotLen, test.want, test.wantLen) + } + // Extra data in the buffer is ignored. + b := append(test.b, 0) + got, gotLen = consumeUint8Bytes(b) + if !bytes.Equal(got, test.want) || gotLen != test.wantLen { + t.Errorf("consumeUint8Bytes(%x) = {%x}, %v; want {%x}, %v", b, got, gotLen, test.want, test.wantLen) + } + // Short buffer results in an error. + for i := 1; i <= len(test.b); i++ { + b = test.b[:len(test.b)-i] + got, gotLen := consumeUint8Bytes(b) + if len(got) > 0 || gotLen > 0 { + t.Errorf("consumeUint8Bytes(%x) = {%x}, %v; want {}, -1", b, got, gotLen) + } + } + + } +} + +func TestConsumeUint8BytesErrors(t *testing.T) { + for _, b := range [][]byte{ + {0x01}, + {0x04, 0x01, 0x02, 0x03}, + } { + got, gotLen := consumeUint8Bytes(b) + if len(got) > 0 || gotLen > 0 { + t.Errorf("consumeUint8Bytes(%x) = {%x}, %v; want {}, -1", b, got, gotLen) + } + } +} + +func TestAppendUint8Bytes(t *testing.T) { + var got []byte + got = appendUint8Bytes(got, []byte{}) + got = appendUint8Bytes(got, []byte{0xaa, 0xbb}) + want := []byte{ + 0x00, + 0x02, 0xaa, 0xbb, + } + if !bytes.Equal(got, want) { + t.Errorf("appendUint8Bytes {}, {aabb} = {%x}; want {%x}", got, want) + } +} + +func TestAppendVarintBytes(t *testing.T) { + var got []byte + got = appendVarintBytes(got, []byte{}) + got = appendVarintBytes(got, []byte{0xaa, 0xbb}) + want := []byte{ + 0x00, + 0x02, 0xaa, 0xbb, + } + if !bytes.Equal(got, want) { + t.Errorf("appendVarintBytes {}, {aabb} = {%x}; want {%x}", got, want) + } +} From 61d852e7b042c17b30770633d4d2a8f7b739e7b2 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Wed, 4 Jan 2023 09:11:34 -0800 Subject: [PATCH 17/28] quic: error codes and types Constants for the transport error codes in RFC 9000 Section 20, types representing transport errors sent to or received from the peer, and a type representing application protocol errors. For golang/go#58547 Change-Id: Ib4325e1272f6e0984f233ef494827a1799d7dc26 Reviewed-on: https://go-review.googlesource.com/c/net/+/495235 Reviewed-by: Jonathan Amsterdam TryBot-Result: Gopher Robot Run-TryBot: Damien Neil --- internal/quic/errors.go | 110 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 internal/quic/errors.go diff --git a/internal/quic/errors.go b/internal/quic/errors.go new file mode 100644 index 000000000..725a7daaa --- /dev/null +++ b/internal/quic/errors.go @@ -0,0 +1,110 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "fmt" +) + +// A transportError is an transport error code from RFC 9000 Section 20.1. +// +// The transportError type doesn't implement the error interface to ensure we always +// distinguish between errors sent to and received from the peer. +// See the localTransportError and peerTransportError types below. +type transportError uint64 + +// https://www.rfc-editor.org/rfc/rfc9000.html#section-20.1 +const ( + errNo = transportError(0x00) + errInternal = transportError(0x01) + errConnectionRefused = transportError(0x02) + errFlowControl = transportError(0x03) + errStreamLimit = transportError(0x04) + errStreamState = transportError(0x05) + errFinalSize = transportError(0x06) + errFrameEncoding = transportError(0x07) + errTransportParameter = transportError(0x08) + errConnectionIDLimit = transportError(0x09) + errProtocolViolation = transportError(0x0a) + errInvalidToken = transportError(0x0b) + errApplicationError = transportError(0x0c) + errCryptoBufferExceeded = transportError(0x0d) + errKeyUpdateError = transportError(0x0e) + errAEADLimitReached = transportError(0x0f) + errNoViablePath = transportError(0x10) + errTLSBase = transportError(0x0100) // 0x0100-0x01ff; base + TLS code +) + +func (e transportError) String() string { + switch e { + case errNo: + return "NO_ERROR" + case errInternal: + return "INTERNAL_ERROR" + case errConnectionRefused: + return "CONNECTION_REFUSED" + case errFlowControl: + return "FLOW_CONTROL_ERROR" + case errStreamLimit: + return "STREAM_LIMIT_ERROR" + case errStreamState: + return "STREAM_STATE_ERROR" + case errFinalSize: + return "FINAL_SIZE_ERROR" + case errFrameEncoding: + return "FRAME_ENCODING_ERROR" + case errTransportParameter: + return "TRANSPORT_PARAMETER_ERROR" + case errConnectionIDLimit: + return "CONNECTION_ID_LIMIT_ERROR" + case errProtocolViolation: + return "PROTOCOL_VIOLATION" + case errInvalidToken: + return "INVALID_TOKEN" + case errApplicationError: + return "APPLICATION_ERROR" + case errCryptoBufferExceeded: + return "CRYPTO_BUFFER_EXCEEDED" + case errKeyUpdateError: + return "KEY_UPDATE_ERROR" + case errAEADLimitReached: + return "AEAD_LIMIT_REACHED" + case errNoViablePath: + return "NO_VIABLE_PATH" + } + if e >= 0x0100 && e <= 0x01ff { + return fmt.Sprintf("CRYPTO_ERROR(%v)", uint64(e)&0xff) + } + return fmt.Sprintf("ERROR %d", uint64(e)) +} + +// A localTransportError is an error sent to the peer. +type localTransportError transportError + +func (e localTransportError) Error() string { + return "closed connection: " + transportError(e).String() +} + +// A peerTransportError is an error received from the peer. +type peerTransportError struct { + code transportError + reason string +} + +func (e peerTransportError) Error() string { + return fmt.Sprintf("peer closed connection: %v: %q", e.code, e.reason) +} + +// An ApplicationError is an application protocol error code (RFC 9000, Section 20.2). +// Application protocol errors may be sent when terminating a stream or connection. +type ApplicationError struct { + Code uint64 + Reason string +} + +func (e ApplicationError) Error() string { + // TODO: Include the Reason string here, but sanitize it first. + return fmt.Sprintf("AppError %v", e.Code) +} From a233290d3062ca7801ab3b804a4d7ee5d0e14253 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 13 Oct 2022 21:38:10 -0700 Subject: [PATCH 18/28] quic: add a data structure for tracking sent packets When we send a packet, we need to remember its contents until it has been acked or detected as lost. For golang/go#58547 Change-Id: I8c18f7ca1730a3ce460cd562d060dd6c7cfa9ffb Reviewed-on: https://go-review.googlesource.com/c/net/+/495236 Reviewed-by: Jonathan Amsterdam Reviewed-by: Cuong Manh Le Run-TryBot: Damien Neil TryBot-Result: Gopher Robot --- internal/quic/sent_packet.go | 100 ++++++++++++++++++++++++++++++ internal/quic/sent_packet_test.go | 51 +++++++++++++++ 2 files changed, 151 insertions(+) create mode 100644 internal/quic/sent_packet.go create mode 100644 internal/quic/sent_packet_test.go diff --git a/internal/quic/sent_packet.go b/internal/quic/sent_packet.go new file mode 100644 index 000000000..03d5e53d1 --- /dev/null +++ b/internal/quic/sent_packet.go @@ -0,0 +1,100 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "sync" + "time" +) + +// A sentPacket tracks state related to an in-flight packet we sent, +// to be committed when the peer acks it or resent if the packet is lost. +type sentPacket struct { + num packetNumber + size int // size in bytes + time time.Time // time sent + + ackEliciting bool // https://www.rfc-editor.org/rfc/rfc9002.html#section-2-3.4.1 + inFlight bool // https://www.rfc-editor.org/rfc/rfc9002.html#section-2-3.6.1 + acked bool // ack has been received + lost bool // packet is presumed lost + + // Frames sent in the packet. + // + // This is an abbreviated version of the packet payload, containing only the information + // we need to process an ack for or loss of this packet. + // For example, a CRYPTO frame is recorded as the frame type (0x06), offset, and length, + // but does not include the sent data. + b []byte + n int // read offset into b +} + +var sentPool = sync.Pool{ + New: func() interface{} { + return &sentPacket{} + }, +} + +func newSentPacket() *sentPacket { + sent := sentPool.Get().(*sentPacket) + sent.reset() + return sent +} + +// recycle returns a sentPacket to the pool. +func (sent *sentPacket) recycle() { + sentPool.Put(sent) +} + +func (sent *sentPacket) reset() { + *sent = sentPacket{ + b: sent.b[:0], + } +} + +// The append* methods record information about frames in the packet. + +func (sent *sentPacket) appendNonAckElicitingFrame(frameType byte) { + sent.b = append(sent.b, frameType) +} + +func (sent *sentPacket) appendAckElicitingFrame(frameType byte) { + sent.ackEliciting = true + sent.inFlight = true + sent.b = append(sent.b, frameType) +} + +func (sent *sentPacket) appendInt(v uint64) { + sent.b = appendVarint(sent.b, v) +} + +func (sent *sentPacket) appendOffAndSize(start int64, size int) { + sent.b = appendVarint(sent.b, uint64(start)) + sent.b = appendVarint(sent.b, uint64(size)) +} + +// The next* methods read back information about frames in the packet. + +func (sent *sentPacket) next() (frameType byte) { + f := sent.b[sent.n] + sent.n++ + return f +} + +func (sent *sentPacket) nextInt() uint64 { + v, n := consumeVarint(sent.b[sent.n:]) + sent.n += n + return v +} + +func (sent *sentPacket) nextRange() (start, end int64) { + start = int64(sent.nextInt()) + end = start + int64(sent.nextInt()) + return start, end +} + +func (sent *sentPacket) done() bool { + return sent.n == len(sent.b) +} diff --git a/internal/quic/sent_packet_test.go b/internal/quic/sent_packet_test.go new file mode 100644 index 000000000..08a3d8ff0 --- /dev/null +++ b/internal/quic/sent_packet_test.go @@ -0,0 +1,51 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import "testing" + +func TestSentPacket(t *testing.T) { + frames := []interface{}{ + byte(frameTypePing), + byte(frameTypeStreamBase), + uint64(1), + i64range{1 << 20, 1<<20 + 1024}, + } + // Record sent frames. + sent := newSentPacket() + for _, f := range frames { + switch f := f.(type) { + case byte: + sent.appendAckElicitingFrame(f) + case uint64: + sent.appendInt(f) + case i64range: + sent.appendOffAndSize(f.start, int(f.size())) + } + } + // Read the record. + for i, want := range frames { + if done := sent.done(); done { + t.Fatalf("before consuming contents, sent.done() = true, want false") + } + switch want := want.(type) { + case byte: + if got := sent.next(); got != want { + t.Fatalf("%v: sent.next() = %v, want %v", i, got, want) + } + case uint64: + if got := sent.nextInt(); got != want { + t.Fatalf("%v: sent.nextInt() = %v, want %v", i, got, want) + } + case i64range: + if start, end := sent.nextRange(); start != want.start || end != want.end { + t.Fatalf("%v: sent.nextRange() = [%v,%v), want %v", i, start, end, want) + } + } + } + if done := sent.done(); !done { + t.Fatalf("after consuming contents, sent.done() = false, want true") + } +} From 1b5a2d8538e23b99860872499f7a5c65b8b8c4b7 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Fri, 14 Oct 2022 10:16:44 -0700 Subject: [PATCH 19/28] quic: packet encoding/decoding Frame encoding is handled by the packetWriter type. The packetWriter also takes responsibility for recording the contents of constructed packets in a sentPacket structure. Frame decoding is handled by consume*Frame functions, which generally return the frame contents. ACK frames, which have complex contents, are provided to the caller via callback function. In addition to the above functions, used in the serving path, this CL includes per-frame types that implement a common debugFrame interface. These types are used for tests and debug logging, but not in the serving path where we want to avoid allocations from storing values in an interface. For golang/go#58547 Change-Id: I03ce11210aa9aa6ac749a5273b2ba9dd9c6989cf Reviewed-on: https://go-review.googlesource.com/c/net/+/495355 Reviewed-by: Jonathan Amsterdam Run-TryBot: Damien Neil TryBot-Result: Gopher Robot --- internal/quic/frame_debug.go | 506 ++++++++++++++++++++ internal/quic/packet_codec_test.go | 712 +++++++++++++++++++++++++++++ internal/quic/packet_parser.go | 524 +++++++++++++++++++++ internal/quic/packet_writer.go | 548 ++++++++++++++++++++++ 4 files changed, 2290 insertions(+) create mode 100644 internal/quic/frame_debug.go create mode 100644 internal/quic/packet_codec_test.go create mode 100644 internal/quic/packet_parser.go create mode 100644 internal/quic/packet_writer.go diff --git a/internal/quic/frame_debug.go b/internal/quic/frame_debug.go new file mode 100644 index 000000000..fa9bdca06 --- /dev/null +++ b/internal/quic/frame_debug.go @@ -0,0 +1,506 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "fmt" + "time" +) + +// A debugFrame is a representation of the contents of a QUIC frame, +// used for debug logs and testing but not the primary serving path. +type debugFrame interface { + String() string + write(w *packetWriter) bool +} + +func parseDebugFrame(b []byte) (f debugFrame, n int) { + if len(b) == 0 { + return nil, -1 + } + switch b[0] { + case frameTypePadding: + f, n = parseDebugFramePadding(b) + case frameTypePing: + f, n = parseDebugFramePing(b) + case frameTypeAck, frameTypeAckECN: + f, n = parseDebugFrameAck(b) + case frameTypeResetStream: + f, n = parseDebugFrameResetStream(b) + case frameTypeStopSending: + f, n = parseDebugFrameStopSending(b) + case frameTypeCrypto: + f, n = parseDebugFrameCrypto(b) + case frameTypeNewToken: + f, n = parseDebugFrameNewToken(b) + case frameTypeStreamBase, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f: + f, n = parseDebugFrameStream(b) + case frameTypeMaxData: + f, n = parseDebugFrameMaxData(b) + case frameTypeMaxStreamData: + f, n = parseDebugFrameMaxStreamData(b) + case frameTypeMaxStreamsBidi, frameTypeMaxStreamsUni: + f, n = parseDebugFrameMaxStreams(b) + case frameTypeDataBlocked: + f, n = parseDebugFrameDataBlocked(b) + case frameTypeStreamDataBlocked: + f, n = parseDebugFrameStreamDataBlocked(b) + case frameTypeStreamsBlockedBidi, frameTypeStreamsBlockedUni: + f, n = parseDebugFrameStreamsBlocked(b) + case frameTypeNewConnectionID: + f, n = parseDebugFrameNewConnectionID(b) + case frameTypeRetireConnectionID: + f, n = parseDebugFrameRetireConnectionID(b) + case frameTypePathChallenge: + f, n = parseDebugFramePathChallenge(b) + case frameTypePathResponse: + f, n = parseDebugFramePathResponse(b) + case frameTypeConnectionCloseTransport: + f, n = parseDebugFrameConnectionCloseTransport(b) + case frameTypeConnectionCloseApplication: + f, n = parseDebugFrameConnectionCloseApplication(b) + case frameTypeHandshakeDone: + f, n = parseDebugFrameHandshakeDone(b) + default: + return nil, -1 + } + return f, n +} + +// debugFramePadding is a sequence of PADDING frames. +type debugFramePadding struct { + size int +} + +func parseDebugFramePadding(b []byte) (f debugFramePadding, n int) { + for n < len(b) && b[n] == frameTypePadding { + n++ + } + f.size = n + return f, n +} + +func (f debugFramePadding) String() string { + return fmt.Sprintf("PADDING*%v", f.size) +} + +func (f debugFramePadding) write(w *packetWriter) bool { + if w.avail() == 0 { + return false + } + for i := 0; i < f.size && w.avail() > 0; i++ { + w.b = append(w.b, frameTypePadding) + } + return true +} + +// debugFramePing is a PING frame. +type debugFramePing struct{} + +func parseDebugFramePing(b []byte) (f debugFramePing, n int) { + return f, 1 +} + +func (f debugFramePing) String() string { + return "PING" +} + +func (f debugFramePing) write(w *packetWriter) bool { + return w.appendPingFrame() +} + +// debugFrameAck is an ACK frame. +type debugFrameAck struct { + ackDelay time.Duration + ranges []i64range +} + +func parseDebugFrameAck(b []byte) (f debugFrameAck, n int) { + f.ranges = nil + _, f.ackDelay, n = consumeAckFrame(b, ackDelayExponent, func(start, end packetNumber) { + f.ranges = append(f.ranges, i64range{ + start: int64(start), + end: int64(end), + }) + }) + // Ranges are parsed smallest to highest; reverse ranges slice to order them high to low. + for i := 0; i < len(f.ranges)/2; i++ { + j := len(f.ranges) - 1 + f.ranges[i], f.ranges[j] = f.ranges[j], f.ranges[i] + } + return f, n +} + +func (f debugFrameAck) String() string { + s := fmt.Sprintf("ACK Delay=%v", f.ackDelay) + for _, r := range f.ranges { + s += fmt.Sprintf(" [%v,%v)", r.start, r.end) + } + return s +} + +func (f debugFrameAck) write(w *packetWriter) bool { + return w.appendAckFrame(rangeset(f.ranges), ackDelayExponent, f.ackDelay) +} + +// debugFrameResetStream is a RESET_STREAM frame. +type debugFrameResetStream struct { + id streamID + code uint64 + finalSize int64 +} + +func parseDebugFrameResetStream(b []byte) (f debugFrameResetStream, n int) { + f.id, f.code, f.finalSize, n = consumeResetStreamFrame(b) + return f, n +} + +func (f debugFrameResetStream) String() string { + return fmt.Sprintf("RESET_STREAM ID=%v Code=%v FinalSize=%v", f.id, f.code, f.finalSize) +} + +func (f debugFrameResetStream) write(w *packetWriter) bool { + return w.appendResetStreamFrame(f.id, f.code, f.finalSize) +} + +// debugFrameStopSending is a STOP_SENDING frame. +type debugFrameStopSending struct { + id streamID + code uint64 +} + +func parseDebugFrameStopSending(b []byte) (f debugFrameStopSending, n int) { + f.id, f.code, n = consumeStopSendingFrame(b) + return f, n +} + +func (f debugFrameStopSending) String() string { + return fmt.Sprintf("STOP_SENDING ID=%v Code=%v", f.id, f.code) +} + +func (f debugFrameStopSending) write(w *packetWriter) bool { + return w.appendStopSendingFrame(f.id, f.code) +} + +// debugFrameCrypto is a CRYPTO frame. +type debugFrameCrypto struct { + off int64 + data []byte +} + +func parseDebugFrameCrypto(b []byte) (f debugFrameCrypto, n int) { + f.off, f.data, n = consumeCryptoFrame(b) + return f, n +} + +func (f debugFrameCrypto) String() string { + return fmt.Sprintf("CRYPTO Offset=%v Length=%v", f.off, len(f.data)) +} + +func (f debugFrameCrypto) write(w *packetWriter) bool { + b, added := w.appendCryptoFrame(f.off, len(f.data)) + copy(b, f.data) + return added +} + +// debugFrameNewToken is a NEW_TOKEN frame. +type debugFrameNewToken struct { + token []byte +} + +func parseDebugFrameNewToken(b []byte) (f debugFrameNewToken, n int) { + f.token, n = consumeNewTokenFrame(b) + return f, n +} + +func (f debugFrameNewToken) String() string { + return fmt.Sprintf("NEW_TOKEN Token=%x", f.token) +} + +func (f debugFrameNewToken) write(w *packetWriter) bool { + return w.appendNewTokenFrame(f.token) +} + +// debugFrameStream is a STREAM frame. +type debugFrameStream struct { + id streamID + fin bool + off int64 + data []byte +} + +func parseDebugFrameStream(b []byte) (f debugFrameStream, n int) { + f.id, f.off, f.fin, f.data, n = consumeStreamFrame(b) + return f, n +} + +func (f debugFrameStream) String() string { + fin := "" + if f.fin { + fin = " FIN" + } + return fmt.Sprintf("STREAM ID=%v%v Offset=%v Length=%v", f.id, fin, f.off, len(f.data)) +} + +func (f debugFrameStream) write(w *packetWriter) bool { + b, added := w.appendStreamFrame(f.id, f.off, len(f.data), f.fin) + copy(b, f.data) + return added +} + +// debugFrameMaxData is a MAX_DATA frame. +type debugFrameMaxData struct { + max int64 +} + +func parseDebugFrameMaxData(b []byte) (f debugFrameMaxData, n int) { + f.max, n = consumeMaxDataFrame(b) + return f, n +} + +func (f debugFrameMaxData) String() string { + return fmt.Sprintf("MAX_DATA Max=%v", f.max) +} + +func (f debugFrameMaxData) write(w *packetWriter) bool { + return w.appendMaxDataFrame(f.max) +} + +// debugFrameMaxStreamData is a MAX_STREAM_DATA frame. +type debugFrameMaxStreamData struct { + id streamID + max int64 +} + +func parseDebugFrameMaxStreamData(b []byte) (f debugFrameMaxStreamData, n int) { + f.id, f.max, n = consumeMaxStreamDataFrame(b) + return f, n +} + +func (f debugFrameMaxStreamData) String() string { + return fmt.Sprintf("MAX_STREAM_DATA ID=%v Max=%v", f.id, f.max) +} + +func (f debugFrameMaxStreamData) write(w *packetWriter) bool { + return w.appendMaxStreamDataFrame(f.id, f.max) +} + +// debugFrameMaxStreams is a MAX_STREAMS frame. +type debugFrameMaxStreams struct { + streamType streamType + max int64 +} + +func parseDebugFrameMaxStreams(b []byte) (f debugFrameMaxStreams, n int) { + f.streamType, f.max, n = consumeMaxStreamsFrame(b) + return f, n +} + +func (f debugFrameMaxStreams) String() string { + return fmt.Sprintf("MAX_STREAMS Type=%v Max=%v", f.streamType, f.max) +} + +func (f debugFrameMaxStreams) write(w *packetWriter) bool { + return w.appendMaxStreamsFrame(f.streamType, f.max) +} + +// debugFrameDataBlocked is a DATA_BLOCKED frame. +type debugFrameDataBlocked struct { + max int64 +} + +func parseDebugFrameDataBlocked(b []byte) (f debugFrameDataBlocked, n int) { + f.max, n = consumeDataBlockedFrame(b) + return f, n +} + +func (f debugFrameDataBlocked) String() string { + return fmt.Sprintf("DATA_BLOCKED Max=%v", f.max) +} + +func (f debugFrameDataBlocked) write(w *packetWriter) bool { + return w.appendDataBlockedFrame(f.max) +} + +// debugFrameStreamDataBlocked is a STREAM_DATA_BLOCKED frame. +type debugFrameStreamDataBlocked struct { + id streamID + max int64 +} + +func parseDebugFrameStreamDataBlocked(b []byte) (f debugFrameStreamDataBlocked, n int) { + f.id, f.max, n = consumeStreamDataBlockedFrame(b) + return f, n +} + +func (f debugFrameStreamDataBlocked) String() string { + return fmt.Sprintf("STREAM_DATA_BLOCKED ID=%v Max=%v", f.id, f.max) +} + +func (f debugFrameStreamDataBlocked) write(w *packetWriter) bool { + return w.appendStreamDataBlockedFrame(f.id, f.max) +} + +// debugFrameStreamsBlocked is a STREAMS_BLOCKED frame. +type debugFrameStreamsBlocked struct { + streamType streamType + max int64 +} + +func parseDebugFrameStreamsBlocked(b []byte) (f debugFrameStreamsBlocked, n int) { + f.streamType, f.max, n = consumeStreamsBlockedFrame(b) + return f, n +} + +func (f debugFrameStreamsBlocked) String() string { + return fmt.Sprintf("STREAMS_BLOCKED Type=%v Max=%v", f.streamType, f.max) +} + +func (f debugFrameStreamsBlocked) write(w *packetWriter) bool { + return w.appendStreamsBlockedFrame(f.streamType, f.max) +} + +// debugFrameNewConnectionID is a NEW_CONNECTION_ID frame. +type debugFrameNewConnectionID struct { + seq int64 + retirePriorTo int64 + connID []byte + token [16]byte +} + +func parseDebugFrameNewConnectionID(b []byte) (f debugFrameNewConnectionID, n int) { + f.seq, f.retirePriorTo, f.connID, f.token, n = consumeNewConnectionIDFrame(b) + return f, n +} + +func (f debugFrameNewConnectionID) String() string { + return fmt.Sprintf("NEW_CONNECTION_ID Seq=%v Retire=%v ID=%x Token=%x", f.seq, f.retirePriorTo, f.connID, f.token[:]) +} + +func (f debugFrameNewConnectionID) write(w *packetWriter) bool { + return w.appendNewConnectionIDFrame(f.seq, f.retirePriorTo, f.connID, f.token) +} + +// debugFrameRetireConnectionID is a NEW_CONNECTION_ID frame. +type debugFrameRetireConnectionID struct { + seq uint64 + retirePriorTo uint64 + connID []byte + token [16]byte +} + +func parseDebugFrameRetireConnectionID(b []byte) (f debugFrameRetireConnectionID, n int) { + f.seq, n = consumeRetireConnectionIDFrame(b) + return f, n +} + +func (f debugFrameRetireConnectionID) String() string { + return fmt.Sprintf("RETIRE_CONNECTION_ID Seq=%v", f.seq) +} + +func (f debugFrameRetireConnectionID) write(w *packetWriter) bool { + return w.appendRetireConnectionIDFrame(f.seq) +} + +// debugFramePathChallenge is a PATH_CHALLENGE frame. +type debugFramePathChallenge struct { + data uint64 +} + +func parseDebugFramePathChallenge(b []byte) (f debugFramePathChallenge, n int) { + f.data, n = consumePathChallengeFrame(b) + return f, n +} + +func (f debugFramePathChallenge) String() string { + return fmt.Sprintf("PATH_CHALLENGE Data=%016x", f.data) +} + +func (f debugFramePathChallenge) write(w *packetWriter) bool { + return w.appendPathChallengeFrame(f.data) +} + +// debugFramePathResponse is a PATH_RESPONSE frame. +type debugFramePathResponse struct { + data uint64 +} + +func parseDebugFramePathResponse(b []byte) (f debugFramePathResponse, n int) { + f.data, n = consumePathResponseFrame(b) + return f, n +} + +func (f debugFramePathResponse) String() string { + return fmt.Sprintf("PATH_RESPONSE Data=%016x", f.data) +} + +func (f debugFramePathResponse) write(w *packetWriter) bool { + return w.appendPathResponseFrame(f.data) +} + +// debugFrameConnectionCloseTransport is a CONNECTION_CLOSE frame carrying a transport error. +type debugFrameConnectionCloseTransport struct { + code transportError + frameType uint64 + reason string +} + +func parseDebugFrameConnectionCloseTransport(b []byte) (f debugFrameConnectionCloseTransport, n int) { + f.code, f.frameType, f.reason, n = consumeConnectionCloseTransportFrame(b) + return f, n +} + +func (f debugFrameConnectionCloseTransport) String() string { + s := fmt.Sprintf("CONNECTION_CLOSE Code=%v", f.code) + if f.frameType != 0 { + s += fmt.Sprintf(" FrameType=%v", f.frameType) + } + if f.reason != "" { + s += fmt.Sprintf(" Reason=%q", f.reason) + } + return s +} + +func (f debugFrameConnectionCloseTransport) write(w *packetWriter) bool { + return w.appendConnectionCloseTransportFrame(f.code, f.frameType, f.reason) +} + +// debugFrameConnectionCloseApplication is a CONNECTION_CLOSE frame carrying an application error. +type debugFrameConnectionCloseApplication struct { + code uint64 + reason string +} + +func parseDebugFrameConnectionCloseApplication(b []byte) (f debugFrameConnectionCloseApplication, n int) { + f.code, f.reason, n = consumeConnectionCloseApplicationFrame(b) + return f, n +} + +func (f debugFrameConnectionCloseApplication) String() string { + s := fmt.Sprintf("CONNECTION_CLOSE AppCode=%v", f.code) + if f.reason != "" { + s += fmt.Sprintf(" Reason=%q", f.reason) + } + return s +} + +func (f debugFrameConnectionCloseApplication) write(w *packetWriter) bool { + return w.appendConnectionCloseApplicationFrame(f.code, f.reason) +} + +// debugFrameHandshakeDone is a HANDSHAKE_DONE frame. +type debugFrameHandshakeDone struct{} + +func parseDebugFrameHandshakeDone(b []byte) (f debugFrameHandshakeDone, n int) { + return f, 1 +} + +func (f debugFrameHandshakeDone) String() string { + return "HANDSHAKE_DONE" +} + +func (f debugFrameHandshakeDone) write(w *packetWriter) bool { + return w.appendHandshakeDoneFrame() +} diff --git a/internal/quic/packet_codec_test.go b/internal/quic/packet_codec_test.go new file mode 100644 index 000000000..ee533c8ab --- /dev/null +++ b/internal/quic/packet_codec_test.go @@ -0,0 +1,712 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "bytes" + "crypto/tls" + "reflect" + "testing" + "time" +) + +func TestParseLongHeaderPacket(t *testing.T) { + // Example Initial packet from: + // https://www.rfc-editor.org/rfc/rfc9001.html#section-a.3 + cid := unhex(`8394c8f03e515708`) + _, initialServerKeys := initialKeys(cid) + pkt := unhex(` + cf000000010008f067a5502a4262b500 4075c0d95a482cd0991cd25b0aac406a + 5816b6394100f37a1c69797554780bb3 8cc5a99f5ede4cf73c3ec2493a1839b3 + dbcba3f6ea46c5b7684df3548e7ddeb9 c3bf9c73cc3f3bded74b562bfb19fb84 + 022f8ef4cdd93795d77d06edbb7aaf2f 58891850abbdca3d20398c276456cbc4 + 2158407dd074ee + `) + want := longPacket{ + ptype: packetTypeInitial, + version: 1, + num: 1, + dstConnID: []byte{}, + srcConnID: unhex(`f067a5502a4262b5`), + payload: unhex(` + 02000000000600405a020000560303ee fce7f7b37ba1d1632e96677825ddf739 + 88cfc79825df566dc5430b9a045a1200 130100002e00330024001d00209d3c94 + 0d89690b84d08a60993c144eca684d10 81287c834d5311bcf32bb9da1a002b00 + 020304 + `), + extra: []byte{}, + } + + // Parse the packet. + got, n := parseLongHeaderPacket(pkt, initialServerKeys, 0) + if n != len(pkt) { + t.Errorf("parseLongHeaderPacket: n=%v, want %v", n, len(pkt)) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("parseLongHeaderPacket:\n got: %+v\nwant: %+v", got, want) + } + + // Skip the packet. + if got, want := skipLongHeaderPacket(pkt), len(pkt); got != want { + t.Errorf("skipLongHeaderPacket: n=%v, want %v", got, want) + } + + // Parse truncated versions of the packet; every attempt should fail. + for i := 0; i < len(pkt); i++ { + if _, n := parseLongHeaderPacket(pkt[:i], initialServerKeys, 0); n != -1 { + t.Fatalf("parse truncated long header packet: n=%v, want -1\ninput: %x", n, pkt[:i]) + } + if n := skipLongHeaderPacket(pkt[:i]); n != -1 { + t.Errorf("skip truncated long header packet: n=%v, want -1", n) + } + } + + // Parse with the wrong keys. + _, invalidKeys := initialKeys([]byte{}) + if _, n := parseLongHeaderPacket(pkt, invalidKeys, 0); n != -1 { + t.Fatalf("parse long header packet with wrong keys: n=%v, want -1", n) + } +} + +func TestRoundtripEncodeLongPacket(t *testing.T) { + aes128Keys, _ := newKeys(tls.TLS_AES_128_GCM_SHA256, []byte("secret")) + aes256Keys, _ := newKeys(tls.TLS_AES_256_GCM_SHA384, []byte("secret")) + chachaKeys, _ := newKeys(tls.TLS_CHACHA20_POLY1305_SHA256, []byte("secret")) + for _, test := range []struct { + desc string + p longPacket + k keys + }{{ + desc: "Initial, 1-byte number, AES128", + p: longPacket{ + ptype: packetTypeInitial, + version: 0x11223344, + num: 0, // 1-byte encodeing + dstConnID: []byte{1, 2, 3, 4}, + srcConnID: []byte{5, 6, 7, 8}, + payload: []byte("payload"), + extra: []byte("token"), + }, + k: aes128Keys, + }, { + desc: "0-RTT, 2-byte number, AES256", + p: longPacket{ + ptype: packetType0RTT, + version: 0x11223344, + num: 0x100, // 2-byte encoding + dstConnID: []byte{1, 2, 3, 4}, + srcConnID: []byte{5, 6, 7, 8}, + payload: []byte("payload"), + }, + k: aes256Keys, + }, { + desc: "0-RTT, 3-byte number, AES256", + p: longPacket{ + ptype: packetType0RTT, + version: 0x11223344, + num: 0x10000, // 2-byte encoding + dstConnID: []byte{1, 2, 3, 4}, + srcConnID: []byte{5, 6, 7, 8}, + payload: []byte{0}, + }, + k: aes256Keys, + }, { + desc: "Handshake, 4-byte number, ChaCha20Poly1305", + p: longPacket{ + ptype: packetTypeHandshake, + version: 0x11223344, + num: 0x1000000, // 4-byte encoding + dstConnID: []byte{1, 2, 3, 4}, + srcConnID: []byte{5, 6, 7, 8}, + payload: []byte("payload"), + }, + k: chachaKeys, + }} { + t.Run(test.desc, func(t *testing.T) { + var w packetWriter + w.reset(1200) + w.startProtectedLongHeaderPacket(0, test.p) + w.b = append(w.b, test.p.payload...) + w.finishProtectedLongHeaderPacket(0, test.k, test.p) + pkt := w.datagram() + + got, n := parseLongHeaderPacket(pkt, test.k, 0) + if n != len(pkt) { + t.Errorf("parseLongHeaderPacket: n=%v, want %v", n, len(pkt)) + } + if !reflect.DeepEqual(got, test.p) { + t.Errorf("Round-trip encode/decode did not preserve packet.\nsent: %+v\n got: %+v\nwire: %x", test.p, got, pkt) + } + }) + } +} + +func TestRoundtripEncodeShortPacket(t *testing.T) { + aes128Keys, _ := newKeys(tls.TLS_AES_128_GCM_SHA256, []byte("secret")) + aes256Keys, _ := newKeys(tls.TLS_AES_256_GCM_SHA384, []byte("secret")) + chachaKeys, _ := newKeys(tls.TLS_CHACHA20_POLY1305_SHA256, []byte("secret")) + connID := make([]byte, connIDLen) + for i := range connID { + connID[i] = byte(i) + } + for _, test := range []struct { + desc string + num packetNumber + payload []byte + k keys + }{{ + desc: "1-byte number, AES128", + num: 0, // 1-byte encoding, + payload: []byte("payload"), + k: aes128Keys, + }, { + desc: "2-byte number, AES256", + num: 0x100, // 2-byte encoding + payload: []byte("payload"), + k: aes256Keys, + }, { + desc: "3-byte number, ChaCha20Poly1305", + num: 0x10000, // 3-byte encoding + payload: []byte("payload"), + k: chachaKeys, + }, { + desc: "4-byte number, ChaCha20Poly1305", + num: 0x1000000, // 4-byte encoding + payload: []byte{0}, + k: chachaKeys, + }} { + t.Run(test.desc, func(t *testing.T) { + var w packetWriter + w.reset(1200) + w.start1RTTPacket(test.num, 0, connID) + w.b = append(w.b, test.payload...) + w.finish1RTTPacket(test.num, 0, connID, test.k) + pkt := w.datagram() + p, n := parse1RTTPacket(pkt, test.k, 0) + if n != len(pkt) { + t.Errorf("parse1RTTPacket: n=%v, want %v", n, len(pkt)) + } + if p.num != test.num || !bytes.Equal(p.payload, test.payload) { + t.Errorf("Round-trip encode/decode did not preserve packet.\nsent: num=%v, payload={%x}\ngot: num=%v, payload={%x}", test.num, test.payload, p.num, p.payload) + } + }) + } +} + +func TestFrameEncodeDecode(t *testing.T) { + for _, test := range []struct { + s string + f debugFrame + b []byte + truncated []byte + }{{ + s: "PADDING*1", + f: debugFramePadding{ + size: 1, + }, + b: []byte{ + 0x00, // Type (i) = 0x00, + + }, + }, { + s: "PING", + f: debugFramePing{}, + b: []byte{ + 0x01, // TYPE(i) = 0x01 + }, + }, { + s: "ACK Delay=80µs [0,16) [17,32) [48,64)", + f: debugFrameAck{ + ackDelay: (10 << ackDelayExponent) * time.Microsecond, + ranges: []i64range{ + {0x00, 0x10}, + {0x11, 0x20}, + {0x30, 0x40}, + }, + }, + b: []byte{ + 0x02, // TYPE (i) = 0x02 + 0x3f, // Largest Acknowledged (i) + 10, // ACK Delay (i) + 0x02, // ACK Range Count (i) + 0x0f, // First ACK Range (i) + 0x0f, // Gap (i) + 0x0e, // ACK Range Length (i) + 0x00, // Gap (i) + 0x0f, // ACK Range Length (i) + }, + truncated: []byte{ + 0x02, // TYPE (i) = 0x02 + 0x3f, // Largest Acknowledged (i) + 10, // ACK Delay (i) + 0x01, // ACK Range Count (i) + 0x0f, // First ACK Range (i) + 0x0f, // Gap (i) + 0x0e, // ACK Range Length (i) + }, + }, { + s: "RESET_STREAM ID=1 Code=2 FinalSize=3", + f: debugFrameResetStream{ + id: 1, + code: 2, + finalSize: 3, + }, + b: []byte{ + 0x04, // TYPE(i) = 0x04 + 0x01, // Stream ID (i), + 0x02, // Application Protocol Error Code (i), + 0x03, // Final Size (i), + }, + }, { + s: "STOP_SENDING ID=1 Code=2", + f: debugFrameStopSending{ + id: 1, + code: 2, + }, + b: []byte{ + 0x05, // TYPE(i) = 0x05 + 0x01, // Stream ID (i), + 0x02, // Application Protocol Error Code (i), + }, + }, { + s: "CRYPTO Offset=1 Length=2", + f: debugFrameCrypto{ + off: 1, + data: []byte{3, 4}, + }, + b: []byte{ + 0x06, // Type (i) = 0x06, + 0x01, // Offset (i), + 0x02, // Length (i), + 0x03, 0x04, // Crypto Data (..), + }, + truncated: []byte{ + 0x06, // Type (i) = 0x06, + 0x01, // Offset (i), + 0x01, // Length (i), + 0x03, + }, + }, { + s: "NEW_TOKEN Token=0304", + f: debugFrameNewToken{ + token: []byte{3, 4}, + }, + b: []byte{ + 0x07, // Type (i) = 0x07, + 0x02, // Token Length (i), + 0x03, 0x04, // Token (..), + }, + }, { + s: "STREAM ID=1 Offset=0 Length=0", + f: debugFrameStream{ + id: 1, + fin: false, + off: 0, + data: []byte{}, + }, + b: []byte{ + 0x0a, // Type (i) = 0x08..0x0f, + 0x01, // Stream ID (i), + // [Offset (i)], + 0x00, // [Length (i)], + // Stream Data (..), + }, + }, { + s: "STREAM ID=100 Offset=4 Length=3", + f: debugFrameStream{ + id: 100, + fin: false, + off: 4, + data: []byte{0xa0, 0xa1, 0xa2}, + }, + b: []byte{ + 0x0e, // Type (i) = 0x08..0x0f, + 0x40, 0x64, // Stream ID (i), + 0x04, // [Offset (i)], + 0x03, // [Length (i)], + 0xa0, 0xa1, 0xa2, // Stream Data (..), + }, + truncated: []byte{ + 0x0e, // Type (i) = 0x08..0x0f, + 0x40, 0x64, // Stream ID (i), + 0x04, // [Offset (i)], + 0x02, // [Length (i)], + 0xa0, 0xa1, // Stream Data (..), + }, + }, { + s: "STREAM ID=100 FIN Offset=4 Length=3", + f: debugFrameStream{ + id: 100, + fin: true, + off: 4, + data: []byte{0xa0, 0xa1, 0xa2}, + }, + b: []byte{ + 0x0f, // Type (i) = 0x08..0x0f, + 0x40, 0x64, // Stream ID (i), + 0x04, // [Offset (i)], + 0x03, // [Length (i)], + 0xa0, 0xa1, 0xa2, // Stream Data (..), + }, + truncated: []byte{ + 0x0e, // Type (i) = 0x08..0x0f, + 0x40, 0x64, // Stream ID (i), + 0x04, // [Offset (i)], + 0x02, // [Length (i)], + 0xa0, 0xa1, // Stream Data (..), + }, + }, { + s: "STREAM ID=1 FIN Offset=100 Length=0", + f: debugFrameStream{ + id: 1, + fin: true, + off: 100, + data: []byte{}, + }, + b: []byte{ + 0x0f, // Type (i) = 0x08..0x0f, + 0x01, // Stream ID (i), + 0x40, 0x64, // [Offset (i)], + 0x00, // [Length (i)], + // Stream Data (..), + }, + }, { + s: "MAX_DATA Max=10", + f: debugFrameMaxData{ + max: 10, + }, + b: []byte{ + 0x10, // Type (i) = 0x10, + 0x0a, // Maximum Data (i), + }, + }, { + s: "MAX_STREAM_DATA ID=1 Max=10", + f: debugFrameMaxStreamData{ + id: 1, + max: 10, + }, + b: []byte{ + 0x11, // Type (i) = 0x11, + 0x01, // Stream ID (i), + 0x0a, // Maximum Stream Data (i), + }, + }, { + s: "MAX_STREAMS Type=bidi Max=1", + f: debugFrameMaxStreams{ + streamType: bidiStream, + max: 1, + }, + b: []byte{ + 0x12, // Type (i) = 0x12..0x13, + 0x01, // Maximum Streams (i), + }, + }, { + s: "MAX_STREAMS Type=uni Max=1", + f: debugFrameMaxStreams{ + streamType: uniStream, + max: 1, + }, + b: []byte{ + 0x13, // Type (i) = 0x12..0x13, + 0x01, // Maximum Streams (i), + }, + }, { + s: "DATA_BLOCKED Max=1", + f: debugFrameDataBlocked{ + max: 1, + }, + b: []byte{ + 0x14, // Type (i) = 0x14, + 0x01, // Maximum Data (i), + }, + }, { + s: "STREAM_DATA_BLOCKED ID=1 Max=2", + f: debugFrameStreamDataBlocked{ + id: 1, + max: 2, + }, + b: []byte{ + 0x15, // Type (i) = 0x15, + 0x01, // Stream ID (i), + 0x02, // Maximum Stream Data (i), + }, + }, { + s: "STREAMS_BLOCKED Type=bidi Max=1", + f: debugFrameStreamsBlocked{ + streamType: bidiStream, + max: 1, + }, + b: []byte{ + 0x16, // Type (i) = 0x16..0x17, + 0x01, // Maximum Streams (i), + }, + }, { + s: "STREAMS_BLOCKED Type=uni Max=1", + f: debugFrameStreamsBlocked{ + streamType: uniStream, + max: 1, + }, + b: []byte{ + 0x17, // Type (i) = 0x16..0x17, + 0x01, // Maximum Streams (i), + }, + }, { + s: "NEW_CONNECTION_ID Seq=3 Retire=2 ID=a0a1a2a3 Token=0102030405060708090a0b0c0d0e0f10", + f: debugFrameNewConnectionID{ + seq: 3, + retirePriorTo: 2, + connID: []byte{0xa0, 0xa1, 0xa2, 0xa3}, + token: [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + }, + b: []byte{ + 0x18, // Type (i) = 0x18, + 0x03, // Sequence Number (i), + 0x02, // Retire Prior To (i), + 0x04, // Length (8), + 0xa0, 0xa1, 0xa2, 0xa3, // Connection ID (8..160), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, // Stateless Reset Token (128), + }, + }, { + s: "RETIRE_CONNECTION_ID Seq=1", + f: debugFrameRetireConnectionID{ + seq: 1, + }, + b: []byte{ + 0x19, // Type (i) = 0x19, + 0x01, // Sequence Number (i), + }, + }, { + s: "PATH_CHALLENGE Data=0123456789abcdef", + f: debugFramePathChallenge{ + data: 0x0123456789abcdef, + }, + b: []byte{ + 0x1a, // Type (i) = 0x1a, + 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, // Data (64), + }, + }, { + s: "PATH_RESPONSE Data=0123456789abcdef", + f: debugFramePathResponse{ + data: 0x0123456789abcdef, + }, + b: []byte{ + 0x1b, // Type (i) = 0x1b, + 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, // Data (64), + }, + }, { + s: `CONNECTION_CLOSE Code=INTERNAL_ERROR FrameType=2 Reason="oops"`, + f: debugFrameConnectionCloseTransport{ + code: 1, + frameType: 2, + reason: "oops", + }, + b: []byte{ + 0x1c, // Type (i) = 0x1c..0x1d, + 0x01, // Error Code (i), + 0x02, // [Frame Type (i)], + 0x04, // Reason Phrase Length (i), + 'o', 'o', 'p', 's', // Reason Phrase (..), + }, + }, { + s: `CONNECTION_CLOSE AppCode=1 Reason="oops"`, + f: debugFrameConnectionCloseApplication{ + code: 1, + reason: "oops", + }, + b: []byte{ + 0x1d, // Type (i) = 0x1c..0x1d, + 0x01, // Error Code (i), + 0x04, // Reason Phrase Length (i), + 'o', 'o', 'p', 's', // Reason Phrase (..), + }, + }, { + s: "HANDSHAKE_DONE", + f: debugFrameHandshakeDone{}, + b: []byte{ + 0x1e, // Type (i) = 0x1e, + }, + }} { + var w packetWriter + w.reset(1200) + w.start1RTTPacket(0, 0, nil) + w.pktLim = w.payOff + len(test.b) + if added := test.f.write(&w); !added { + t.Errorf("encoding %v with %v bytes available: write unexpectedly failed", test.f, len(test.b)) + } + if got, want := w.payload(), test.b; !bytes.Equal(got, want) { + t.Errorf("encoding %v:\ngot {%x}\nwant {%x}", test.f, got, want) + } + gotf, n := parseDebugFrame(test.b) + if n != len(test.b) || !reflect.DeepEqual(gotf, test.f) { + t.Errorf("decoding {%x}:\ndecoded %v bytes, want %v\ngot: %v\nwant: %v", test.b, n, len(test.b), gotf, test.f) + } + if got, want := test.f.String(), test.s; got != want { + t.Errorf("frame.String():\ngot %q\nwant %q", got, want) + } + + // Try encoding the frame into too little space. + // Most frames will result in an error; some (like STREAM frames) will truncate + // the data written. + w.reset(1200) + w.start1RTTPacket(0, 0, nil) + w.pktLim = w.payOff + len(test.b) - 1 + if added := test.f.write(&w); added { + if test.truncated == nil { + t.Errorf("encoding %v with %v-1 bytes available: write unexpectedly succeeded", test.f, len(test.b)) + } else if got, want := w.payload(), test.truncated; !bytes.Equal(got, want) { + t.Errorf("encoding %v with %v-1 bytes available:\ngot {%x}\nwant {%x}", test.f, len(test.b), got, want) + } + } + + // Try parsing truncated data. + for i := 0; i < len(test.b); i++ { + f, n := parseDebugFrame(test.b[:i]) + if n >= 0 { + t.Errorf("decoding truncated frame {%x}:\ngot: %v\nwant error", test.b[:i], f) + } + } + } +} + +func TestFrameDecode(t *testing.T) { + for _, test := range []struct { + desc string + want debugFrame + b []byte + }{{ + desc: "STREAM frame with LEN bit unset", + want: debugFrameStream{ + id: 1, + fin: false, + data: []byte{0x01, 0x02, 0x03}, + }, + b: []byte{ + 0x08, // Type (i) = 0x08..0x0f, + 0x01, // Stream ID (i), + // [Offset (i)], + // [Length (i)], + 0x01, 0x02, 0x03, // Stream Data (..), + }, + }, { + desc: "ACK frame with ECN counts", + want: debugFrameAck{ + ackDelay: (10 << ackDelayExponent) * time.Microsecond, + ranges: []i64range{ + {0, 1}, + }, + }, + b: []byte{ + 0x03, // TYPE (i) = 0x02..0x03 + 0x00, // Largest Acknowledged (i) + 10, // ACK Delay (i) + 0x00, // ACK Range Count (i) + 0x00, // First ACK Range (i) + 0x01, 0x02, 0x03, // [ECN Counts (..)], + }, + }} { + got, n := parseDebugFrame(test.b) + if n != len(test.b) || !reflect.DeepEqual(got, test.want) { + t.Errorf("decoding {%x}:\ndecoded %v bytes, want %v\ngot: %v\nwant: %v", test.b, n, len(test.b), got, test.want) + } + } +} + +func TestFrameDecodeErrors(t *testing.T) { + for _, test := range []struct { + name string + b []byte + }{{ + name: "ACK [-1,0]", + b: []byte{ + 0x02, // TYPE (i) = 0x02 + 0x00, // Largest Acknowledged (i) + 0x00, // ACK Delay (i) + 0x00, // ACK Range Count (i) + 0x01, // First ACK Range (i) + }, + }, { + name: "ACK [-1,16]", + b: []byte{ + 0x02, // TYPE (i) = 0x02 + 0x10, // Largest Acknowledged (i) + 0x00, // ACK Delay (i) + 0x00, // ACK Range Count (i) + 0x11, // First ACK Range (i) + }, + }, { + name: "ACK [-1,0],[1,2]", + b: []byte{ + 0x02, // TYPE (i) = 0x02 + 0x02, // Largest Acknowledged (i) + 0x00, // ACK Delay (i) + 0x01, // ACK Range Count (i) + 0x00, // First ACK Range (i) + 0x01, // Gap (i) + 0x01, // ACK Range Length (i) + }, + }, { + name: "NEW_TOKEN with zero-length token", + b: []byte{ + 0x07, // Type (i) = 0x07, + 0x00, // Token Length (i), + }, + }, { + name: "MAX_STREAMS with too many streams", + b: func() []byte { + // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.11-5.2.1 + return appendVarint([]byte{frameTypeMaxStreamsBidi}, (1<<60)+1) + }(), + }, { + name: "NEW_CONNECTION_ID too small", + b: []byte{ + 0x18, // Type (i) = 0x18, + 0x03, // Sequence Number (i), + 0x02, // Retire Prior To (i), + 0x00, // Length (8), + // Connection ID (8..160), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, // Stateless Reset Token (128), + }, + }, { + name: "NEW_CONNECTION_ID too large", + b: []byte{ + 0x18, // Type (i) = 0x18, + 0x03, // Sequence Number (i), + 0x02, // Retire Prior To (i), + 21, // Length (8), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, // Connection ID (8..160), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, // Stateless Reset Token (128), + }, + }, { + name: "NEW_CONNECTION_ID sequence smaller than retire", + b: []byte{ + 0x18, // Type (i) = 0x18, + 0x02, // Sequence Number (i), + 0x03, // Retire Prior To (i), + 0x02, // Length (8), + 0xff, 0xff, // Connection ID (8..160), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, // Stateless Reset Token (128), + }, + }} { + f, n := parseDebugFrame(test.b) + if n >= 0 { + t.Errorf("%v: no error when parsing invalid frame {%x}\ngot: %v", test.name, test.b, f) + } + } +} + +func FuzzParseLongHeaderPacket(f *testing.F) { + cid := unhex(`0000000000000000`) + _, initialServerKeys := initialKeys(cid) + f.Fuzz(func(t *testing.T, in []byte) { + parseLongHeaderPacket(in, initialServerKeys, 0) + }) +} + +func FuzzFrameDecode(f *testing.F) { + f.Fuzz(func(t *testing.T, in []byte) { + parseDebugFrame(in) + }) +} diff --git a/internal/quic/packet_parser.go b/internal/quic/packet_parser.go new file mode 100644 index 000000000..d06b601d8 --- /dev/null +++ b/internal/quic/packet_parser.go @@ -0,0 +1,524 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "time" +) + +// parseLongHeaderPacket parses a QUIC long header packet. +// +// It does not parse Version Negotiation packets. +// +// On input, pkt contains a long header packet (possibly followed by more packets), +// k the decryption keys for the packet, and pnumMax the largest packet number seen +// in the number space of this packet. +// +// parseLongHeaderPacket returns the parsed packet with protection removed +// and its length in bytes. +// +// It returns an empty packet and -1 if the packet could not be parsed. +func parseLongHeaderPacket(pkt []byte, k keys, pnumMax packetNumber) (p longPacket, n int) { + if len(pkt) < 5 || !isLongHeader(pkt[0]) { + return longPacket{}, -1 + } + + // Header Form (1) = 1, + // Fixed Bit (1) = 1, + // Long Packet Type (2), + // Type-Specific Bits (4), + b := pkt + p.ptype = getPacketType(b) + if p.ptype == packetTypeInvalid { + return longPacket{}, -1 + } + b = b[1:] + // Version (32), + p.version, n = consumeUint32(b) + if n < 0 { + return longPacket{}, -1 + } + b = b[n:] + if p.version == 0 { + // Version Negotiation packet; not handled here. + return longPacket{}, -1 + } + + // Destination Connection ID Length (8), + // Destination Connection ID (0..160), + p.dstConnID, n = consumeUint8Bytes(b) + if n < 0 || len(p.dstConnID) > 20 { + return longPacket{}, -1 + } + b = b[n:] + + // Source Connection ID Length (8), + // Source Connection ID (0..160), + p.srcConnID, n = consumeUint8Bytes(b) + if n < 0 || len(p.dstConnID) > 20 { + return longPacket{}, -1 + } + b = b[n:] + + switch p.ptype { + case packetTypeInitial: + // Token Length (i), + // Token (..), + p.extra, n = consumeVarintBytes(b) + if n < 0 { + return longPacket{}, -1 + } + b = b[n:] + case packetTypeRetry: + // Retry Token (..), + // Retry Integrity Tag (128), + p.extra = b + return p, len(pkt) + } + + // Length (i), + payLen, n := consumeVarint(b) + if n < 0 { + return longPacket{}, -1 + } + b = b[n:] + if uint64(len(b)) < payLen { + return longPacket{}, -1 + } + + // Packet Number (8..32), + // Packet Payload (..), + pnumOff := len(pkt) - len(b) + pkt = pkt[:pnumOff+int(payLen)] + + if k.initialized() { + var err error + p.payload, p.num, err = k.unprotect(pkt, pnumOff, pnumMax) + if err != nil { + return longPacket{}, -1 + } + // Reserved bits should always be zero, but this is handled + // as a protocol-level violation by the caller rather than a parse error. + p.reservedBits = pkt[0] & reservedBits + } + return p, len(pkt) +} + +// skipLongHeaderPacket returns the length of the long header packet at the start of pkt, +// or -1 if the buffer does not contain a valid packet. +func skipLongHeaderPacket(pkt []byte) int { + // Header byte, 4 bytes of version. + n := 5 + if len(pkt) <= n { + return -1 + } + // Destination connection ID length, destination connection ID. + n += 1 + int(pkt[n]) + if len(pkt) <= n { + return -1 + } + // Source connection ID length, source connection ID. + n += 1 + int(pkt[n]) + if len(pkt) <= n { + return -1 + } + if getPacketType(pkt) == packetTypeInitial { + // Token length, token. + _, nn := consumeVarintBytes(pkt[n:]) + if nn < 0 { + return -1 + } + n += nn + } + // Length, packet number, payload. + _, nn := consumeVarintBytes(pkt[n:]) + if nn < 0 { + return -1 + } + n += nn + if len(pkt) < n { + return -1 + } + return n +} + +// parse1RTTPacket parses a QUIC 1-RTT (short header) packet. +// +// On input, pkt contains a short header packet, k the decryption keys for the packet, +// and pnumMax the largest packet number seen in the number space of this packet. +func parse1RTTPacket(pkt []byte, k keys, pnumMax packetNumber) (p shortPacket, n int) { + var err error + p.payload, p.num, err = k.unprotect(pkt, 1+connIDLen, pnumMax) + if err != nil { + return shortPacket{}, -1 + } + // Reserved bits should always be zero, but this is handled + // as a protocol-level violation by the caller rather than a parse error. + p.reservedBits = pkt[0] & reservedBits + return p, len(pkt) +} + +// Consume functions return n=-1 on conditions which result in FRAME_ENCODING_ERROR, +// which includes both general parse failures and specific violations of frame +// constraints. + +func consumeAckFrame(frame []byte, ackDelayExponent uint8, f func(start, end packetNumber)) (largest packetNumber, ackDelay time.Duration, n int) { + b := frame[1:] // type + + largestAck, n := consumeVarint(b) + if n < 0 { + return 0, 0, -1 + } + b = b[n:] + + ackDelayScaled, n := consumeVarint(b) + if n < 0 { + return 0, 0, -1 + } + b = b[n:] + ackDelay = time.Duration(ackDelayScaled*(1< rangeMax { + return 0, 0, -1 + } + f(rangeMin, rangeMax+1) + + if i == ackRangeCount { + break + } + + gap, n := consumeVarint(b) + if n < 0 { + return 0, 0, -1 + } + b = b[n:] + + rangeMax = rangeMin - packetNumber(gap) - 2 + } + + if frame[0] != frameTypeAckECN { + return packetNumber(largestAck), ackDelay, len(frame) - len(b) + } + + ect0Count, n := consumeVarint(b) + if n < 0 { + return 0, 0, -1 + } + b = b[n:] + ect1Count, n := consumeVarint(b) + if n < 0 { + return 0, 0, -1 + } + b = b[n:] + ecnCECount, n := consumeVarint(b) + if n < 0 { + return 0, 0, -1 + } + b = b[n:] + + // TODO: Make use of ECN feedback. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.3.2 + _ = ect0Count + _ = ect1Count + _ = ecnCECount + + return packetNumber(largestAck), ackDelay, len(frame) - len(b) +} + +func consumeResetStreamFrame(b []byte) (id streamID, code uint64, finalSize int64, n int) { + n = 1 + idInt, nn := consumeVarint(b[n:]) + if nn < 0 { + return 0, 0, 0, -1 + } + n += nn + code, nn = consumeVarint(b[n:]) + if nn < 0 { + return 0, 0, 0, -1 + } + n += nn + v, nn := consumeVarint(b[n:]) + if nn < 0 { + return 0, 0, 0, -1 + } + n += nn + finalSize = int64(v) + return streamID(idInt), code, finalSize, n +} + +func consumeStopSendingFrame(b []byte) (id streamID, code uint64, n int) { + n = 1 + idInt, nn := consumeVarint(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + code, nn = consumeVarint(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + return streamID(idInt), code, n +} + +func consumeCryptoFrame(b []byte) (off int64, data []byte, n int) { + n = 1 + v, nn := consumeVarint(b[n:]) + if nn < 0 { + return 0, nil, -1 + } + off = int64(v) + n += nn + data, nn = consumeVarintBytes(b[n:]) + if nn < 0 { + return 0, nil, -1 + } + n += nn + return off, data, n +} + +func consumeNewTokenFrame(b []byte) (token []byte, n int) { + n = 1 + data, nn := consumeVarintBytes(b[n:]) + if nn < 0 { + return nil, -1 + } + if len(data) == 0 { + return nil, -1 + } + n += nn + return data, n +} + +func consumeStreamFrame(b []byte) (id streamID, off int64, fin bool, data []byte, n int) { + fin = (b[0] & 0x01) != 0 + n = 1 + idInt, nn := consumeVarint(b[n:]) + if nn < 0 { + return 0, 0, false, nil, -1 + } + n += nn + if b[0]&0x04 != 0 { + v, nn := consumeVarint(b[n:]) + if nn < 0 { + return 0, 0, false, nil, -1 + } + n += nn + off = int64(v) + } + if b[0]&0x02 != 0 { + data, nn = consumeVarintBytes(b[n:]) + if nn < 0 { + return 0, 0, false, nil, -1 + } + n += nn + } else { + data = b[n:] + n += len(data) + } + return streamID(idInt), off, fin, data, n +} + +func consumeMaxDataFrame(b []byte) (max int64, n int) { + n = 1 + v, nn := consumeVarint(b[n:]) + if nn < 0 { + return 0, -1 + } + n += nn + return int64(v), n +} + +func consumeMaxStreamDataFrame(b []byte) (id streamID, max int64, n int) { + n = 1 + v, nn := consumeVarint(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + id = streamID(v) + v, nn = consumeVarint(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + max = int64(v) + return id, max, n +} + +func consumeMaxStreamsFrame(b []byte) (typ streamType, max int64, n int) { + switch b[0] { + case frameTypeMaxStreamsBidi: + typ = bidiStream + case frameTypeMaxStreamsUni: + typ = uniStream + default: + return 0, 0, -1 + } + n = 1 + v, nn := consumeVarint(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + if v > 1<<60 { + return 0, 0, -1 + } + return typ, int64(v), n +} + +func consumeStreamDataBlockedFrame(b []byte) (id streamID, max int64, n int) { + n = 1 + v, nn := consumeVarint(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + id = streamID(v) + max, nn = consumeVarintInt64(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + return id, max, n +} + +func consumeDataBlockedFrame(b []byte) (max int64, n int) { + n = 1 + max, nn := consumeVarintInt64(b[n:]) + if nn < 0 { + return 0, -1 + } + n += nn + return max, n +} + +func consumeStreamsBlockedFrame(b []byte) (typ streamType, max int64, n int) { + if b[0] == frameTypeStreamsBlockedBidi { + typ = bidiStream + } else { + typ = uniStream + } + n = 1 + max, nn := consumeVarintInt64(b[n:]) + if nn < 0 { + return 0, 0, -1 + } + n += nn + return typ, max, n +} + +func consumeNewConnectionIDFrame(b []byte) (seq, retire int64, connID []byte, resetToken [16]byte, n int) { + n = 1 + var nn int + seq, nn = consumeVarintInt64(b[n:]) + if nn < 0 { + return 0, 0, nil, [16]byte{}, -1 + } + n += nn + retire, nn = consumeVarintInt64(b[n:]) + if nn < 0 { + return 0, 0, nil, [16]byte{}, -1 + } + n += nn + if seq < retire { + return 0, 0, nil, [16]byte{}, -1 + } + connID, nn = consumeVarintBytes(b[n:]) + if nn < 0 { + return 0, 0, nil, [16]byte{}, -1 + } + if len(connID) < 1 || len(connID) > 20 { + return 0, 0, nil, [16]byte{}, -1 + } + n += nn + if len(b[n:]) < len(resetToken) { + return 0, 0, nil, [16]byte{}, -1 + } + copy(resetToken[:], b[n:]) + n += len(resetToken) + return seq, retire, connID, resetToken, n +} + +func consumeRetireConnectionIDFrame(b []byte) (seq uint64, n int) { + n = 1 + var nn int + seq, nn = consumeVarint(b[n:]) + if nn < 0 { + return 0, -1 + } + n += nn + return seq, n +} + +func consumePathChallengeFrame(b []byte) (data uint64, n int) { + n = 1 + var nn int + data, nn = consumeUint64(b[n:]) + if nn < 0 { + return 0, -1 + } + n += nn + return data, n +} + +func consumePathResponseFrame(b []byte) (data uint64, n int) { + return consumePathChallengeFrame(b) // identical frame format +} + +func consumeConnectionCloseTransportFrame(b []byte) (code transportError, frameType uint64, reason string, n int) { + n = 1 + var nn int + var codeInt uint64 + codeInt, nn = consumeVarint(b[n:]) + if nn < 0 { + return 0, 0, "", -1 + } + code = transportError(codeInt) + n += nn + frameType, nn = consumeVarint(b[n:]) + if nn < 0 { + return 0, 0, "", -1 + } + n += nn + reasonb, nn := consumeVarintBytes(b[n:]) + if nn < 0 { + return 0, 0, "", -1 + } + n += nn + reason = string(reasonb) + return code, frameType, reason, n +} + +func consumeConnectionCloseApplicationFrame(b []byte) (code uint64, reason string, n int) { + n = 1 + var nn int + code, nn = consumeVarint(b[n:]) + if nn < 0 { + return 0, "", -1 + } + n += nn + reasonb, nn := consumeVarintBytes(b[n:]) + if nn < 0 { + return 0, "", -1 + } + n += nn + reason = string(reasonb) + return code, reason, n +} diff --git a/internal/quic/packet_writer.go b/internal/quic/packet_writer.go new file mode 100644 index 000000000..1f9e30f6e --- /dev/null +++ b/internal/quic/packet_writer.go @@ -0,0 +1,548 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import ( + "encoding/binary" + "time" +) + +// A packetWriter constructs QUIC datagrams. +// +// A datagram consists of one or more packets. +// A packet consists of a header followed by one or more frames. +// +// Packets are written in three steps: +// - startProtectedLongHeaderPacket or start1RTT packet prepare the packet; +// - append*Frame appends frames to the payload; and +// - finishProtectedLongHeaderPacket or finish1RTT finalize the packet. +// +// The start functions are efficient, so we can start speculatively +// writing a packet before we know whether we have any frames to +// put in it. The finish functions will abandon the packet if the +// payload contains no data. +type packetWriter struct { + dgramLim int // max datagram size + pktLim int // max packet size + pktOff int // offset of the start of the current packet + payOff int // offset of the payload of the current packet + b []byte + sent *sentPacket +} + +// reset prepares to write a datagram of at most lim bytes. +func (w *packetWriter) reset(lim int) { + if cap(w.b) < lim { + w.b = make([]byte, 0, lim) + } + w.dgramLim = lim + w.b = w.b[:0] +} + +// datagram returns the current datagram. +func (w *packetWriter) datagram() []byte { + return w.b +} + +// payload returns the payload of the current packet. +func (w *packetWriter) payload() []byte { + return w.b[w.payOff:] +} + +func (w *packetWriter) abandonPacket() { + w.b = w.b[:w.payOff] + w.sent.reset() +} + +// startProtectedLongHeaderPacket starts writing an Initial, 0-RTT, or Handshake packet. +func (w *packetWriter) startProtectedLongHeaderPacket(pnumMaxAcked packetNumber, p longPacket) { + if w.sent == nil { + w.sent = newSentPacket() + } + w.pktOff = len(w.b) + hdrSize := 1 // packet type + hdrSize += 4 // version + hdrSize += 1 + len(p.dstConnID) + hdrSize += 1 + len(p.srcConnID) + switch p.ptype { + case packetTypeInitial: + hdrSize += sizeVarint(uint64(len(p.extra))) + len(p.extra) + } + hdrSize += 2 // length, hardcoded to a 2-byte varint + pnumOff := len(w.b) + hdrSize + hdrSize += packetNumberLength(p.num, pnumMaxAcked) + payOff := len(w.b) + hdrSize + // Check if we have enough space to hold the packet, including the header, + // header protection sample (RFC 9001, section 5.4.2), and encryption overhead. + if pnumOff+4+headerProtectionSampleSize+aeadOverhead >= w.dgramLim { + // Set the limit on the packet size to be the current write buffer length, + // ensuring that any writes to the payload fail. + w.payOff = len(w.b) + w.pktLim = len(w.b) + return + } + w.payOff = payOff + w.pktLim = w.dgramLim - aeadOverhead + // We hardcode the payload length field to be 2 bytes, which limits the payload + // (including the packet number) to 16383 bytes (the largest 2-byte QUIC varint). + // + // Most networks don't support datagrams over 1472 bytes, and even Ethernet + // jumbo frames are generally only about 9000 bytes. + if lim := pnumOff + 16383 - aeadOverhead; lim < w.pktLim { + w.pktLim = lim + } + w.b = w.b[:payOff] +} + +// finishProtectedLongHeaderPacket finishes writing an Initial, 0-RTT, or Handshake packet, +// canceling the packet if it contains no payload. +// It returns a sentPacket describing the packet, or nil if no packet was written. +func (w *packetWriter) finishProtectedLongHeaderPacket(pnumMaxAcked packetNumber, k keys, p longPacket) *sentPacket { + if len(w.b) == w.payOff { + // The payload is empty, so just abandon the packet. + w.b = w.b[:w.pktOff] + return nil + } + pnumLen := packetNumberLength(p.num, pnumMaxAcked) + plen := w.padPacketLength(pnumLen) + hdr := w.b[:w.pktOff] + var typeBits byte + switch p.ptype { + case packetTypeInitial: + typeBits = longPacketTypeInitial + case packetType0RTT: + typeBits = longPacketType0RTT + case packetTypeHandshake: + typeBits = longPacketTypeHandshake + case packetTypeRetry: + typeBits = longPacketTypeRetry + } + hdr = append(hdr, headerFormLong|fixedBit|typeBits|byte(pnumLen-1)) + hdr = binary.BigEndian.AppendUint32(hdr, p.version) + hdr = appendUint8Bytes(hdr, p.dstConnID) + hdr = appendUint8Bytes(hdr, p.srcConnID) + switch p.ptype { + case packetTypeInitial: + hdr = appendVarintBytes(hdr, p.extra) // token + } + + // Packet length, always encoded as a 2-byte varint. + hdr = append(hdr, 0x40|byte(plen>>8), byte(plen)) + + pnumOff := len(hdr) + hdr = appendPacketNumber(hdr, p.num, pnumMaxAcked) + + return w.protect(hdr[w.pktOff:], p.num, pnumOff, k) +} + +// start1RTTPacket starts writing a 1-RTT (short header) packet. +func (w *packetWriter) start1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConnID []byte) { + if w.sent == nil { + w.sent = newSentPacket() + } + w.pktOff = len(w.b) + hdrSize := 1 // packet type + hdrSize += len(dstConnID) + // Ensure we have enough space to hold the packet, including the header, + // header protection sample (RFC 9001, section 5.4.2), and encryption overhead. + if len(w.b)+hdrSize+4+headerProtectionSampleSize+aeadOverhead >= w.dgramLim { + w.payOff = len(w.b) + w.pktLim = len(w.b) + return + } + hdrSize += packetNumberLength(pnum, pnumMaxAcked) + w.payOff = len(w.b) + hdrSize + w.pktLim = w.dgramLim - aeadOverhead + w.b = w.b[:w.payOff] +} + +// finish1RTTPacket finishes writing a 1-RTT packet, +// canceling the packet if it contains no payload. +// It returns a sentPacket describing the packet, or nil if no packet was written. +func (w *packetWriter) finish1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConnID []byte, k keys) *sentPacket { + if len(w.b) == w.payOff { + // The payload is empty, so just abandon the packet. + w.b = w.b[:w.pktOff] + return nil + } + // TODO: Spin + // TODO: Key phase + pnumLen := packetNumberLength(pnum, pnumMaxAcked) + hdr := w.b[:w.pktOff] + hdr = append(hdr, 0x40|byte(pnumLen-1)) + hdr = append(hdr, dstConnID...) + pnumOff := len(hdr) + hdr = appendPacketNumber(hdr, pnum, pnumMaxAcked) + w.padPacketLength(pnumLen) + return w.protect(hdr[w.pktOff:], pnum, pnumOff, k) +} + +// padPacketLength pads out the payload of the current packet to the minimum size, +// and returns the combined length of the packet number and payload (used for the Length +// field of long header packets). +func (w *packetWriter) padPacketLength(pnumLen int) int { + plen := len(w.b) - w.payOff + pnumLen + aeadOverhead + // "To ensure that sufficient data is available for sampling, packets are + // padded so that the combined lengths of the encoded packet number and + // protected payload is at least 4 bytes longer than the sample required + // for header protection." + // https://www.rfc-editor.org/rfc/rfc9001.html#section-5.4.2 + for plen < 4+headerProtectionSampleSize { + w.b = append(w.b, 0) + plen++ + } + return plen +} + +// protect applies packet protection and finishes the current packet. +func (w *packetWriter) protect(hdr []byte, pnum packetNumber, pnumOff int, k keys) *sentPacket { + k.protect(hdr, w.b[w.pktOff+len(hdr):], pnumOff-w.pktOff, pnum) + w.b = w.b[:len(w.b)+aeadOverhead] + w.sent.size = len(w.b) - w.pktOff + w.sent.num = pnum + sent := w.sent + w.sent = nil + return sent +} + +// avail reports how many more bytes may be written to the current packet. +func (w *packetWriter) avail() int { + return w.pktLim - len(w.b) +} + +// appendPaddingTo appends PADDING frames until the total datagram size +// (including AEAD overhead of the current packet) is n. +func (w *packetWriter) appendPaddingTo(n int) { + n -= aeadOverhead + lim := w.pktLim + if n < lim { + lim = n + } + if len(w.b) >= lim { + return + } + for len(w.b) < lim { + w.b = append(w.b, frameTypePadding) + } + // Packets are considered in flight when they contain a PADDING frame. + // https://www.rfc-editor.org/rfc/rfc9002.html#section-2-3.6.1 + w.sent.inFlight = true +} + +func (w *packetWriter) appendPingFrame() (added bool) { + if len(w.b) >= w.pktLim { + return false + } + w.b = append(w.b, frameTypePing) + w.sent.appendAckElicitingFrame(frameTypePing) + return true +} + +// appendAckFrame appends an ACK frame to the payload. +// It includes at least the most recent range in the rangeset +// (the range with the largest packet numbers), +// followed by as many additional ranges as fit within the packet. +// +// We always place ACK frames at the start of packets, +// we limit the number of ack ranges retained, and +// we set a minimum packet payload size. +// As a result, appendAckFrame will rarely if ever drop ranges +// in practice. +// +// In the event that ranges are dropped, the impact is limited +// to the peer potentially failing to receive an acknowledgement +// for an older packet during a period of high packet loss or +// reordering. This may result in unnecessary retransmissions. +func (w *packetWriter) appendAckFrame(seen rangeset, ackDelayExponent uint8, delay time.Duration) (added bool) { + if len(seen) == 0 { + return false + } + var ( + largest = uint64(seen.max()) + mdelay = uint64(delay.Microseconds() / (1 << ackDelayExponent)) + firstRange = uint64(seen[len(seen)-1].size() - 1) + ) + if w.avail() < 1+sizeVarint(largest)+sizeVarint(mdelay)+1+sizeVarint(firstRange) { + return false + } + w.b = append(w.b, frameTypeAck) + w.b = appendVarint(w.b, largest) + w.b = appendVarint(w.b, mdelay) + // The range count is technically a varint, but we'll reserve a single byte for it + // and never add more than 62 ranges (the maximum varint that fits in a byte). + rangeCountOff := len(w.b) + w.b = append(w.b, 0) + w.b = appendVarint(w.b, firstRange) + rangeCount := byte(0) + for i := len(seen) - 2; i >= 0; i-- { + gap := uint64(seen[i+1].start - seen[i].end - 1) + size := uint64(seen[i].size() - 1) + if w.avail() < sizeVarint(gap)+sizeVarint(size) || rangeCount > 62 { + break + } + w.b = appendVarint(w.b, gap) + w.b = appendVarint(w.b, size) + rangeCount++ + } + w.b[rangeCountOff] = rangeCount + w.sent.appendNonAckElicitingFrame(frameTypeAck) + w.sent.appendInt(uint64(seen.max())) + return true +} + +func (w *packetWriter) appendNewTokenFrame(token []byte) (added bool) { + if w.avail() < 1+sizeVarint(uint64(len(token)))+len(token) { + return false + } + w.b = append(w.b, frameTypeNewToken) + w.b = appendVarintBytes(w.b, token) + return true +} + +func (w *packetWriter) appendResetStreamFrame(id streamID, code uint64, finalSize int64) (added bool) { + if w.avail() < 1+sizeVarint(uint64(id))+sizeVarint(code)+sizeVarint(uint64(finalSize)) { + return false + } + w.b = append(w.b, frameTypeResetStream) + w.b = appendVarint(w.b, uint64(id)) + w.b = appendVarint(w.b, code) + w.b = appendVarint(w.b, uint64(finalSize)) + w.sent.appendAckElicitingFrame(frameTypeResetStream) + w.sent.appendInt(uint64(id)) + return true +} + +func (w *packetWriter) appendStopSendingFrame(id streamID, code uint64) (added bool) { + if w.avail() < 1+sizeVarint(uint64(id))+sizeVarint(code) { + return false + } + w.b = append(w.b, frameTypeStopSending) + w.b = appendVarint(w.b, uint64(id)) + w.b = appendVarint(w.b, code) + w.sent.appendAckElicitingFrame(frameTypeStopSending) + w.sent.appendInt(uint64(id)) + return true +} + +// appendCryptoFrame appends a CRYPTO frame. +// It returns a []byte into which the data should be written and whether a frame was added. +// The returned []byte may be smaller than size if the packet cannot hold all the data. +func (w *packetWriter) appendCryptoFrame(off int64, size int) (_ []byte, added bool) { + max := w.avail() + max -= 1 // frame type + max -= sizeVarint(uint64(off)) // offset + max -= sizeVarint(uint64(size)) // maximum length + if max <= 0 { + return nil, false + } + if max < size { + size = max + } + w.b = append(w.b, frameTypeCrypto) + w.b = appendVarint(w.b, uint64(off)) + w.b = appendVarint(w.b, uint64(size)) + start := len(w.b) + w.b = w.b[:start+size] + w.sent.appendAckElicitingFrame(frameTypeCrypto) + w.sent.appendOffAndSize(off, size) + return w.b[start:][:size], true +} + +// appendStreamFrame appends a STREAM frame. +// It returns a []byte into which the data should be written and whether a frame was added. +// The returned []byte may be smaller than size if the packet cannot hold all the data. +func (w *packetWriter) appendStreamFrame(id streamID, off int64, size int, fin bool) (_ []byte, added bool) { + typ := uint8(frameTypeStreamBase | streamLenBit) + max := w.avail() + max -= 1 // frame type + max -= sizeVarint(uint64(id)) + if off != 0 { + max -= sizeVarint(uint64(off)) + typ |= streamOffBit + } + max -= sizeVarint(uint64(size)) // maximum length + if max < 0 || (max == 0 && size > 0) { + return nil, false + } + if max < size { + size = max + } else if fin { + typ |= streamFinBit + } + w.b = append(w.b, typ) + w.b = appendVarint(w.b, uint64(id)) + if off != 0 { + w.b = appendVarint(w.b, uint64(off)) + } + w.b = appendVarint(w.b, uint64(size)) + start := len(w.b) + w.b = w.b[:start+size] + if fin { + w.sent.appendAckElicitingFrame(frameTypeStreamBase | streamFinBit) + } else { + w.sent.appendAckElicitingFrame(frameTypeStreamBase) + } + w.sent.appendInt(uint64(id)) + w.sent.appendOffAndSize(off, size) + return w.b[start:][:size], true +} + +func (w *packetWriter) appendMaxDataFrame(max int64) (added bool) { + if w.avail() < 1+sizeVarint(uint64(max)) { + return false + } + w.b = append(w.b, frameTypeMaxData) + w.b = appendVarint(w.b, uint64(max)) + w.sent.appendAckElicitingFrame(frameTypeMaxData) + return true +} + +func (w *packetWriter) appendMaxStreamDataFrame(id streamID, max int64) (added bool) { + if w.avail() < 1+sizeVarint(uint64(id))+sizeVarint(uint64(max)) { + return false + } + w.b = append(w.b, frameTypeMaxStreamData) + w.b = appendVarint(w.b, uint64(id)) + w.b = appendVarint(w.b, uint64(max)) + w.sent.appendAckElicitingFrame(frameTypeMaxStreamData) + w.sent.appendInt(uint64(id)) + return true +} + +func (w *packetWriter) appendMaxStreamsFrame(streamType streamType, max int64) (added bool) { + if w.avail() < 1+sizeVarint(uint64(max)) { + return false + } + var typ byte + if streamType == bidiStream { + typ = frameTypeMaxStreamsBidi + } else { + typ = frameTypeMaxStreamsUni + } + w.b = append(w.b, typ) + w.b = appendVarint(w.b, uint64(max)) + w.sent.appendAckElicitingFrame(typ) + return true +} + +func (w *packetWriter) appendDataBlockedFrame(max int64) (added bool) { + if w.avail() < 1+sizeVarint(uint64(max)) { + return false + } + w.b = append(w.b, frameTypeDataBlocked) + w.b = appendVarint(w.b, uint64(max)) + w.sent.appendAckElicitingFrame(frameTypeDataBlocked) + return true +} + +func (w *packetWriter) appendStreamDataBlockedFrame(id streamID, max int64) (added bool) { + if w.avail() < 1+sizeVarint(uint64(id))+sizeVarint(uint64(max)) { + return false + } + w.b = append(w.b, frameTypeStreamDataBlocked) + w.b = appendVarint(w.b, uint64(id)) + w.b = appendVarint(w.b, uint64(max)) + w.sent.appendAckElicitingFrame(frameTypeStreamDataBlocked) + w.sent.appendInt(uint64(id)) + return true +} + +func (w *packetWriter) appendStreamsBlockedFrame(typ streamType, max int64) (added bool) { + if w.avail() < 1+sizeVarint(uint64(max)) { + return false + } + var ftype byte + if typ == bidiStream { + ftype = frameTypeStreamsBlockedBidi + } else { + ftype = frameTypeStreamsBlockedUni + } + w.b = append(w.b, ftype) + w.b = appendVarint(w.b, uint64(max)) + w.sent.appendAckElicitingFrame(ftype) + return true +} + +func (w *packetWriter) appendNewConnectionIDFrame(seq, retirePriorTo int64, connID []byte, token [16]byte) (added bool) { + if w.avail() < 1+sizeVarint(uint64(seq))+sizeVarint(uint64(retirePriorTo))+1+len(connID)+len(token) { + return false + } + w.b = append(w.b, frameTypeNewConnectionID) + w.b = appendVarint(w.b, uint64(seq)) + w.b = appendVarint(w.b, uint64(retirePriorTo)) + w.b = appendUint8Bytes(w.b, connID) + w.b = append(w.b, token[:]...) + w.sent.appendAckElicitingFrame(frameTypeNewConnectionID) + w.sent.appendInt(uint64(seq)) + return true +} + +func (w *packetWriter) appendRetireConnectionIDFrame(seq uint64) (added bool) { + if w.avail() < 1+sizeVarint(seq) { + return false + } + w.b = append(w.b, frameTypeRetireConnectionID) + w.b = appendVarint(w.b, seq) + w.sent.appendAckElicitingFrame(frameTypeRetireConnectionID) + return true +} + +func (w *packetWriter) appendPathChallengeFrame(data uint64) (added bool) { + if w.avail() < 1+8 { + return false + } + w.b = append(w.b, frameTypePathChallenge) + w.b = binary.BigEndian.AppendUint64(w.b, data) + w.sent.appendAckElicitingFrame(frameTypePathChallenge) + return true +} + +func (w *packetWriter) appendPathResponseFrame(data uint64) (added bool) { + if w.avail() < 1+8 { + return false + } + w.b = append(w.b, frameTypePathResponse) + w.b = binary.BigEndian.AppendUint64(w.b, data) + w.sent.appendAckElicitingFrame(frameTypePathResponse) + return true +} + +// appendConnectionCloseTransportFrame appends a CONNECTION_CLOSE frame +// carrying a transport error code. +func (w *packetWriter) appendConnectionCloseTransportFrame(code transportError, frameType uint64, reason string) (added bool) { + if w.avail() < 1+sizeVarint(uint64(code))+sizeVarint(frameType)+sizeVarint(uint64(len(reason)))+len(reason) { + return false + } + w.b = append(w.b, frameTypeConnectionCloseTransport) + w.b = appendVarint(w.b, uint64(code)) + w.b = appendVarint(w.b, frameType) + w.b = appendVarintBytes(w.b, []byte(reason)) + // We don't record CONNECTION_CLOSE frames in w.sent, since they are never acked or + // detected as lost. + return true +} + +// appendConnectionCloseTransportFrame appends a CONNECTION_CLOSE frame +// carrying an application protocol error code. +func (w *packetWriter) appendConnectionCloseApplicationFrame(code uint64, reason string) (added bool) { + if w.avail() < 1+sizeVarint(code)+sizeVarint(uint64(len(reason)))+len(reason) { + return false + } + w.b = append(w.b, frameTypeConnectionCloseApplication) + w.b = appendVarint(w.b, code) + w.b = appendVarintBytes(w.b, []byte(reason)) + // We don't record CONNECTION_CLOSE frames in w.sent, since they are never acked or + // detected as lost. + return true +} + +func (w *packetWriter) appendHandshakeDoneFrame() (added bool) { + if w.avail() < 1 { + return false + } + w.b = append(w.b, frameTypeHandshakeDone) + w.sent.appendAckElicitingFrame(frameTypeHandshakeDone) + return true +} From f7250ea19d213699fb52bf6b0c0d2cb9cda49d7c Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Fri, 23 Dec 2022 08:37:30 -0800 Subject: [PATCH 20/28] quic: add a type tracking sent values Any given datum communicated to the peer follows a state machine: - We do not need to send the this datum. - We need to send it, but have not done so. - We have sent it, but the peer has not acknowledged it. - We have sent it and the peer has acknowledged it. Data transitions between states in a consistent fashion; for example, loss of the most recent packet containing a HANDSHAKE_DONE frame means we should resend the frame in a new packet. Add a sentVal type which tracks this state machine. For golang/go#58547 Change-Id: I9de0ef5e482534b8733ef66363bac8f6c0fd3173 Reviewed-on: https://go-review.googlesource.com/c/net/+/498295 Run-TryBot: Damien Neil Reviewed-by: Jonathan Amsterdam Auto-Submit: Damien Neil TryBot-Result: Gopher Robot --- internal/quic/sent_val.go | 103 ++++++++++++++++++++ internal/quic/sent_val_test.go | 166 +++++++++++++++++++++++++++++++++ 2 files changed, 269 insertions(+) create mode 100644 internal/quic/sent_val.go create mode 100644 internal/quic/sent_val_test.go diff --git a/internal/quic/sent_val.go b/internal/quic/sent_val.go new file mode 100644 index 000000000..f1a9c9fbc --- /dev/null +++ b/internal/quic/sent_val.go @@ -0,0 +1,103 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +// A sentVal tracks sending some piece of information to the peer. +// It tracks whether the information has been sent, acked, and +// (when in-flight) the most recent packet to carry it. +// +// For example, a sentVal can track sending of a RESET_STREAM frame. +// +// - unset: stream is active, no need to send RESET_STREAM +// - unsent: we should send a RESET_STREAM, but have not yet +// - sent: we have sent a RESET_STREAM, but have not received an ack +// - received: we have sent a RESET_STREAM, and the peer has acked the packet that contained it +// +// In the "sent" state, a sentVal also tracks the latest packet number to carry +// the information. (QUIC packet numbers are always at most 62 bits in size, +// so the sentVal keeps the number in the low 62 bits and the state in the high 2 bits.) +type sentVal uint64 + +const ( + sentValUnset = 0 // unset + sentValUnsent = 1 << 62 // set, not sent to the peer + sentValSent = 2 << 62 // set, sent to the peer but not yet acked; pnum is set + sentValReceived = 3 << 62 // set, peer acked receipt + + sentValStateMask = 3 << 62 +) + +// isSet reports whether the value is set. +func (s sentVal) isSet() bool { return s != 0 } + +// shouldSend reports whether the value is set and has not been sent to the peer. +func (s sentVal) shouldSend() bool { return s.state() == sentValUnsent } + +// shouldSend reports whether the the value needs to be sent to the peer. +// The value needs to be sent if it is set and has not been sent. +// If pto is true, indicating that we are sending a PTO probe, the value +// should also be sent if it is set and has not been acknowledged. +func (s sentVal) shouldSendPTO(pto bool) bool { + st := s.state() + return st == sentValUnsent || (pto && st == sentValSent) +} + +// isReceived reports whether the value has been received by the peer. +func (s sentVal) isReceived() bool { return s == sentValReceived } + +// set sets the value and records that it should be sent to the peer. +// If the value has already been sent, it is not resent. +func (s *sentVal) set() { + if *s == 0 { + *s = sentValUnsent + } +} + +// reset sets the value to the unsent state. +func (s *sentVal) setUnsent() { *s = sentValUnsent } + +// clear sets the value to the unset state. +func (s *sentVal) clear() { *s = sentValUnset } + +// setSent sets the value to the send state and records the number of the most recent +// packet containing the value. +func (s *sentVal) setSent(pnum packetNumber) { + *s = sentValSent | sentVal(pnum) +} + +// setReceived sets the value to the received state. +func (s *sentVal) setReceived() { *s = sentValReceived } + +// ackOrLoss reports that an acknowledgement has been received for the value, +// or (if acked is false) that the packet carrying the value has been lost. +func (s *sentVal) ackOrLoss(pnum packetNumber, acked bool) { + if acked { + *s = sentValReceived + } else if *s == sentVal(pnum)|sentValSent { + *s = sentValUnsent + } +} + +// ackLatestOrLoss reports that an acknowledgement has been received for the value, +// or (if acked is false) that the packet carrying the value has been lost. +// The value is set to the acked state only if pnum is the latest packet containing it. +// +// We use this to handle acks for data that varies every time it is sent. +// For example, if we send a MAX_DATA frame followed by an updated MAX_DATA value in a +// second packet, we consider the data sent only upon receiving an ack for the most +// recent value. +func (s *sentVal) ackLatestOrLoss(pnum packetNumber, acked bool) { + if acked { + if *s == sentVal(pnum)|sentValSent { + *s = sentValReceived + } + } else { + if *s == sentVal(pnum)|sentValSent { + *s = sentValUnsent + } + } +} + +func (s sentVal) state() uint64 { return uint64(s) & sentValStateMask } diff --git a/internal/quic/sent_val_test.go b/internal/quic/sent_val_test.go new file mode 100644 index 000000000..458b221c2 --- /dev/null +++ b/internal/quic/sent_val_test.go @@ -0,0 +1,166 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package quic + +import "testing" + +func TestSentVal(t *testing.T) { + for _, test := range []struct { + name string + f func(*sentVal) + wantIsSet bool + wantShouldSend bool + wantIsReceived bool + wantShouldSendPTO bool + }{{ + name: "zero value", + f: func(*sentVal) {}, + wantIsSet: false, + wantShouldSend: false, + wantShouldSendPTO: false, + wantIsReceived: false, + }, { + name: "v.set()", + f: (*sentVal).set, + wantIsSet: true, + wantShouldSend: true, + wantShouldSendPTO: true, + wantIsReceived: false, + }, { + name: "v.setSent(0)", + f: func(v *sentVal) { + v.setSent(0) + }, + wantIsSet: true, + wantShouldSend: false, + wantShouldSendPTO: true, + wantIsReceived: false, + }, { + name: "sent.set()", + f: func(v *sentVal) { + v.setSent(0) + v.set() + }, + wantIsSet: true, + wantShouldSend: false, + wantShouldSendPTO: true, + wantIsReceived: false, + }, { + name: "sent.setUnsent()", + f: func(v *sentVal) { + v.setSent(0) + v.setUnsent() + }, + wantIsSet: true, + wantShouldSend: true, + wantShouldSendPTO: true, + wantIsReceived: false, + }, { + name: "set.clear()", + f: func(v *sentVal) { + v.set() + v.clear() + }, + wantIsSet: false, + wantShouldSend: false, + wantShouldSendPTO: false, + wantIsReceived: false, + }, { + name: "v.setReceived()", + f: (*sentVal).setReceived, + wantIsSet: true, + wantShouldSend: false, + wantShouldSendPTO: false, + wantIsReceived: true, + }, { + name: "v.ackOrLoss(!pnum, true)", + f: func(v *sentVal) { + v.setSent(1) + v.ackOrLoss(0, true) // ack different packet containing the val + }, + wantIsSet: true, + wantShouldSend: false, + wantShouldSendPTO: false, + wantIsReceived: true, + }, { + name: "v.ackOrLoss(!pnum, false)", + f: func(v *sentVal) { + v.setSent(1) + v.ackOrLoss(0, false) // lose different packet containing the val + }, + wantIsSet: true, + wantShouldSend: false, + wantShouldSendPTO: true, + wantIsReceived: false, + }, { + name: "v.ackOrLoss(pnum, false)", + f: func(v *sentVal) { + v.setSent(1) + v.ackOrLoss(1, false) // lose same packet containing the val + }, + wantIsSet: true, + wantShouldSend: true, + wantShouldSendPTO: true, + wantIsReceived: false, + }, { + name: "v.ackLatestOrLoss(!pnum, true)", + f: func(v *sentVal) { + v.setSent(1) + v.ackLatestOrLoss(0, true) // ack different packet containing the val + }, + wantIsSet: true, + wantShouldSend: false, + wantShouldSendPTO: true, + wantIsReceived: false, + }, { + name: "v.ackLatestOrLoss(pnum, true)", + f: func(v *sentVal) { + v.setSent(1) + v.ackLatestOrLoss(1, true) // ack same packet containing the val + }, + wantIsSet: true, + wantShouldSend: false, + wantShouldSendPTO: false, + wantIsReceived: true, + }, { + name: "v.ackLatestOrLoss(!pnum, false)", + f: func(v *sentVal) { + v.setSent(1) + v.ackLatestOrLoss(0, false) // lose different packet containing the val + }, + wantIsSet: true, + wantShouldSend: false, + wantShouldSendPTO: true, + wantIsReceived: false, + }, { + name: "v.ackLatestOrLoss(pnum, false)", + f: func(v *sentVal) { + v.setSent(1) + v.ackLatestOrLoss(1, false) // lose same packet containing the val + }, + wantIsSet: true, + wantShouldSend: true, + wantShouldSendPTO: true, + wantIsReceived: false, + }} { + var v sentVal + test.f(&v) + if got, want := v.isSet(), test.wantIsSet; got != want { + t.Errorf("%v: v.isSet() = %v, want %v", test.name, got, want) + } + if got, want := v.shouldSend(), test.wantShouldSend; got != want { + t.Errorf("%v: v.shouldSend() = %v, want %v", test.name, got, want) + } + if got, want := v.shouldSendPTO(false), test.wantShouldSend; got != want { + t.Errorf("%v: v.shouldSendPTO(false) = %v, want %v", test.name, got, want) + } + if got, want := v.shouldSendPTO(true), test.wantShouldSendPTO; got != want { + t.Errorf("%v: v.shouldSendPTO(true) = %v, want %v", test.name, got, want) + } + if got, want := v.isReceived(), test.wantIsReceived; got != want { + t.Errorf("%v: v.isReceived() = %v, want %v", test.name, got, want) + } + } +} From f16447cf6cc0d406a38d3fd21ce0c5bab9fdc9c7 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 30 May 2023 15:14:23 -0700 Subject: [PATCH 21/28] quic: add go1.21 build constraint This package will add on crypto/tls features added in Go 1.21, so use a build constraint to restrict ourselves to that version. Unlocks the ability to use other features from Go versions more recent than what's in x/net's go.mod file. For golang/go#58547 Change-Id: I14011c7506b047e389d9b3e995c0bafcd5e74d44 Reviewed-on: https://go-review.googlesource.com/c/net/+/499283 TryBot-Result: Gopher Robot Run-TryBot: Damien Neil Reviewed-by: Jonathan Amsterdam --- internal/quic/errors.go | 2 + internal/quic/files_test.go | 51 +++++++++++++++++++++++++ internal/quic/frame_debug.go | 2 + internal/quic/packet.go | 2 + internal/quic/packet_codec_test.go | 2 + internal/quic/packet_number.go | 2 + internal/quic/packet_number_test.go | 2 + internal/quic/packet_parser.go | 2 + internal/quic/packet_protection.go | 2 + internal/quic/packet_protection_test.go | 2 + internal/quic/packet_test.go | 2 + internal/quic/packet_writer.go | 2 + internal/quic/quic.go | 2 + internal/quic/rangeset.go | 2 + internal/quic/rangeset_test.go | 2 + internal/quic/sent_packet.go | 4 +- internal/quic/sent_packet_test.go | 4 +- internal/quic/sent_val.go | 2 + internal/quic/sent_val_test.go | 2 + internal/quic/wire.go | 2 + internal/quic/wire_test.go | 2 + 21 files changed, 93 insertions(+), 2 deletions(-) create mode 100644 internal/quic/files_test.go diff --git a/internal/quic/errors.go b/internal/quic/errors.go index 725a7daaa..a9ebbe4b7 100644 --- a/internal/quic/errors.go +++ b/internal/quic/errors.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.21 + package quic import ( diff --git a/internal/quic/files_test.go b/internal/quic/files_test.go new file mode 100644 index 000000000..8113109e7 --- /dev/null +++ b/internal/quic/files_test.go @@ -0,0 +1,51 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "bytes" + "os" + "strings" + "testing" +) + +// TestFiles checks that every file in this package has a build constraint on Go 1.21. +// +// The QUIC implementation depends on crypto/tls features added in Go 1.21, +// so there's no point in trying to build on anything older. +// +// Drop this test when the x/net go.mod depends on 1.21 or newer. +func TestFiles(t *testing.T) { + f, err := os.Open(".") + if err != nil { + t.Fatal(err) + } + names, err := f.Readdirnames(-1) + if err != nil { + t.Fatal(err) + } + for _, name := range names { + if !strings.HasSuffix(name, ".go") { + continue + } + b, err := os.ReadFile(name) + if err != nil { + t.Fatal(err) + } + // Check for copyright header while we're in here. + if !bytes.Contains(b, []byte("The Go Authors.")) { + t.Errorf("%v: missing copyright", name) + } + // doc.go doesn't need a build constraint. + if name == "doc.go" { + continue + } + if !bytes.Contains(b, []byte("//go:build go1.21")) { + t.Errorf("%v: missing constraint on go1.21", name) + } + } +} diff --git a/internal/quic/frame_debug.go b/internal/quic/frame_debug.go index fa9bdca06..34a0039ba 100644 --- a/internal/quic/frame_debug.go +++ b/internal/quic/frame_debug.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.21 + package quic import ( diff --git a/internal/quic/packet.go b/internal/quic/packet.go index 4645ae709..93a9102e8 100644 --- a/internal/quic/packet.go +++ b/internal/quic/packet.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.21 + package quic // packetType is a QUIC packet type. diff --git a/internal/quic/packet_codec_test.go b/internal/quic/packet_codec_test.go index ee533c8ab..2a2043b99 100644 --- a/internal/quic/packet_codec_test.go +++ b/internal/quic/packet_codec_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.21 + package quic import ( diff --git a/internal/quic/packet_number.go b/internal/quic/packet_number.go index 9e9f0ad00..206053e58 100644 --- a/internal/quic/packet_number.go +++ b/internal/quic/packet_number.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.21 + package quic // A packetNumber is a QUIC packet number. diff --git a/internal/quic/packet_number_test.go b/internal/quic/packet_number_test.go index 7450e3988..4d8516ae6 100644 --- a/internal/quic/packet_number_test.go +++ b/internal/quic/packet_number_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.21 + package quic import ( diff --git a/internal/quic/packet_parser.go b/internal/quic/packet_parser.go index d06b601d8..cc025b6f3 100644 --- a/internal/quic/packet_parser.go +++ b/internal/quic/packet_parser.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.21 + package quic import ( diff --git a/internal/quic/packet_protection.go b/internal/quic/packet_protection.go index 7d96d69cd..6969ad3a1 100644 --- a/internal/quic/packet_protection.go +++ b/internal/quic/packet_protection.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.21 + package quic import ( diff --git a/internal/quic/packet_protection_test.go b/internal/quic/packet_protection_test.go index f1d353d8e..6495360a3 100644 --- a/internal/quic/packet_protection_test.go +++ b/internal/quic/packet_protection_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.21 + package quic import ( diff --git a/internal/quic/packet_test.go b/internal/quic/packet_test.go index 3011dda1d..b13a587e5 100644 --- a/internal/quic/packet_test.go +++ b/internal/quic/packet_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.21 + package quic import ( diff --git a/internal/quic/packet_writer.go b/internal/quic/packet_writer.go index 1f9e30f6e..494105eff 100644 --- a/internal/quic/packet_writer.go +++ b/internal/quic/packet_writer.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.21 + package quic import ( diff --git a/internal/quic/quic.go b/internal/quic/quic.go index f7c1b765d..aae611a5e 100644 --- a/internal/quic/quic.go +++ b/internal/quic/quic.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.21 + package quic import ( diff --git a/internal/quic/rangeset.go b/internal/quic/rangeset.go index 9d6b63a74..ea331ab9e 100644 --- a/internal/quic/rangeset.go +++ b/internal/quic/rangeset.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.21 + package quic // A rangeset is a set of int64s, stored as an ordered list of non-overlapping, diff --git a/internal/quic/rangeset_test.go b/internal/quic/rangeset_test.go index 292284813..082f9201c 100644 --- a/internal/quic/rangeset_test.go +++ b/internal/quic/rangeset_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.21 + package quic import ( diff --git a/internal/quic/sent_packet.go b/internal/quic/sent_packet.go index 03d5e53d1..e5a80be3b 100644 --- a/internal/quic/sent_packet.go +++ b/internal/quic/sent_packet.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.21 + package quic import ( @@ -32,7 +34,7 @@ type sentPacket struct { } var sentPool = sync.Pool{ - New: func() interface{} { + New: func() any { return &sentPacket{} }, } diff --git a/internal/quic/sent_packet_test.go b/internal/quic/sent_packet_test.go index 08a3d8ff0..c01a2fb33 100644 --- a/internal/quic/sent_packet_test.go +++ b/internal/quic/sent_packet_test.go @@ -2,12 +2,14 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.21 + package quic import "testing" func TestSentPacket(t *testing.T) { - frames := []interface{}{ + frames := []any{ byte(frameTypePing), byte(frameTypeStreamBase), uint64(1), diff --git a/internal/quic/sent_val.go b/internal/quic/sent_val.go index f1a9c9fbc..7ca5b70f3 100644 --- a/internal/quic/sent_val.go +++ b/internal/quic/sent_val.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.21 + package quic // A sentVal tracks sending some piece of information to the peer. diff --git a/internal/quic/sent_val_test.go b/internal/quic/sent_val_test.go index 458b221c2..7f9d67656 100644 --- a/internal/quic/sent_val_test.go +++ b/internal/quic/sent_val_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.21 + package quic import "testing" diff --git a/internal/quic/wire.go b/internal/quic/wire.go index 2494ad031..f0643c922 100644 --- a/internal/quic/wire.go +++ b/internal/quic/wire.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.21 + package quic import "encoding/binary" diff --git a/internal/quic/wire_test.go b/internal/quic/wire_test.go index a5dd83661..379da0d34 100644 --- a/internal/quic/wire_test.go +++ b/internal/quic/wire_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.21 + package quic import ( From ccc217c97ec0528131ef5dc058e57a2080cff953 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 30 May 2023 15:25:43 -0700 Subject: [PATCH 22/28] quic: parameterize rangeset Make the rangeset type parameterized, so it can be used for either packet number or byte ranges without type conversions. For golang/go#58547 Change-Id: I764913a33ba58222dcfd36f94de01c2249d73551 Reviewed-on: https://go-review.googlesource.com/c/net/+/499284 Run-TryBot: Damien Neil TryBot-Result: Gopher Robot Reviewed-by: Jonathan Amsterdam --- internal/quic/frame_debug.go | 10 +- internal/quic/packet_codec_test.go | 4 +- internal/quic/packet_writer.go | 2 +- internal/quic/rangeset.go | 41 +++--- internal/quic/rangeset_test.go | 210 ++++++++++++++--------------- internal/quic/sent_packet_test.go | 6 +- 6 files changed, 135 insertions(+), 138 deletions(-) diff --git a/internal/quic/frame_debug.go b/internal/quic/frame_debug.go index 34a0039ba..93ddf5513 100644 --- a/internal/quic/frame_debug.go +++ b/internal/quic/frame_debug.go @@ -116,15 +116,15 @@ func (f debugFramePing) write(w *packetWriter) bool { // debugFrameAck is an ACK frame. type debugFrameAck struct { ackDelay time.Duration - ranges []i64range + ranges []i64range[packetNumber] } func parseDebugFrameAck(b []byte) (f debugFrameAck, n int) { f.ranges = nil _, f.ackDelay, n = consumeAckFrame(b, ackDelayExponent, func(start, end packetNumber) { - f.ranges = append(f.ranges, i64range{ - start: int64(start), - end: int64(end), + f.ranges = append(f.ranges, i64range[packetNumber]{ + start: start, + end: end, }) }) // Ranges are parsed smallest to highest; reverse ranges slice to order them high to low. @@ -144,7 +144,7 @@ func (f debugFrameAck) String() string { } func (f debugFrameAck) write(w *packetWriter) bool { - return w.appendAckFrame(rangeset(f.ranges), ackDelayExponent, f.ackDelay) + return w.appendAckFrame(rangeset[packetNumber](f.ranges), ackDelayExponent, f.ackDelay) } // debugFrameResetStream is a RESET_STREAM frame. diff --git a/internal/quic/packet_codec_test.go b/internal/quic/packet_codec_test.go index 2a2043b99..efd519b1f 100644 --- a/internal/quic/packet_codec_test.go +++ b/internal/quic/packet_codec_test.go @@ -222,7 +222,7 @@ func TestFrameEncodeDecode(t *testing.T) { s: "ACK Delay=80µs [0,16) [17,32) [48,64)", f: debugFrameAck{ ackDelay: (10 << ackDelayExponent) * time.Microsecond, - ranges: []i64range{ + ranges: []i64range[packetNumber]{ {0x00, 0x10}, {0x11, 0x20}, {0x30, 0x40}, @@ -595,7 +595,7 @@ func TestFrameDecode(t *testing.T) { desc: "ACK frame with ECN counts", want: debugFrameAck{ ackDelay: (10 << ackDelayExponent) * time.Microsecond, - ranges: []i64range{ + ranges: []i64range[packetNumber]{ {0, 1}, }, }, diff --git a/internal/quic/packet_writer.go b/internal/quic/packet_writer.go index 494105eff..bfe9af796 100644 --- a/internal/quic/packet_writer.go +++ b/internal/quic/packet_writer.go @@ -257,7 +257,7 @@ func (w *packetWriter) appendPingFrame() (added bool) { // to the peer potentially failing to receive an acknowledgement // for an older packet during a period of high packet loss or // reordering. This may result in unnecessary retransmissions. -func (w *packetWriter) appendAckFrame(seen rangeset, ackDelayExponent uint8, delay time.Duration) (added bool) { +func (w *packetWriter) appendAckFrame(seen rangeset[packetNumber], ackDelayExponent uint8, delay time.Duration) (added bool) { if len(seen) == 0 { return false } diff --git a/internal/quic/rangeset.go b/internal/quic/rangeset.go index ea331ab9e..5339c5ac5 100644 --- a/internal/quic/rangeset.go +++ b/internal/quic/rangeset.go @@ -11,27 +11,24 @@ package quic // // Rangesets are efficient for small numbers of ranges, // which is expected to be the common case. -// -// Once we're willing to drop support for pre-generics versions of Go, this can -// be made into a parameterized type to permit use with packetNumber without casts. -type rangeset []i64range +type rangeset[T ~int64] []i64range[T] -type i64range struct { - start, end int64 // [start, end) +type i64range[T ~int64] struct { + start, end T // [start, end) } // size returns the size of the range. -func (r i64range) size() int64 { +func (r i64range[T]) size() T { return r.end - r.start } // contains reports whether v is in the range. -func (r i64range) contains(v int64) bool { +func (r i64range[T]) contains(v T) bool { return r.start <= v && v < r.end } // add adds [start, end) to the set, combining it with existing ranges if necessary. -func (s *rangeset) add(start, end int64) { +func (s *rangeset[T]) add(start, end T) { if start == end { return } @@ -65,11 +62,11 @@ func (s *rangeset) add(start, end int64) { s.removeranges(i+1, j) return } - *s = append(*s, i64range{start, end}) + *s = append(*s, i64range[T]{start, end}) } // sub removes [start, end) from the set. -func (s *rangeset) sub(start, end int64) { +func (s *rangeset[T]) sub(start, end T) { removefrom, removeto := -1, -1 for i := range *s { r := &(*s)[i] @@ -106,7 +103,7 @@ func (s *rangeset) sub(start, end int64) { } // contains reports whether s contains v. -func (s rangeset) contains(v int64) bool { +func (s rangeset[T]) contains(v T) bool { for _, r := range s { if v >= r.end { continue @@ -120,7 +117,7 @@ func (s rangeset) contains(v int64) bool { } // rangeContaining returns the range containing v, or the range [0,0) if v is not in s. -func (s rangeset) rangeContaining(v int64) i64range { +func (s rangeset[T]) rangeContaining(v T) i64range[T] { for _, r := range s { if v >= r.end { continue @@ -130,11 +127,11 @@ func (s rangeset) rangeContaining(v int64) i64range { } break } - return i64range{0, 0} + return i64range[T]{0, 0} } // min returns the minimum value in the set, or 0 if empty. -func (s rangeset) min() int64 { +func (s rangeset[T]) min() T { if len(s) == 0 { return 0 } @@ -142,7 +139,7 @@ func (s rangeset) min() int64 { } // max returns the maximum value in the set, or 0 if empty. -func (s rangeset) max() int64 { +func (s rangeset[T]) max() T { if len(s) == 0 { return 0 } @@ -150,7 +147,7 @@ func (s rangeset) max() int64 { } // end returns the end of the last range in the set, or 0 if empty. -func (s rangeset) end() int64 { +func (s rangeset[T]) end() T { if len(s) == 0 { return 0 } @@ -158,7 +155,7 @@ func (s rangeset) end() int64 { } // isrange reports if the rangeset covers exactly the range [start, end). -func (s rangeset) isrange(start, end int64) bool { +func (s rangeset[T]) isrange(start, end T) bool { switch len(s) { case 0: return start == 0 && end == 0 @@ -169,7 +166,7 @@ func (s rangeset) isrange(start, end int64) bool { } // removeranges removes ranges [i,j). -func (s *rangeset) removeranges(i, j int) { +func (s *rangeset[T]) removeranges(i, j int) { if i == j { return } @@ -178,8 +175,8 @@ func (s *rangeset) removeranges(i, j int) { } // insert adds a new range at index i. -func (s *rangeset) insertrange(i int, start, end int64) { - *s = append(*s, i64range{}) +func (s *rangeset[T]) insertrange(i int, start, end T) { + *s = append(*s, i64range[T]{}) copy((*s)[i+1:], (*s)[i:]) - (*s)[i] = i64range{start, end} + (*s)[i] = i64range[T]{start, end} } diff --git a/internal/quic/rangeset_test.go b/internal/quic/rangeset_test.go index 082f9201c..308046905 100644 --- a/internal/quic/rangeset_test.go +++ b/internal/quic/rangeset_test.go @@ -13,13 +13,13 @@ import ( func TestRangeSize(t *testing.T) { for _, test := range []struct { - r i64range + r i64range[int64] want int64 }{{ - r: i64range{0, 100}, + r: i64range[int64]{0, 100}, want: 100, }, { - r: i64range{10, 20}, + r: i64range[int64]{10, 20}, want: 10, }} { if got := test.r.size(); got != test.want { @@ -29,7 +29,7 @@ func TestRangeSize(t *testing.T) { } func TestRangeContains(t *testing.T) { - r := i64range{5, 10} + r := i64range[int64]{5, 10} for _, i := range []int64{0, 4, 10, 15} { if r.contains(i) { t.Errorf("%v.contains(%v) = true, want false", r, i) @@ -45,69 +45,69 @@ func TestRangeContains(t *testing.T) { func TestRangesetAdd(t *testing.T) { for _, test := range []struct { desc string - set rangeset - add i64range - want rangeset + set rangeset[int64] + add i64range[int64] + want rangeset[int64] }{{ desc: "add to empty set", - set: rangeset{}, - add: i64range{0, 100}, - want: rangeset{{0, 100}}, + set: rangeset[int64]{}, + add: i64range[int64]{0, 100}, + want: rangeset[int64]{{0, 100}}, }, { desc: "add empty range", - set: rangeset{}, - add: i64range{100, 100}, - want: rangeset{}, + set: rangeset[int64]{}, + add: i64range[int64]{100, 100}, + want: rangeset[int64]{}, }, { desc: "append nonadjacent range", - set: rangeset{{100, 200}}, - add: i64range{300, 400}, - want: rangeset{{100, 200}, {300, 400}}, + set: rangeset[int64]{{100, 200}}, + add: i64range[int64]{300, 400}, + want: rangeset[int64]{{100, 200}, {300, 400}}, }, { desc: "prepend nonadjacent range", - set: rangeset{{100, 200}}, - add: i64range{0, 50}, - want: rangeset{{0, 50}, {100, 200}}, + set: rangeset[int64]{{100, 200}}, + add: i64range[int64]{0, 50}, + want: rangeset[int64]{{0, 50}, {100, 200}}, }, { desc: "insert nonadjacent range", - set: rangeset{{100, 200}, {500, 600}}, - add: i64range{300, 400}, - want: rangeset{{100, 200}, {300, 400}, {500, 600}}, + set: rangeset[int64]{{100, 200}, {500, 600}}, + add: i64range[int64]{300, 400}, + want: rangeset[int64]{{100, 200}, {300, 400}, {500, 600}}, }, { desc: "prepend adjacent range", - set: rangeset{{100, 200}}, - add: i64range{50, 100}, - want: rangeset{{50, 200}}, + set: rangeset[int64]{{100, 200}}, + add: i64range[int64]{50, 100}, + want: rangeset[int64]{{50, 200}}, }, { desc: "append adjacent range", - set: rangeset{{100, 200}}, - add: i64range{200, 250}, - want: rangeset{{100, 250}}, + set: rangeset[int64]{{100, 200}}, + add: i64range[int64]{200, 250}, + want: rangeset[int64]{{100, 250}}, }, { desc: "prepend overlapping range", - set: rangeset{{100, 200}}, - add: i64range{50, 150}, - want: rangeset{{50, 200}}, + set: rangeset[int64]{{100, 200}}, + add: i64range[int64]{50, 150}, + want: rangeset[int64]{{50, 200}}, }, { desc: "append overlapping range", - set: rangeset{{100, 200}}, - add: i64range{150, 250}, - want: rangeset{{100, 250}}, + set: rangeset[int64]{{100, 200}}, + add: i64range[int64]{150, 250}, + want: rangeset[int64]{{100, 250}}, }, { desc: "replace range", - set: rangeset{{100, 200}}, - add: i64range{50, 250}, - want: rangeset{{50, 250}}, + set: rangeset[int64]{{100, 200}}, + add: i64range[int64]{50, 250}, + want: rangeset[int64]{{50, 250}}, }, { desc: "prepend and combine", - set: rangeset{{100, 200}, {300, 400}, {500, 600}}, - add: i64range{50, 300}, - want: rangeset{{50, 400}, {500, 600}}, + set: rangeset[int64]{{100, 200}, {300, 400}, {500, 600}}, + add: i64range[int64]{50, 300}, + want: rangeset[int64]{{50, 400}, {500, 600}}, }, { desc: "combine several ranges", - set: rangeset{{100, 200}, {300, 400}, {500, 600}, {700, 800}, {900, 1000}}, - add: i64range{300, 850}, - want: rangeset{{100, 200}, {300, 850}, {900, 1000}}, + set: rangeset[int64]{{100, 200}, {300, 400}, {500, 600}, {700, 800}, {900, 1000}}, + add: i64range[int64]{300, 850}, + want: rangeset[int64]{{100, 200}, {300, 850}, {900, 1000}}, }} { test := test t.Run(test.desc, func(t *testing.T) { @@ -125,59 +125,59 @@ func TestRangesetAdd(t *testing.T) { func TestRangesetSub(t *testing.T) { for _, test := range []struct { desc string - set rangeset - sub i64range - want rangeset + set rangeset[int64] + sub i64range[int64] + want rangeset[int64] }{{ desc: "subtract from empty set", - set: rangeset{}, - sub: i64range{0, 100}, - want: rangeset{}, + set: rangeset[int64]{}, + sub: i64range[int64]{0, 100}, + want: rangeset[int64]{}, }, { desc: "subtract empty range", - set: rangeset{{0, 100}}, - sub: i64range{0, 0}, - want: rangeset{{0, 100}}, + set: rangeset[int64]{{0, 100}}, + sub: i64range[int64]{0, 0}, + want: rangeset[int64]{{0, 100}}, }, { desc: "subtract not present in set", - set: rangeset{{0, 100}, {200, 300}}, - sub: i64range{100, 200}, - want: rangeset{{0, 100}, {200, 300}}, + set: rangeset[int64]{{0, 100}, {200, 300}}, + sub: i64range[int64]{100, 200}, + want: rangeset[int64]{{0, 100}, {200, 300}}, }, { desc: "subtract prefix", - set: rangeset{{100, 200}}, - sub: i64range{0, 150}, - want: rangeset{{150, 200}}, + set: rangeset[int64]{{100, 200}}, + sub: i64range[int64]{0, 150}, + want: rangeset[int64]{{150, 200}}, }, { desc: "subtract suffix", - set: rangeset{{100, 200}}, - sub: i64range{150, 300}, - want: rangeset{{100, 150}}, + set: rangeset[int64]{{100, 200}}, + sub: i64range[int64]{150, 300}, + want: rangeset[int64]{{100, 150}}, }, { desc: "subtract middle", - set: rangeset{{0, 100}}, - sub: i64range{40, 60}, - want: rangeset{{0, 40}, {60, 100}}, + set: rangeset[int64]{{0, 100}}, + sub: i64range[int64]{40, 60}, + want: rangeset[int64]{{0, 40}, {60, 100}}, }, { desc: "subtract from two ranges", - set: rangeset{{0, 100}, {200, 300}}, - sub: i64range{50, 250}, - want: rangeset{{0, 50}, {250, 300}}, + set: rangeset[int64]{{0, 100}, {200, 300}}, + sub: i64range[int64]{50, 250}, + want: rangeset[int64]{{0, 50}, {250, 300}}, }, { desc: "subtract removes range", - set: rangeset{{0, 100}, {200, 300}, {400, 500}}, - sub: i64range{200, 300}, - want: rangeset{{0, 100}, {400, 500}}, + set: rangeset[int64]{{0, 100}, {200, 300}, {400, 500}}, + sub: i64range[int64]{200, 300}, + want: rangeset[int64]{{0, 100}, {400, 500}}, }, { desc: "subtract removes multiple ranges", - set: rangeset{{0, 100}, {200, 300}, {400, 500}, {600, 700}}, - sub: i64range{50, 650}, - want: rangeset{{0, 50}, {650, 700}}, + set: rangeset[int64]{{0, 100}, {200, 300}, {400, 500}, {600, 700}}, + sub: i64range[int64]{50, 650}, + want: rangeset[int64]{{0, 50}, {650, 700}}, }, { desc: "subtract only range", - set: rangeset{{0, 100}}, - sub: i64range{0, 100}, - want: rangeset{}, + set: rangeset[int64]{{0, 100}}, + sub: i64range[int64]{0, 100}, + want: rangeset[int64]{}, }} { test := test t.Run(test.desc, func(t *testing.T) { @@ -193,7 +193,7 @@ func TestRangesetSub(t *testing.T) { } func TestRangesetContains(t *testing.T) { - var s rangeset + var s rangeset[int64] s.add(10, 20) s.add(30, 40) for i := int64(0); i < 50; i++ { @@ -205,23 +205,23 @@ func TestRangesetContains(t *testing.T) { } func TestRangesetRangeContaining(t *testing.T) { - var s rangeset + var s rangeset[int64] s.add(10, 20) s.add(30, 40) for _, test := range []struct { v int64 - want i64range + want i64range[int64] }{ - {0, i64range{0, 0}}, - {9, i64range{0, 0}}, - {10, i64range{10, 20}}, - {15, i64range{10, 20}}, - {19, i64range{10, 20}}, - {20, i64range{0, 0}}, - {29, i64range{0, 0}}, - {30, i64range{30, 40}}, - {39, i64range{30, 40}}, - {40, i64range{0, 0}}, + {0, i64range[int64]{0, 0}}, + {9, i64range[int64]{0, 0}}, + {10, i64range[int64]{10, 20}}, + {15, i64range[int64]{10, 20}}, + {19, i64range[int64]{10, 20}}, + {20, i64range[int64]{0, 0}}, + {29, i64range[int64]{0, 0}}, + {30, i64range[int64]{30, 40}}, + {39, i64range[int64]{30, 40}}, + {40, i64range[int64]{0, 0}}, } { got := s.rangeContaining(test.v) if got != test.want { @@ -232,22 +232,22 @@ func TestRangesetRangeContaining(t *testing.T) { func TestRangesetLimits(t *testing.T) { for _, test := range []struct { - s rangeset + s rangeset[int64] wantMin int64 wantMax int64 wantEnd int64 }{{ - s: rangeset{}, + s: rangeset[int64]{}, wantMin: 0, wantMax: 0, wantEnd: 0, }, { - s: rangeset{{10, 20}}, + s: rangeset[int64]{{10, 20}}, wantMin: 10, wantMax: 19, wantEnd: 20, }, { - s: rangeset{{10, 20}, {30, 40}, {50, 60}}, + s: rangeset[int64]{{10, 20}, {30, 40}, {50, 60}}, wantMin: 10, wantMax: 59, wantEnd: 60, @@ -266,28 +266,28 @@ func TestRangesetLimits(t *testing.T) { func TestRangesetIsRange(t *testing.T) { for _, test := range []struct { - s rangeset - r i64range + s rangeset[int64] + r i64range[int64] want bool }{{ - s: rangeset{{0, 100}}, - r: i64range{0, 100}, + s: rangeset[int64]{{0, 100}}, + r: i64range[int64]{0, 100}, want: true, }, { - s: rangeset{{0, 100}}, - r: i64range{0, 101}, + s: rangeset[int64]{{0, 100}}, + r: i64range[int64]{0, 101}, want: false, }, { - s: rangeset{{0, 10}, {11, 100}}, - r: i64range{0, 100}, + s: rangeset[int64]{{0, 10}, {11, 100}}, + r: i64range[int64]{0, 100}, want: false, }, { - s: rangeset{}, - r: i64range{0, 0}, + s: rangeset[int64]{}, + r: i64range[int64]{0, 0}, want: true, }, { - s: rangeset{}, - r: i64range{0, 1}, + s: rangeset[int64]{}, + r: i64range[int64]{0, 1}, want: false, }} { if got := test.s.isrange(test.r.start, test.r.end); got != test.want { diff --git a/internal/quic/sent_packet_test.go b/internal/quic/sent_packet_test.go index c01a2fb33..c0b04e676 100644 --- a/internal/quic/sent_packet_test.go +++ b/internal/quic/sent_packet_test.go @@ -13,7 +13,7 @@ func TestSentPacket(t *testing.T) { byte(frameTypePing), byte(frameTypeStreamBase), uint64(1), - i64range{1 << 20, 1<<20 + 1024}, + i64range[int64]{1 << 20, 1<<20 + 1024}, } // Record sent frames. sent := newSentPacket() @@ -23,7 +23,7 @@ func TestSentPacket(t *testing.T) { sent.appendAckElicitingFrame(f) case uint64: sent.appendInt(f) - case i64range: + case i64range[int64]: sent.appendOffAndSize(f.start, int(f.size())) } } @@ -41,7 +41,7 @@ func TestSentPacket(t *testing.T) { if got := sent.nextInt(); got != want { t.Fatalf("%v: sent.nextInt() = %v, want %v", i, got, want) } - case i64range: + case i64range[int64]: if start, end := sent.nextRange(); start != want.start || end != want.end { t.Fatalf("%v: sent.nextRange() = [%v,%v), want %v", i, start, end, want) } From 10cf388024ca83e11a073bb365ccda2b2806d103 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 16 May 2023 16:41:53 -0700 Subject: [PATCH 23/28] quic: add a data structure for tracking lists of sent packets Store in-flight packets in a ring buffer. For golang/go#58547 Change-Id: I825c4e600bb496c2f8f6c195085aaae3e847445e Reviewed-on: https://go-review.googlesource.com/c/net/+/499285 TryBot-Result: Gopher Robot Reviewed-by: Jonathan Amsterdam Run-TryBot: Damien Neil --- internal/quic/sent_packet_list.go | 95 ++++++++++++++++++++++ internal/quic/sent_packet_list_test.go | 107 +++++++++++++++++++++++++ 2 files changed, 202 insertions(+) create mode 100644 internal/quic/sent_packet_list.go create mode 100644 internal/quic/sent_packet_list_test.go diff --git a/internal/quic/sent_packet_list.go b/internal/quic/sent_packet_list.go new file mode 100644 index 000000000..6fb712a7a --- /dev/null +++ b/internal/quic/sent_packet_list.go @@ -0,0 +1,95 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +// A sentPacketList is a ring buffer of sentPackets. +// +// Processing an ack for a packet causes all older packets past a small threshold +// to be discarded (RFC 9002, Section 6.1.1), so the list of in-flight packets is +// not sparse and will contain at most a few acked/lost packets we no longer +// care about. +type sentPacketList struct { + nextNum packetNumber // next packet number to add to the buffer + off int // offset of first packet in the buffer + size int // number of packets + p []*sentPacket +} + +// start is the first packet in the list. +func (s *sentPacketList) start() packetNumber { + return s.nextNum - packetNumber(s.size) +} + +// end is one after the last packet in the list. +// If the list is empty, start == end. +func (s *sentPacketList) end() packetNumber { + return s.nextNum +} + +// discard clears the list. +func (s *sentPacketList) discard() { + *s = sentPacketList{} +} + +// add appends a packet to the list. +func (s *sentPacketList) add(sent *sentPacket) { + if s.nextNum != sent.num { + panic("inserting out-of-order packet") + } + s.nextNum++ + if s.size >= len(s.p) { + s.grow() + } + i := (s.off + s.size) % len(s.p) + s.size++ + s.p[i] = sent +} + +// nth returns a packet by index. +func (s *sentPacketList) nth(n int) *sentPacket { + index := (s.off + n) % len(s.p) + return s.p[index] +} + +// num returns a packet by number. +// It returns nil if the packet is not in the list. +func (s *sentPacketList) num(num packetNumber) *sentPacket { + i := int(num - s.start()) + if i < 0 || i >= s.size { + return nil + } + return s.nth(i) +} + +// clean removes all acked or lost packets from the head of the list. +func (s *sentPacketList) clean() { + for s.size > 0 { + sent := s.p[s.off] + if !sent.acked && !sent.lost { + return + } + sent.recycle() + s.p[s.off] = nil + s.off = (s.off + 1) % len(s.p) + s.size-- + } + s.off = 0 +} + +// grow increases the buffer to hold more packaets. +func (s *sentPacketList) grow() { + newSize := len(s.p) * 2 + if newSize == 0 { + newSize = 64 + } + p := make([]*sentPacket, newSize) + for i := 0; i < s.size; i++ { + p[i] = s.nth(i) + } + s.p = p + s.off = 0 +} diff --git a/internal/quic/sent_packet_list_test.go b/internal/quic/sent_packet_list_test.go new file mode 100644 index 000000000..2f7f4d2c6 --- /dev/null +++ b/internal/quic/sent_packet_list_test.go @@ -0,0 +1,107 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import "testing" + +func TestSentPacketListSlidingWindow(t *testing.T) { + // Record 1000 sent packets, acking everything outside the most recent 10. + list := &sentPacketList{} + const window = 10 + for i := packetNumber(0); i < 1000; i++ { + list.add(&sentPacket{num: i}) + if i < window { + continue + } + prev := i - window + sent := list.num(prev) + if sent == nil { + t.Fatalf("packet %v not in list", prev) + } + if sent.num != prev { + t.Fatalf("list.num(%v) = packet %v", prev, sent.num) + } + if got := list.nth(0); got != sent { + t.Fatalf("list.nth(0) != list.num(%v)", prev) + } + sent.acked = true + list.clean() + if got := list.num(prev); got != nil { + t.Fatalf("list.num(%v) = packet %v, expected it to be discarded", prev, got.num) + } + if got, want := list.start(), prev+1; got != want { + t.Fatalf("list.start() = %v, want %v", got, want) + } + if got, want := list.end(), i+1; got != want { + t.Fatalf("list.end() = %v, want %v", got, want) + } + if got, want := list.size, window; got != want { + t.Fatalf("list.size = %v, want %v", got, want) + } + } +} + +func TestSentPacketListGrows(t *testing.T) { + // Record 1000 sent packets. + list := &sentPacketList{} + const count = 1000 + for i := packetNumber(0); i < count; i++ { + list.add(&sentPacket{num: i}) + } + if got, want := list.start(), packetNumber(0); got != want { + t.Fatalf("list.start() = %v, want %v", got, want) + } + if got, want := list.end(), packetNumber(count); got != want { + t.Fatalf("list.end() = %v, want %v", got, want) + } + if got, want := list.size, count; got != want { + t.Fatalf("list.size = %v, want %v", got, want) + } + for i := packetNumber(0); i < count; i++ { + sent := list.num(i) + if sent == nil { + t.Fatalf("packet %v not in list", i) + } + if sent.num != i { + t.Fatalf("list.num(%v) = packet %v", i, sent.num) + } + if got := list.nth(int(i)); got != sent { + t.Fatalf("list.nth(%v) != list.num(%v)", int(i), i) + } + } +} + +func TestSentPacketListCleanAll(t *testing.T) { + list := &sentPacketList{} + // Record 10 sent packets. + const count = 10 + for i := packetNumber(0); i < count; i++ { + list.add(&sentPacket{num: i}) + } + // Mark all the packets as acked. + for i := packetNumber(0); i < count; i++ { + list.num(i).acked = true + } + list.clean() + if got, want := list.size, 0; got != want { + t.Fatalf("list.size = %v, want %v", got, want) + } + list.add(&sentPacket{num: 10}) + if got, want := list.size, 1; got != want { + t.Fatalf("list.size = %v, want %v", got, want) + } + sent := list.num(10) + if sent == nil { + t.Fatalf("packet %v not in list", 10) + } + if sent.num != 10 { + t.Fatalf("list.num(10) = %v", sent.num) + } + if got := list.nth(0); got != sent { + t.Fatalf("list.nth(0) != list.num(10)") + } +} From 2796e09d6af2852bffb1ae1eefb63a21c823e22f Mon Sep 17 00:00:00 2001 From: Matt Layher Date: Tue, 6 Jun 2023 10:02:23 -0400 Subject: [PATCH 24/28] bpf: check for little endian CPU for OS VM comparison When I wrote these tests, I assumed the native endianness for the machine was little endian. Explicitly check this so that the emulated BPF VM tests can run on s390x, but we avoid test flakes related to endianness. Updates golang/go#55235. Change-Id: I9be430dfe7f97503af7a620ed80dcbacb66d73cc Reviewed-on: https://go-review.googlesource.com/c/net/+/501155 Reviewed-by: David Chase Reviewed-by: Ian Lance Taylor Run-TryBot: Ian Lance Taylor Reviewed-by: Tobias Klauser TryBot-Result: Gopher Robot Run-TryBot: Matt Layher Auto-Submit: Ian Lance Taylor --- bpf/vm_bpf_test.go | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/bpf/vm_bpf_test.go b/bpf/vm_bpf_test.go index 137eea160..559e5dd8e 100644 --- a/bpf/vm_bpf_test.go +++ b/bpf/vm_bpf_test.go @@ -14,6 +14,7 @@ import ( "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" "golang.org/x/net/nettest" + "golang.org/x/sys/cpu" ) // A virtualMachine is a BPF virtual machine which can process an @@ -22,18 +23,6 @@ type virtualMachine interface { Run(in []byte) (int, error) } -// canUseOSVM indicates if the OS BPF VM is available on this platform. -func canUseOSVM() bool { - // OS BPF VM can only be used on platforms where x/net/ipv4 supports - // attaching a BPF program to a socket. - switch runtime.GOOS { - case "linux": - return true - } - - return false -} - // All BPF tests against both the Go VM and OS VM are assumed to // be used with a UDP socket. As a result, the entire contents // of a UDP datagram is sent through the BPF program, but only @@ -55,11 +44,10 @@ func testVM(t *testing.T, filter []bpf.Instruction) (virtualMachine, func(), err t: t, } - // If available, add the OS VM for tests which verify that both the Go - // VM and OS VM have exactly the same output for the same input program - // and packet. + // For linux with a little endian CPU, the Go VM and OS VM have exactly the + // same output for the same input program and packet. Compare both. done := func() {} - if canUseOSVM() { + if runtime.GOOS == "linux" && !cpu.IsBigEndian { osVM, osVMDone := testOSVM(t, filter) done = func() { osVMDone() } mvm.osVM = osVM From 7e6923f9c41329cf4f8ffe24e4503e030dab2ff2 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 25 May 2023 09:41:50 -0700 Subject: [PATCH 25/28] quic: add RTT estimator Implement the round-trip time estimation algorithm from RFC 9002, Section 5. For golang/go#58547 Change-Id: I494e692e710f77270c9ad28354366f384feb4ac7 Reviewed-on: https://go-review.googlesource.com/c/net/+/499286 TryBot-Result: Gopher Robot Run-TryBot: Damien Neil Reviewed-by: Jonathan Amsterdam --- internal/quic/math.go | 14 ++++ internal/quic/rtt.go | 73 +++++++++++++++++ internal/quic/rtt_test.go | 168 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 255 insertions(+) create mode 100644 internal/quic/math.go create mode 100644 internal/quic/rtt.go create mode 100644 internal/quic/rtt_test.go diff --git a/internal/quic/math.go b/internal/quic/math.go new file mode 100644 index 000000000..f9dd7545a --- /dev/null +++ b/internal/quic/math.go @@ -0,0 +1,14 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +func abs[T ~int | ~int64](a T) T { + if a < 0 { + return -a + } + return a +} diff --git a/internal/quic/rtt.go b/internal/quic/rtt.go new file mode 100644 index 000000000..5bd8861b0 --- /dev/null +++ b/internal/quic/rtt.go @@ -0,0 +1,73 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "time" +) + +type rttState struct { + minRTT time.Duration + latestRTT time.Duration + smoothedRTT time.Duration + rttvar time.Duration // RTT variation + firstSampleTime time.Time // time of first RTT sample +} + +func (r *rttState) init() { + r.minRTT = -1 // -1 indicates the first sample has not been taken yet + + // "[...] the initial RTT SHOULD be set to 333 milliseconds." + // https://www.rfc-editor.org/rfc/rfc9002.html#section-6.2.2-1 + const initialRTT = 333 * time.Millisecond + + // https://www.rfc-editor.org/rfc/rfc9002.html#section-5.3-12 + r.smoothedRTT = initialRTT + r.rttvar = initialRTT / 2 +} + +func (r *rttState) establishPersistentCongestion() { + // "Endpoints SHOULD set the min_rtt to the newest RTT sample + // after persistent congestion is established." + // https://www.rfc-editor.org/rfc/rfc9002#section-5.2-5 + r.minRTT = r.latestRTT +} + +// updateRTTSample is called when we generate a new RTT sample. +// https://www.rfc-editor.org/rfc/rfc9002.html#section-5 +func (r *rttState) updateSample(now time.Time, handshakeConfirmed bool, spaceID numberSpace, latestRTT, ackDelay, maxAckDelay time.Duration) { + r.latestRTT = latestRTT + + if r.minRTT < 0 { + // First RTT sample. + // "min_rtt MUST be set to the latest_rtt on the first RTT sample." + // https://www.rfc-editor.org/rfc/rfc9002.html#section-5.2-2 + r.minRTT = latestRTT + // https://www.rfc-editor.org/rfc/rfc9002.html#section-5.3-14 + r.smoothedRTT = latestRTT + r.rttvar = latestRTT / 2 + r.firstSampleTime = now + return + } + + // "min_rtt MUST be set to the lesser of min_rtt and latest_rtt [...] + // on all other samples." + // https://www.rfc-editor.org/rfc/rfc9002.html#section-5.2-2 + r.minRTT = min(r.minRTT, latestRTT) + + // https://www.rfc-editor.org/rfc/rfc9002.html#section-5.3-16 + if handshakeConfirmed { + ackDelay = min(ackDelay, maxAckDelay) + } + adjustedRTT := latestRTT - ackDelay + if adjustedRTT < r.minRTT { + adjustedRTT = latestRTT + } + r.smoothedRTT = ((7 * r.smoothedRTT) + adjustedRTT) / 8 + rttvarSample := abs(r.smoothedRTT - adjustedRTT) + r.rttvar = (3*r.rttvar + rttvarSample) / 4 +} diff --git a/internal/quic/rtt_test.go b/internal/quic/rtt_test.go new file mode 100644 index 000000000..63789c288 --- /dev/null +++ b/internal/quic/rtt_test.go @@ -0,0 +1,168 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "testing" + "time" +) + +func TestRTTMinRTT(t *testing.T) { + var ( + handshakeConfirmed = false + ackDelay = 0 * time.Millisecond + maxAckDelay = 25 * time.Millisecond + now = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + ) + rtt := &rttState{} + rtt.init() + + // "min_rtt MUST be set to the latest_rtt on the first RTT sample." + // https://www.rfc-editor.org/rfc/rfc9002.html#section-5.2-2 + rtt.updateSample(now, handshakeConfirmed, initialSpace, 10*time.Millisecond, ackDelay, maxAckDelay) + if got, want := rtt.latestRTT, 10*time.Millisecond; got != want { + t.Errorf("on first sample: latest_rtt = %v, want %v", got, want) + } + if got, want := rtt.minRTT, 10*time.Millisecond; got != want { + t.Errorf("on first sample: min_rtt = %v, want %v", got, want) + } + + // "min_rtt MUST be set to the lesser of min_rtt and latest_rtt [...] + // on all other samples." + rtt.updateSample(now, handshakeConfirmed, initialSpace, 20*time.Millisecond, ackDelay, maxAckDelay) + if got, want := rtt.latestRTT, 20*time.Millisecond; got != want { + t.Errorf("on increasing sample: latest_rtt = %v, want %v", got, want) + } + if got, want := rtt.minRTT, 10*time.Millisecond; got != want { + t.Errorf("on increasing sample: min_rtt = %v, want %v (no change)", got, want) + } + + rtt.updateSample(now, handshakeConfirmed, initialSpace, 5*time.Millisecond, ackDelay, maxAckDelay) + if got, want := rtt.latestRTT, 5*time.Millisecond; got != want { + t.Errorf("on new minimum: latest_rtt = %v, want %v", got, want) + } + if got, want := rtt.minRTT, 5*time.Millisecond; got != want { + t.Errorf("on new minimum: min_rtt = %v, want %v", got, want) + } + + // "Endpoints SHOULD set the min_rtt to the newest RTT sample + // after persistent congestion is established." + // https://www.rfc-editor.org/rfc/rfc9002.html#section-5.2-5 + rtt.updateSample(now, handshakeConfirmed, initialSpace, 15*time.Millisecond, ackDelay, maxAckDelay) + if got, want := rtt.latestRTT, 15*time.Millisecond; got != want { + t.Errorf("on increasing sample: latest_rtt = %v, want %v", got, want) + } + if got, want := rtt.minRTT, 5*time.Millisecond; got != want { + t.Errorf("on increasing sample: min_rtt = %v, want %v (no change)", got, want) + } + rtt.establishPersistentCongestion() + if got, want := rtt.minRTT, 15*time.Millisecond; got != want { + t.Errorf("after persistent congestion: min_rtt = %v, want %v", got, want) + } +} + +func TestRTTInitialRTT(t *testing.T) { + var ( + handshakeConfirmed = false + ackDelay = 0 * time.Millisecond + maxAckDelay = 25 * time.Millisecond + now = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + ) + rtt := &rttState{} + rtt.init() + + // "When no previous RTT is available, + // the initial RTT SHOULD be set to 333 milliseconds." + // https://www.rfc-editor.org/rfc/rfc9002#section-6.2.2-1 + if got, want := rtt.smoothedRTT, 333*time.Millisecond; got != want { + t.Errorf("initial smoothed_rtt = %v, want %v", got, want) + } + if got, want := rtt.rttvar, 333*time.Millisecond/2; got != want { + t.Errorf("initial rttvar = %v, want %v", got, want) + } + + rtt.updateSample(now, handshakeConfirmed, initialSpace, 10*time.Millisecond, ackDelay, maxAckDelay) + smoothedRTT := 10 * time.Millisecond + if got, want := rtt.smoothedRTT, smoothedRTT; got != want { + t.Errorf("after first rtt sample of 10ms, smoothed_rtt = %v, want %v", got, want) + } + rttvar := 5 * time.Millisecond + if got, want := rtt.rttvar, rttvar; got != want { + t.Errorf("after first rtt sample of 10ms, rttvar = %v, want %v", got, want) + } + + // "[...] MAY ignore the acknowledgment delay for Initial packets [...]" + // https://www.rfc-editor.org/rfc/rfc9002#section-5.3-7.1 + ackDelay = 1 * time.Millisecond + rtt.updateSample(now, handshakeConfirmed, initialSpace, 10*time.Millisecond, ackDelay, maxAckDelay) + adjustedRTT := 10 * time.Millisecond + smoothedRTT = (7*smoothedRTT + adjustedRTT) / 8 + if got, want := rtt.smoothedRTT, smoothedRTT; got != want { + t.Errorf("smoothed_rtt = %v, want %v", got, want) + } + rttvarSample := abs(smoothedRTT - adjustedRTT) + rttvar = (3*rttvar + rttvarSample) / 4 + if got, want := rtt.rttvar, rttvar; got != want { + t.Errorf("rttvar = %v, want %v", got, want) + } + + // "[...] SHOULD ignore the peer's max_ack_delay until the handshake is confirmed [...]" + // https://www.rfc-editor.org/rfc/rfc9002#section-5.3-7.2 + ackDelay = 30 * time.Millisecond + maxAckDelay = 25 * time.Millisecond + rtt.updateSample(now, handshakeConfirmed, handshakeSpace, 40*time.Millisecond, ackDelay, maxAckDelay) + adjustedRTT = 10 * time.Millisecond // latest_rtt (40ms) - ack_delay (30ms) + smoothedRTT = (7*smoothedRTT + adjustedRTT) / 8 + if got, want := rtt.smoothedRTT, smoothedRTT; got != want { + t.Errorf("smoothed_rtt = %v, want %v", got, want) + } + rttvarSample = abs(smoothedRTT - adjustedRTT) + rttvar = (3*rttvar + rttvarSample) / 4 + if got, want := rtt.rttvar, rttvar; got != want { + t.Errorf("rttvar = %v, want %v", got, want) + } + + // "[...] MUST use the lesser of the acknowledgment delay and + // the peer's max_ack_delay after the handshake is confirmed [...]" + // https://www.rfc-editor.org/rfc/rfc9002#section-5.3-7.3 + ackDelay = 30 * time.Millisecond + maxAckDelay = 25 * time.Millisecond + handshakeConfirmed = true + rtt.updateSample(now, handshakeConfirmed, handshakeSpace, 40*time.Millisecond, ackDelay, maxAckDelay) + adjustedRTT = 15 * time.Millisecond // latest_rtt (40ms) - max_ack_delay (25ms) + smoothedRTT = (7*smoothedRTT + adjustedRTT) / 8 + if got, want := rtt.smoothedRTT, smoothedRTT; got != want { + t.Errorf("smoothed_rtt = %v, want %v", got, want) + } + rttvarSample = abs(smoothedRTT - adjustedRTT) + rttvar = (3*rttvar + rttvarSample) / 4 + if got, want := rtt.rttvar, rttvar; got != want { + t.Errorf("rttvar = %v, want %v", got, want) + } + + // "[...] MUST NOT subtract the acknowledgment delay from + // the RTT sample if the resulting value is smaller than the min_rtt." + // https://www.rfc-editor.org/rfc/rfc9002#section-5.3-7.4 + ackDelay = 25 * time.Millisecond + maxAckDelay = 25 * time.Millisecond + handshakeConfirmed = true + rtt.updateSample(now, handshakeConfirmed, handshakeSpace, 30*time.Millisecond, ackDelay, maxAckDelay) + if got, want := rtt.minRTT, 10*time.Millisecond; got != want { + t.Errorf("min_rtt = %v, want %v", got, want) + } + // latest_rtt (30ms) - ack_delay (25ms) = 5ms, which is less than min_rtt (10ms) + adjustedRTT = 30 * time.Millisecond // latest_rtt + smoothedRTT = (7*smoothedRTT + adjustedRTT) / 8 + if got, want := rtt.smoothedRTT, smoothedRTT; got != want { + t.Errorf("smoothed_rtt = %v, want %v", got, want) + } + rttvarSample = abs(smoothedRTT - adjustedRTT) + rttvar = (3*rttvar + rttvarSample) / 4 + if got, want := rtt.rttvar, rttvar; got != want { + t.Errorf("rttvar = %v, want %v", got, want) + } +} From 88a50b64844d3c278f93e8007efd20742c386ded Mon Sep 17 00:00:00 2001 From: "Bryan C. Mills" Date: Wed, 7 Jun 2023 12:22:51 -0400 Subject: [PATCH 26/28] all: update x/sys to HEAD An update is needed to pull in CL 494856, to allow the bpf test to build on js and wasip1 after the test changes in CL 501155. Updates golang/go#55235. Updates golang/go#57237. Change-Id: Iff48bad97453932065c27b0c8b4a3706ddcf659a Reviewed-on: https://go-review.googlesource.com/c/net/+/501615 Reviewed-by: Tobias Klauser Run-TryBot: Bryan Mills Auto-Submit: Bryan Mills Reviewed-by: Matt Layher Reviewed-by: Ian Lance Taylor TryBot-Result: Gopher Robot --- go.mod | 2 +- go.sum | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index b1d3e5474..43451b0da 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.17 require ( golang.org/x/crypto v0.9.0 - golang.org/x/sys v0.8.0 + golang.org/x/sys v0.8.1-0.20230606214330-304f187cdba9 golang.org/x/term v0.8.0 golang.org/x/text v0.9.0 ) diff --git a/go.sum b/go.sum index af21d7cac..666172d8b 100644 --- a/go.sum +++ b/go.sum @@ -19,8 +19,9 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.1-0.20230606214330-304f187cdba9 h1:M57WSsHXpJhTgA9/YbQISwowFNh+0HsWrZQt6NZf1go= +golang.org/x/sys v0.8.1-0.20230606214330-304f187cdba9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= From 5541298b83d714971b9a1e63219fbd68d60aa98d Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Wed, 24 May 2023 15:37:11 -0700 Subject: [PATCH 27/28] quic: add packet pacer The pacer rate-limits the transmission of packets to avoid creating bursts that may cause short-term congestion or loss. See RFC 9002, Section 7.7. For golang/go#58547 Change-Id: I75285c194a1048f988e4d5a829602d199829669d Reviewed-on: https://go-review.googlesource.com/c/net/+/499287 Run-TryBot: Damien Neil TryBot-Result: Gopher Robot Reviewed-by: Jonathan Amsterdam --- internal/quic/pacer.go | 131 +++++++++++++++++++++ internal/quic/pacer_test.go | 224 ++++++++++++++++++++++++++++++++++++ 2 files changed, 355 insertions(+) create mode 100644 internal/quic/pacer.go create mode 100644 internal/quic/pacer_test.go diff --git a/internal/quic/pacer.go b/internal/quic/pacer.go new file mode 100644 index 000000000..bcba76936 --- /dev/null +++ b/internal/quic/pacer.go @@ -0,0 +1,131 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "time" +) + +// A pacerState controls the rate at which packets are sent using a leaky-bucket rate limiter. +// +// The pacer limits the maximum size of a burst of packets. +// When a burst exceeds this limit, it spreads subsequent packets +// over time. +// +// The bucket is initialized to the maximum burst size (ten packets by default), +// and fills at the rate: +// +// 1.25 * congestion_window / smoothed_rtt +// +// A sender can send one congestion window of packets per RTT, +// since the congestion window consumed by each packet is returned +// one round-trip later by the responding ack. +// The pacer permits sending at slightly faster than this rate to +// avoid underutilizing the congestion window. +// +// The pacer permits the bucket to become negative, and permits +// sending when non-negative. This biases slightly in favor of +// sending packets over limiting them, and permits bursts one +// packet greater than the configured maximum, but permits the pacer +// to be ignorant of the maximum packet size. +// +// https://www.rfc-editor.org/rfc/rfc9002.html#section-7.7 +type pacerState struct { + bucket int // measured in bytes + maxBucket int + timerGranularity time.Duration + lastUpdate time.Time + nextSend time.Time +} + +func (p *pacerState) init(now time.Time, maxBurst int, timerGranularity time.Duration) { + // Bucket is limited to maximum burst size, which is the initial congestion window. + // https://www.rfc-editor.org/rfc/rfc9002#section-7.7-2 + p.maxBucket = maxBurst + p.bucket = p.maxBucket + p.timerGranularity = timerGranularity + p.lastUpdate = now + p.nextSend = now +} + +// pacerBytesForInterval returns the number of bytes permitted over an interval. +// +// rate = 1.25 * congestion_window / smoothed_rtt +// bytes = interval * rate +// +// https://www.rfc-editor.org/rfc/rfc9002#section-7.7-6 +func pacerBytesForInterval(interval time.Duration, congestionWindow int, rtt time.Duration) int { + bytes := (int64(interval) * int64(congestionWindow)) / int64(rtt) + bytes = (bytes * 5) / 4 // bytes *= 1.25 + return int(bytes) +} + +// pacerIntervalForBytes returns the amount of time required for a number of bytes. +// +// time_per_byte = (smoothed_rtt / congestion_window) / 1.25 +// interval = time_per_byte * bytes +// +// https://www.rfc-editor.org/rfc/rfc9002#section-7.7-8 +func pacerIntervalForBytes(bytes int, congestionWindow int, rtt time.Duration) time.Duration { + interval := (int64(rtt) * int64(bytes)) / int64(congestionWindow) + interval = (interval * 4) / 5 // interval /= 1.25 + return time.Duration(interval) +} + +// advance is called when time passes. +func (p *pacerState) advance(now time.Time, congestionWindow int, rtt time.Duration) { + elapsed := now.Sub(p.lastUpdate) + if elapsed < 0 { + // Time has gone backward? + elapsed = 0 + p.nextSend = now // allow a packet through to get back on track + if p.bucket < 0 { + p.bucket = 0 + } + } + p.lastUpdate = now + if rtt == 0 { + // Avoid divide by zero in the implausible case that we measure no RTT. + p.bucket = p.maxBucket + return + } + // Refill the bucket. + delta := pacerBytesForInterval(elapsed, congestionWindow, rtt) + p.bucket = min(p.bucket+delta, p.maxBucket) +} + +// packetSent is called to record transmission of a packet. +func (p *pacerState) packetSent(now time.Time, size, congestionWindow int, rtt time.Duration) { + p.bucket -= size + if p.bucket < -congestionWindow { + // Never allow the bucket to fall more than one congestion window in arrears. + // We can only fall this far behind if the sender is sending unpaced packets, + // the congestion window has been exceeded, or the RTT is less than the + // timer granularity. + // + // Limiting the minimum bucket size limits the maximum pacer delay + // to RTT/1.25. + p.bucket = -congestionWindow + } + if p.bucket >= 0 { + p.nextSend = now + return + } + // Next send occurs when the bucket has refilled to 0. + delay := pacerIntervalForBytes(-p.bucket, congestionWindow, rtt) + p.nextSend = now.Add(delay) +} + +// canSend reports whether a packet can be sent now. +// If it returns false, next is the time when the next packet can be sent. +func (p *pacerState) canSend(now time.Time) (canSend bool, next time.Time) { + // If the next send time is within the timer granularity, send immediately. + if p.nextSend.After(now.Add(p.timerGranularity)) { + return false, p.nextSend + } + return true, time.Time{} +} diff --git a/internal/quic/pacer_test.go b/internal/quic/pacer_test.go new file mode 100644 index 000000000..9c69da038 --- /dev/null +++ b/internal/quic/pacer_test.go @@ -0,0 +1,224 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "testing" + "time" +) + +func TestPacerStartup(t *testing.T) { + p := &pacerTest{ + cwnd: 10000, + rtt: 100 * time.Millisecond, + timerGranularity: 1 * time.Millisecond, + } + p.init(t) + t.Logf("# initial burst permits sending ten packets") + for i := 0; i < 10; i++ { + p.sendPacket(1000) + } + + t.Logf("# empty bucket allows for one more packet") + p.sendPacket(1000) + + t.Logf("# sending 1000 byte packets with 8ms interval:") + t.Logf("# (smoothed_rtt * packet_size / congestion_window) / 1.25") + t.Logf("# (100ms * 1000 / 10000) / 1.25 = 8ms") + p.wantSendDelay(8 * time.Millisecond) + p.advance(8 * time.Millisecond) + p.sendPacket(1000) + p.wantSendDelay(8 * time.Millisecond) + + t.Logf("# accumulate enough window for two packets") + p.advance(16 * time.Millisecond) + p.sendPacket(1000) + p.sendPacket(1000) + p.wantSendDelay(8 * time.Millisecond) + + t.Logf("# window does not grow to more than burst limit") + p.advance(1 * time.Second) + for i := 0; i < 11; i++ { + p.sendPacket(1000) + } + p.wantSendDelay(8 * time.Millisecond) +} + +func TestPacerTimerGranularity(t *testing.T) { + p := &pacerTest{ + cwnd: 10000, + rtt: 100 * time.Millisecond, + timerGranularity: 1 * time.Millisecond, + } + p.init(t) + t.Logf("# consume initial burst") + for i := 0; i < 11; i++ { + p.sendPacket(1000) + } + p.wantSendDelay(8 * time.Millisecond) + + t.Logf("# small advance in time does not permit sending") + p.advance(4 * time.Millisecond) + p.wantSendDelay(4 * time.Millisecond) + + t.Logf("# advancing to within timerGranularity of next send permits send") + p.advance(3 * time.Millisecond) + p.wantSendDelay(0) + + t.Logf("# early send adds skipped delay (1ms) to next send (8ms)") + p.sendPacket(1000) + p.wantSendDelay(9 * time.Millisecond) +} + +func TestPacerChangingRate(t *testing.T) { + p := &pacerTest{ + cwnd: 10000, + rtt: 100 * time.Millisecond, + timerGranularity: 0, + } + p.init(t) + t.Logf("# consume initial burst") + for i := 0; i < 11; i++ { + p.sendPacket(1000) + } + p.wantSendDelay(8 * time.Millisecond) + p.advance(8 * time.Millisecond) + + t.Logf("# set congestion window to 20000, 1000 byte interval is 4ms") + p.cwnd = 20000 + p.sendPacket(1000) + p.wantSendDelay(4 * time.Millisecond) + p.advance(4 * time.Millisecond) + + t.Logf("# set rtt to 200ms, 1000 byte interval is 8ms") + p.rtt = 200 * time.Millisecond + p.sendPacket(1000) + p.wantSendDelay(8 * time.Millisecond) + p.advance(8 * time.Millisecond) + + t.Logf("# set congestion window to 40000, 1000 byte interval is 4ms") + p.cwnd = 40000 + p.advance(8 * time.Millisecond) + p.sendPacket(1000) + p.sendPacket(1000) + p.sendPacket(1000) + p.wantSendDelay(4 * time.Millisecond) +} + +func TestPacerTimeReverses(t *testing.T) { + p := &pacerTest{ + cwnd: 10000, + rtt: 100 * time.Millisecond, + timerGranularity: 0, + } + p.init(t) + t.Logf("# consume initial burst") + for i := 0; i < 11; i++ { + p.sendPacket(1000) + } + p.wantSendDelay(8 * time.Millisecond) + t.Logf("# reverse time") + p.advance(-4 * time.Millisecond) + p.sendPacket(1000) + p.wantSendDelay(8 * time.Millisecond) + p.advance(8 * time.Millisecond) + p.sendPacket(1000) + p.wantSendDelay(8 * time.Millisecond) +} + +func TestPacerZeroRTT(t *testing.T) { + p := &pacerTest{ + cwnd: 10000, + rtt: 0, + timerGranularity: 0, + } + p.init(t) + t.Logf("# with rtt 0, the pacer does not limit sending") + for i := 0; i < 20; i++ { + p.sendPacket(1000) + } + p.advance(1 * time.Second) + for i := 0; i < 20; i++ { + p.sendPacket(1000) + } +} + +func TestPacerZeroCongestionWindow(t *testing.T) { + p := &pacerTest{ + cwnd: 10000, + rtt: 100 * time.Millisecond, + timerGranularity: 0, + } + p.init(t) + p.cwnd = 0 + t.Logf("# with cwnd 0, the pacer does not limit sending") + for i := 0; i < 20; i++ { + p.sendPacket(1000) + } +} + +type pacerTest struct { + t *testing.T + p pacerState + timerGranularity time.Duration + cwnd int + rtt time.Duration + now time.Time +} + +func newPacerTest(t *testing.T, congestionWindow int, rtt time.Duration) *pacerTest { + p := &pacerTest{ + now: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), + cwnd: congestionWindow, + rtt: rtt, + } + p.p.init(p.now, congestionWindow, p.timerGranularity) + return p +} + +func (p *pacerTest) init(t *testing.T) { + p.t = t + p.now = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + p.p.init(p.now, p.cwnd, p.timerGranularity) + t.Logf("# initial congestion window: %v", p.cwnd) + t.Logf("# timer granularity: %v", p.timerGranularity) +} + +func (p *pacerTest) advance(d time.Duration) { + p.t.Logf("advance time %v", d) + p.now = p.now.Add(d) + p.p.advance(p.now, p.cwnd, p.rtt) +} + +func (p *pacerTest) sendPacket(size int) { + if canSend, next := p.p.canSend(p.now); !canSend { + p.t.Fatalf("ERROR: pacer unexpectedly blocked send, delay=%v", next.Sub(p.now)) + } + p.t.Logf("send packet of size %v", size) + p.p.packetSent(p.now, size, p.cwnd, p.rtt) +} + +func (p *pacerTest) wantSendDelay(want time.Duration) { + wantCanSend := want == 0 + gotCanSend, next := p.p.canSend(p.now) + var got time.Duration + if !gotCanSend { + got = next.Sub(p.now) + } + p.t.Logf("# pacer send delay: %v", got) + if got != want || gotCanSend != wantCanSend { + p.t.Fatalf("ERROR: pacer send delay = %v (can send: %v); want %v, %v", got, gotCanSend, want, wantCanSend) + } +} + +func (p *pacerTest) sendDelay() time.Duration { + canSend, next := p.p.canSend(p.now) + if canSend { + return 0 + } + return next.Sub(p.now) +} From 6c96ca5daff89298060438c3b5d24e1bd0900a52 Mon Sep 17 00:00:00 2001 From: Gopher Robot Date: Tue, 13 Jun 2023 11:08:24 +0000 Subject: [PATCH 28/28] go.mod: update golang.org/x dependencies Update golang.org/x dependencies to their latest tagged versions. Once this CL is submitted, and post-submit testing succeeds on all first-class ports across all supported Go versions, this repository will be tagged with its next minor version. Change-Id: Ifc35b03aeb994b74293ca0b2a4c79940cff8a66c Reviewed-on: https://go-review.googlesource.com/c/net/+/502795 Run-TryBot: Gopher Robot Auto-Submit: Gopher Robot Reviewed-by: Dmitri Shuralyov Reviewed-by: Carlos Amedee TryBot-Result: Gopher Robot Reviewed-by: Dmitri Shuralyov --- go.mod | 8 ++++---- go.sum | 14 ++++++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/go.mod b/go.mod index 43451b0da..2360e1940 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module golang.org/x/net go 1.17 require ( - golang.org/x/crypto v0.9.0 - golang.org/x/sys v0.8.1-0.20230606214330-304f187cdba9 - golang.org/x/term v0.8.0 - golang.org/x/text v0.9.0 + golang.org/x/crypto v0.10.0 + golang.org/x/sys v0.9.0 + golang.org/x/term v0.9.0 + golang.org/x/text v0.10.0 ) diff --git a/go.sum b/go.sum index 666172d8b..730429480 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= -golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= +golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -20,19 +20,21 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.1-0.20230606214330-304f187cdba9 h1:M57WSsHXpJhTgA9/YbQISwowFNh+0HsWrZQt6NZf1go= -golang.org/x/sys v0.8.1-0.20230606214330-304f187cdba9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= +golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.9.0 h1:GRRCnKYhdQrD8kfRAdQ6Zcw1P0OcELxGLKJvtjVMZ28= +golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58= +golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=