From 824fd8a2f2eb8a08fe6cef7a693fee6be3819e01 Mon Sep 17 00:00:00 2001 From: WeidiDeng Date: Mon, 21 Aug 2023 11:31:22 +0800 Subject: [PATCH] http3: automatically add content-length for small responses (#3989) * response writer: add content-length automatically when response is small enough and doesn't call Flush * fix comment * add integration test * Update http3/response_writer.go * Update integrationtests/self/http_test.go --------- Co-authored-by: Marten Seemann --- http3/response_writer.go | 97 ++++++++++++++++++++---------- http3/server.go | 8 ++- http3/server_test.go | 41 +++++++++++++ integrationtests/self/http_test.go | 12 ++++ 4 files changed, 126 insertions(+), 32 deletions(-) diff --git a/http3/response_writer.go b/http3/response_writer.go index 3d9271858d6..90a30497ae6 100644 --- a/http3/response_writer.go +++ b/http3/response_writer.go @@ -15,19 +15,61 @@ import ( "github.com/quic-go/qpack" ) +// The maximum length of an encoded HTTP/3 frame header is 16: +// The frame has a type and length field, both QUIC varints (maximum 8 bytes in length) +const frameHeaderLen = 16 + +// headerWriter wraps the stream, so that the first Write call flushes the header to the stream +type headerWriter struct { + str quic.Stream + header http.Header + status int // status code passed to WriteHeader + written bool + + logger utils.Logger +} + +// writeHeader encodes and flush header to the stream +func (hw *headerWriter) writeHeader() error { + var headers bytes.Buffer + enc := qpack.NewEncoder(&headers) + enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(hw.status)}) + + for k, v := range hw.header { + for index := range v { + enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}) + } + } + + buf := make([]byte, 0, frameHeaderLen+headers.Len()) + buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf) + hw.logger.Infof("Responding with %d", hw.status) + buf = append(buf, headers.Bytes()...) + + _, err := hw.str.Write(buf) + return err +} + +// first Write will trigger flushing header +func (hw *headerWriter) Write(p []byte) (int, error) { + if !hw.written { + if err := hw.writeHeader(); err != nil { + return 0, err + } + hw.written = true + } + return hw.str.Write(p) +} + type responseWriter struct { + *headerWriter conn quic.Connection - str quic.Stream bufferedStr *bufio.Writer buf []byte - header http.Header - status int // status code passed to WriteHeader headerWritten bool contentLen int64 // if handler set valid Content-Length header numWritten int64 // bytes written - - logger utils.Logger } var ( @@ -37,13 +79,16 @@ var ( ) func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter { + hw := &headerWriter{ + str: str, + header: http.Header{}, + logger: logger, + } return &responseWriter{ - header: http.Header{}, - buf: make([]byte, 16), - conn: conn, - str: str, - bufferedStr: bufio.NewWriter(str), - logger: logger, + headerWriter: hw, + buf: make([]byte, frameHeaderLen), + conn: conn, + bufferedStr: bufio.NewWriter(hw), } } @@ -83,27 +128,8 @@ func (w *responseWriter) WriteHeader(status int) { } w.status = status - var headers bytes.Buffer - enc := qpack.NewEncoder(&headers) - enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) - - for k, v := range w.header { - for index := range v { - enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}) - } - } - - w.buf = w.buf[:0] - w.buf = (&headersFrame{Length: uint64(headers.Len())}).Append(w.buf) - w.logger.Infof("Responding with %d", status) - if _, err := w.bufferedStr.Write(w.buf); err != nil { - w.logger.Errorf("could not write headers frame: %s", err.Error()) - } - if _, err := w.bufferedStr.Write(headers.Bytes()); err != nil { - w.logger.Errorf("could not write header frame payload: %s", err.Error()) - } if !w.headerWritten { - w.Flush() + w.writeHeader() } } @@ -146,6 +172,15 @@ func (w *responseWriter) Write(p []byte) (int, error) { } func (w *responseWriter) FlushError() error { + if !w.headerWritten { + w.WriteHeader(http.StatusOK) + } + if !w.written { + if err := w.writeHeader(); err != nil { + return err + } + w.written = true + } return w.bufferedStr.Flush() } diff --git a/http3/server.go b/http3/server.go index 8d39f3749da..4587a1fca77 100644 --- a/http3/server.go +++ b/http3/server.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "runtime" + "strconv" "strings" "sync" "time" @@ -627,7 +628,12 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q // only write response when there is no panic if !panicked { - r.WriteHeader(http.StatusOK) + // response not written to the client yet, set Content-Length + if !r.written { + if _, haveCL := r.header["Content-Length"]; !haveCL { + r.header.Set("Content-Length", strconv.FormatInt(r.numWritten, 10)) + } + } r.Flush() } // If the EOF was read by the handler, CancelRead() is a no-op. diff --git a/http3/server_test.go b/http3/server_test.go index f180ef0d5c2..67713b922f7 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -180,6 +180,47 @@ var _ = Describe("Server", func() { Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) }) + It("sets Content-Length when the handler doesn't flush to the client", func() { + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("foobar")) + }) + + responseBuf := &bytes.Buffer{} + setRequest(encodeRequest(exampleGetRequest)) + str.EXPECT().Context().Return(reqContext) + str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() + str.EXPECT().CancelRead(gomock.Any()) + + serr := s.handleRequest(conn, str, qpackDecoder, nil) + Expect(serr.err).ToNot(HaveOccurred()) + hfs := decodeHeader(responseBuf) + Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) + Expect(hfs).To(HaveKeyWithValue("content-length", []string{"6"})) + // status, content-length, date, content-type + Expect(hfs).To(HaveLen(4)) + }) + + It("not sets Content-Length when the handler flushes to the client", func() { + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("foobar")) + // force flush + w.(http.Flusher).Flush() + }) + + responseBuf := &bytes.Buffer{} + setRequest(encodeRequest(exampleGetRequest)) + str.EXPECT().Context().Return(reqContext) + str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() + str.EXPECT().CancelRead(gomock.Any()) + + serr := s.handleRequest(conn, str, qpackDecoder, nil) + Expect(serr.err).ToNot(HaveOccurred()) + hfs := decodeHeader(responseBuf) + Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) + // status, date, content-type + Expect(hfs).To(HaveLen(3)) + }) + It("handles a aborting handler", func() { s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { panic(http.ErrAbortHandler) diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index b647d4d9c00..a1067070275 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -128,6 +128,18 @@ var _ = Describe("HTTP tests", func() { Expect(string(body)).To(Equal("Hello, World!\n")) }) + It("sets content-length for small response", func() { + mux.HandleFunc("/small", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + w.Write([]byte("foobar")) + }) + + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/small", port)) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + Expect(resp.Header.Get("Content-Length")).To(Equal(strconv.Itoa(len("foobar")))) + }) + It("requests to different servers with the same udpconn", func() { resp, err := client.Get(fmt.Sprintf("https://localhost:%d/remoteAddr", port)) Expect(err).ToNot(HaveOccurred())