Skip to content

Commit

Permalink
http3: automatically add content-length for small responses (#3989)
Browse files Browse the repository at this point in the history
* response writer: add content-length automatically when response is small enough and doesn't call Flush

* fix comment

* add integration test

* Update http3/response_writer.go

* Update integrationtests/self/http_test.go

---------

Co-authored-by: Marten Seemann <[email protected]>
  • Loading branch information
WeidiDeng and marten-seemann committed Aug 21, 2023
1 parent ced65c0 commit 824fd8a
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 32 deletions.
97 changes: 66 additions & 31 deletions http3/response_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,61 @@ import (
"github.com/quic-go/qpack"
)

// The maximum length of an encoded HTTP/3 frame header is 16:
// The frame has a type and length field, both QUIC varints (maximum 8 bytes in length)
const frameHeaderLen = 16

// headerWriter wraps the stream, so that the first Write call flushes the header to the stream
type headerWriter struct {
str quic.Stream
header http.Header
status int // status code passed to WriteHeader
written bool

logger utils.Logger
}

// writeHeader encodes and flush header to the stream
func (hw *headerWriter) writeHeader() error {
var headers bytes.Buffer
enc := qpack.NewEncoder(&headers)
enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(hw.status)})

for k, v := range hw.header {
for index := range v {
enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
}
}

buf := make([]byte, 0, frameHeaderLen+headers.Len())
buf = (&headersFrame{Length: uint64(headers.Len())}).Append(buf)
hw.logger.Infof("Responding with %d", hw.status)
buf = append(buf, headers.Bytes()...)

_, err := hw.str.Write(buf)
return err
}

// first Write will trigger flushing header
func (hw *headerWriter) Write(p []byte) (int, error) {
if !hw.written {
if err := hw.writeHeader(); err != nil {
return 0, err
}
hw.written = true
}
return hw.str.Write(p)
}

type responseWriter struct {
*headerWriter
conn quic.Connection
str quic.Stream
bufferedStr *bufio.Writer
buf []byte

header http.Header
status int // status code passed to WriteHeader
headerWritten bool
contentLen int64 // if handler set valid Content-Length header
numWritten int64 // bytes written

logger utils.Logger
}

var (
Expand All @@ -37,13 +79,16 @@ var (
)

func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter {
hw := &headerWriter{
str: str,
header: http.Header{},
logger: logger,
}
return &responseWriter{
header: http.Header{},
buf: make([]byte, 16),
conn: conn,
str: str,
bufferedStr: bufio.NewWriter(str),
logger: logger,
headerWriter: hw,
buf: make([]byte, frameHeaderLen),
conn: conn,
bufferedStr: bufio.NewWriter(hw),
}
}

Expand Down Expand Up @@ -83,27 +128,8 @@ func (w *responseWriter) WriteHeader(status int) {
}
w.status = status

var headers bytes.Buffer
enc := qpack.NewEncoder(&headers)
enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})

for k, v := range w.header {
for index := range v {
enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
}
}

w.buf = w.buf[:0]
w.buf = (&headersFrame{Length: uint64(headers.Len())}).Append(w.buf)
w.logger.Infof("Responding with %d", status)
if _, err := w.bufferedStr.Write(w.buf); err != nil {
w.logger.Errorf("could not write headers frame: %s", err.Error())
}
if _, err := w.bufferedStr.Write(headers.Bytes()); err != nil {
w.logger.Errorf("could not write header frame payload: %s", err.Error())
}
if !w.headerWritten {
w.Flush()
w.writeHeader()
}
}

Expand Down Expand Up @@ -146,6 +172,15 @@ func (w *responseWriter) Write(p []byte) (int, error) {
}

func (w *responseWriter) FlushError() error {
if !w.headerWritten {
w.WriteHeader(http.StatusOK)
}
if !w.written {
if err := w.writeHeader(); err != nil {
return err
}
w.written = true
}
return w.bufferedStr.Flush()
}

Expand Down
8 changes: 7 additions & 1 deletion http3/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net"
"net/http"
"runtime"
"strconv"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -627,7 +628,12 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q

// only write response when there is no panic
if !panicked {
r.WriteHeader(http.StatusOK)
// response not written to the client yet, set Content-Length
if !r.written {
if _, haveCL := r.header["Content-Length"]; !haveCL {
r.header.Set("Content-Length", strconv.FormatInt(r.numWritten, 10))
}
}
r.Flush()
}
// If the EOF was read by the handler, CancelRead() is a no-op.
Expand Down
41 changes: 41 additions & 0 deletions http3/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,47 @@ var _ = Describe("Server", func() {
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
})

It("sets Content-Length when the handler doesn't flush to the client", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("foobar"))
})

responseBuf := &bytes.Buffer{}
setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())

serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
Expect(hfs).To(HaveKeyWithValue("content-length", []string{"6"}))
// status, content-length, date, content-type
Expect(hfs).To(HaveLen(4))
})

It("not sets Content-Length when the handler flushes to the client", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("foobar"))
// force flush
w.(http.Flusher).Flush()
})

responseBuf := &bytes.Buffer{}
setRequest(encodeRequest(exampleGetRequest))
str.EXPECT().Context().Return(reqContext)
str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes()
str.EXPECT().CancelRead(gomock.Any())

serr := s.handleRequest(conn, str, qpackDecoder, nil)
Expect(serr.err).ToNot(HaveOccurred())
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
// status, date, content-type
Expect(hfs).To(HaveLen(3))
})

It("handles a aborting handler", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic(http.ErrAbortHandler)
Expand Down
12 changes: 12 additions & 0 deletions integrationtests/self/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,18 @@ var _ = Describe("HTTP tests", func() {
Expect(string(body)).To(Equal("Hello, World!\n"))
})

It("sets content-length for small response", func() {
mux.HandleFunc("/small", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
w.Write([]byte("foobar"))
})

resp, err := client.Get(fmt.Sprintf("https://localhost:%d/small", port))
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
Expect(resp.Header.Get("Content-Length")).To(Equal(strconv.Itoa(len("foobar"))))
})

It("requests to different servers with the same udpconn", func() {
resp, err := client.Get(fmt.Sprintf("https://localhost:%d/remoteAddr", port))
Expect(err).ToNot(HaveOccurred())
Expand Down

0 comments on commit 824fd8a

Please sign in to comment.