Skip to content

Commit

Permalink
http3: simplify buffering of small responses (#4432)
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Apr 13, 2024
1 parent 857c31d commit 90627f6
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 78 deletions.
6 changes: 3 additions & 3 deletions http3/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func encodeResponse(status int) []byte {
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
rw := newResponseWriter(rstr, nil, false, utils.DefaultLogger)
rw.WriteHeader(status)
rw.Flush()
return buf.Bytes()
Expand Down Expand Up @@ -738,7 +738,7 @@ var _ = Describe("Client", func() {
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
rw := newResponseWriter(rstr, nil, false, utils.DefaultLogger)
rw.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(rw)
gz.Write([]byte("gzipped response"))
Expand All @@ -764,7 +764,7 @@ var _ = Describe("Client", func() {
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
rw := newResponseWriter(rstr, nil, false, utils.DefaultLogger)
rw.Write([]byte("not gzipped"))
rw.Flush()
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
Expand Down
155 changes: 92 additions & 63 deletions http3/response_writer.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package http3

import (
"bufio"
"bytes"
"fmt"
"net/http"
Expand All @@ -21,10 +20,9 @@ const frameHeaderLen = 16

// headerWriter wraps the stream, so that the first Write call flushes the header to the stream
type headerWriter struct {
str quic.Stream
header http.Header
status int // status code passed to WriteHeader
written bool
str quic.Stream
header http.Header
status int // status code passed to WriteHeader

logger utils.Logger
}
Expand All @@ -33,11 +31,15 @@ type headerWriter struct {
func (hw *headerWriter) writeHeader() error {
var headers bytes.Buffer
enc := qpack.NewEncoder(&headers)
enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(hw.status)})
if err := enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(hw.status)}); err != nil {
return err
}

for k, v := range hw.header {
for index := range v {
enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
if err := enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}); err != nil {
return err
}
}
}

Expand All @@ -50,27 +52,22 @@ func (hw *headerWriter) writeHeader() error {
return err
}

// first Write will trigger flushing header
func (hw *headerWriter) Write(p []byte) (int, error) {
if !hw.written {
if err := hw.writeHeader(); err != nil {
return 0, err
}
hw.written = true
}
return hw.str.Write(p)
}
const maxSmallResponseSize = 4096

type responseWriter struct {
*headerWriter
conn Connection
bufferedStr *bufio.Writer
buf []byte

contentLen int64 // if handler set valid Content-Length header
numWritten int64 // bytes written
headerWritten bool
isHead bool
conn Connection
buf []byte

// for responses smaller than maxSmallResponseSize, we buffer calls to Write,
// and automatically add the Content-Length header
smallResponseBuf []byte

contentLen int64 // if handler set valid Content-Length header
numWritten int64 // bytes written
headerComplete bool // set once WriteHeader is called with a status code >= 200
headerWritten bool // set once the response header has been serialized to the stream
isHead bool
}

var (
Expand All @@ -79,7 +76,7 @@ var (
_ Hijacker = &responseWriter{}
)

func newResponseWriter(str quic.Stream, conn Connection, logger utils.Logger) *responseWriter {
func newResponseWriter(str quic.Stream, conn Connection, isHead bool, logger utils.Logger) *responseWriter {
hw := &headerWriter{
str: str,
header: http.Header{},
Expand All @@ -89,7 +86,7 @@ func newResponseWriter(str quic.Stream, conn Connection, logger utils.Logger) *r
headerWriter: hw,
buf: make([]byte, frameHeaderLen),
conn: conn,
bufferedStr: bufio.NewWriter(hw),
isHead: isHead,
}
}

Expand All @@ -98,45 +95,46 @@ func (w *responseWriter) Header() http.Header {
}

func (w *responseWriter) WriteHeader(status int) {
if w.headerWritten {
if w.headerComplete {
return
}

// http status must be 3 digits
if status < 100 || status > 999 {
panic(fmt.Sprintf("invalid WriteHeader code %v", status))
}

if status >= 200 {
w.headerWritten = true
// Add Date header.
// This is what the standard library does.
// Can be disabled by setting the Date header to nil.
if _, ok := w.header["Date"]; !ok {
w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat))
}
// Content-Length checking
// use ParseUint instead of ParseInt, as negative values are invalid
if clen := w.header.Get("Content-Length"); clen != "" {
if cl, err := strconv.ParseUint(clen, 10, 63); err == nil {
w.contentLen = int64(cl)
} else {
// emit a warning for malformed Content-Length and remove it
w.logger.Errorf("Malformed Content-Length %s", clen)
w.header.Del("Content-Length")
}
}
}
w.status = status

if !w.headerWritten {
// immediately write 1xx headers
if status < 200 {
w.writeHeader()
return
}

// We're done with headers once we write a status >= 200.
w.headerComplete = true
// Add Date header.
// This is what the standard library does.
// Can be disabled by setting the Date header to nil.
if _, ok := w.header["Date"]; !ok {
w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat))
}
// Content-Length checking
// use ParseUint instead of ParseInt, as negative values are invalid
if clen := w.header.Get("Content-Length"); clen != "" {
if cl, err := strconv.ParseUint(clen, 10, 63); err == nil {
w.contentLen = int64(cl)
} else {
// emit a warning for malformed Content-Length and remove it
w.logger.Errorf("Malformed Content-Length %s", clen)
w.header.Del("Content-Length")
}
}
}

func (w *responseWriter) Write(p []byte) (int, error) {
bodyAllowed := bodyAllowedForStatus(w.status)
if !w.headerWritten {
if !w.headerComplete {
// If body is not allowed, we don't need to (and we can't) sniff the content type.
if bodyAllowed {
// If no content type, apply sniffing algorithm to body.
Expand Down Expand Up @@ -167,27 +165,58 @@ func (w *responseWriter) Write(p []byte) (int, error) {
return len(p), nil
}

df := &dataFrame{Length: uint64(len(p))}
if !w.headerWritten {
// Buffer small responses.
// This allows us to automatically set the Content-Length field.
if len(w.smallResponseBuf)+len(p) < maxSmallResponseSize {
w.smallResponseBuf = append(w.smallResponseBuf, p...)
return len(p), nil
}
}
return w.doWrite(p)
}

func (w *responseWriter) doWrite(p []byte) (int, error) {
if !w.headerWritten {
if err := w.writeHeader(); err != nil {
return 0, maybeReplaceError(err)
}
w.headerWritten = true
}

l := uint64(len(w.smallResponseBuf) + len(p))
if l == 0 {
return 0, nil
}
df := &dataFrame{Length: l}
w.buf = w.buf[:0]
w.buf = df.Append(w.buf)
if _, err := w.bufferedStr.Write(w.buf); err != nil {
if _, err := w.str.Write(w.buf); err != nil {
return 0, maybeReplaceError(err)
}
n, err := w.bufferedStr.Write(p)
return n, maybeReplaceError(err)
if len(w.smallResponseBuf) > 0 {
if _, err := w.str.Write(w.smallResponseBuf); err != nil {
return 0, maybeReplaceError(err)
}
w.smallResponseBuf = nil
}
var n int
if len(p) > 0 {
var err error
n, err = w.str.Write(p)
if err != nil {
return n, maybeReplaceError(err)
}
}
return n, nil
}

func (w *responseWriter) FlushError() error {
if !w.headerWritten {
if !w.headerComplete {
w.WriteHeader(http.StatusOK)
}
if !w.written {
if err := w.writeHeader(); err != nil {
return maybeReplaceError(err)
}
w.written = true
}
return w.bufferedStr.Flush()
_, err := w.doWrite(nil)
return err
}

func (w *responseWriter) Flush() {
Expand Down
2 changes: 1 addition & 1 deletion http3/response_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ var _ = Describe("Response Writer", func() {
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
str.EXPECT().SetReadDeadline(gomock.Any()).Return(nil).AnyTimes()
str.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).AnyTimes()
rw = newResponseWriter(str, nil, utils.DefaultLogger)
rw = newResponseWriter(str, nil, false, utils.DefaultLogger)
})

decodeHeader := func(str io.Reader) map[string][]string {
Expand Down
7 changes: 2 additions & 5 deletions http3/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -554,10 +554,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack
}
}
req = req.WithContext(ctx)
r := newResponseWriter(str, conn, s.logger)
if req.Method == http.MethodHead {
r.isHead = true
}
r := newResponseWriter(str, conn, req.Method == http.MethodHead, s.logger)
handler := s.Handler
if handler == nil {
handler = http.DefaultServeMux
Expand Down Expand Up @@ -588,7 +585,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack
// only write response when there is no panic
if !panicked {
// response not written to the client yet, set Content-Length
if !r.written {
if !r.headerWritten {
if _, haveCL := r.header["Content-Length"]; !haveCL {
r.header.Set("Content-Length", strconv.FormatInt(r.numWritten, 10))
}
Expand Down
8 changes: 4 additions & 4 deletions http3/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,12 @@ var _ = Describe("Server", func() {
Expect(hfs).To(HaveLen(3))
})

It("response to HEAD request should not have body", func() {
It("ignores calls to Write for responses to HEAD requests", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("foobar"))
})

headRequest, err := http.NewRequest("HEAD", "https://www.example.com", nil)
headRequest, err := http.NewRequest(http.MethodHead, "https://www.example.com", nil)
Expect(err).ToNot(HaveOccurred())
responseBuf := &bytes.Buffer{}
setRequest(encodeRequest(headRequest))
Expand All @@ -245,15 +245,15 @@ var _ = Describe("Server", func() {
s.handleRequest(conn, str, qpackDecoder)
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
Expect(responseBuf.Bytes()).To(HaveLen(0))
Expect(responseBuf.Bytes()).To(BeEmpty())
})

It("response to HEAD request should also do content sniffing", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("<html></html>"))
})

headRequest, err := http.NewRequest("HEAD", "https://www.example.com", nil)
headRequest, err := http.NewRequest(http.MethodHead, "https://www.example.com", nil)
Expect(err).ToNot(HaveOccurred())
responseBuf := &bytes.Buffer{}
setRequest(encodeRequest(headRequest))
Expand Down
5 changes: 3 additions & 2 deletions integrationtests/self/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,14 @@ var _ = Describe("HTTP tests", func() {
It("sets content-length for small response", func() {
mux.HandleFunc("/small", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
w.Write([]byte("foobar"))
w.Write([]byte("foo"))
w.Write([]byte("bar"))
})

resp, err := client.Get(fmt.Sprintf("https://localhost:%d/small", port))
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
Expect(resp.Header.Get("Content-Length")).To(Equal(strconv.Itoa(len("foobar"))))
Expect(resp.Header.Get("Content-Length")).To(Equal("6"))
})

It("detects stream errors when server panics when writing response", func() {
Expand Down

0 comments on commit 90627f6

Please sign in to comment.