Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

http3: simplify buffering of small responses #4432

Merged
merged 1 commit into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions http3/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func encodeResponse(status int) []byte {
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
rw := newResponseWriter(rstr, nil, false, utils.DefaultLogger)
rw.WriteHeader(status)
rw.Flush()
return buf.Bytes()
Expand Down Expand Up @@ -738,7 +738,7 @@ var _ = Describe("Client", func() {
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
rw := newResponseWriter(rstr, nil, false, utils.DefaultLogger)
rw.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(rw)
gz.Write([]byte("gzipped response"))
Expand All @@ -764,7 +764,7 @@ var _ = Describe("Client", func() {
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
rw := newResponseWriter(rstr, nil, false, utils.DefaultLogger)
rw.Write([]byte("not gzipped"))
rw.Flush()
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
Expand Down
155 changes: 92 additions & 63 deletions http3/response_writer.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package http3

import (
"bufio"
"bytes"
"fmt"
"net/http"
Expand All @@ -21,10 +20,9 @@ const frameHeaderLen = 16

// headerWriter wraps the stream, so that the first Write call flushes the header to the stream
type headerWriter struct {
str quic.Stream
header http.Header
status int // status code passed to WriteHeader
written bool
str quic.Stream
header http.Header
status int // status code passed to WriteHeader

logger utils.Logger
}
Expand All @@ -33,11 +31,15 @@ type headerWriter struct {
func (hw *headerWriter) writeHeader() error {
var headers bytes.Buffer
enc := qpack.NewEncoder(&headers)
enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(hw.status)})
if err := enc.WriteField(qpack.HeaderField{Name: ":status", Value: strconv.Itoa(hw.status)}); err != nil {
return err
}

for k, v := range hw.header {
for index := range v {
enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
if err := enc.WriteField(qpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}); err != nil {
return err
}
}
}

Expand All @@ -50,27 +52,22 @@ func (hw *headerWriter) writeHeader() error {
return err
}

// first Write will trigger flushing header
func (hw *headerWriter) Write(p []byte) (int, error) {
if !hw.written {
if err := hw.writeHeader(); err != nil {
return 0, err
}
hw.written = true
}
return hw.str.Write(p)
}
const maxSmallResponseSize = 4096

type responseWriter struct {
*headerWriter
conn Connection
bufferedStr *bufio.Writer
buf []byte

contentLen int64 // if handler set valid Content-Length header
numWritten int64 // bytes written
headerWritten bool
isHead bool
conn Connection
buf []byte

// for responses smaller than maxSmallResponseSize, we buffer calls to Write,
// and automatically add the Content-Length header
smallResponseBuf []byte

contentLen int64 // if handler set valid Content-Length header
numWritten int64 // bytes written
headerComplete bool // set once WriteHeader is called with a status code >= 200
headerWritten bool // set once the response header has been serialized to the stream
isHead bool
}

var (
Expand All @@ -79,7 +76,7 @@ var (
_ Hijacker = &responseWriter{}
)

func newResponseWriter(str quic.Stream, conn Connection, logger utils.Logger) *responseWriter {
func newResponseWriter(str quic.Stream, conn Connection, isHead bool, logger utils.Logger) *responseWriter {
hw := &headerWriter{
str: str,
header: http.Header{},
Expand All @@ -89,7 +86,7 @@ func newResponseWriter(str quic.Stream, conn Connection, logger utils.Logger) *r
headerWriter: hw,
buf: make([]byte, frameHeaderLen),
conn: conn,
bufferedStr: bufio.NewWriter(hw),
isHead: isHead,
}
}

Expand All @@ -98,45 +95,46 @@ func (w *responseWriter) Header() http.Header {
}

func (w *responseWriter) WriteHeader(status int) {
if w.headerWritten {
if w.headerComplete {
return
}

// http status must be 3 digits
if status < 100 || status > 999 {
panic(fmt.Sprintf("invalid WriteHeader code %v", status))
}

if status >= 200 {
w.headerWritten = true
// Add Date header.
// This is what the standard library does.
// Can be disabled by setting the Date header to nil.
if _, ok := w.header["Date"]; !ok {
w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat))
}
// Content-Length checking
// use ParseUint instead of ParseInt, as negative values are invalid
if clen := w.header.Get("Content-Length"); clen != "" {
if cl, err := strconv.ParseUint(clen, 10, 63); err == nil {
w.contentLen = int64(cl)
} else {
// emit a warning for malformed Content-Length and remove it
w.logger.Errorf("Malformed Content-Length %s", clen)
w.header.Del("Content-Length")
}
}
}
w.status = status

if !w.headerWritten {
// immediately write 1xx headers
if status < 200 {
w.writeHeader()
return
}

// We're done with headers once we write a status >= 200.
w.headerComplete = true
// Add Date header.
// This is what the standard library does.
// Can be disabled by setting the Date header to nil.
if _, ok := w.header["Date"]; !ok {
w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat))
}
// Content-Length checking
// use ParseUint instead of ParseInt, as negative values are invalid
if clen := w.header.Get("Content-Length"); clen != "" {
if cl, err := strconv.ParseUint(clen, 10, 63); err == nil {
w.contentLen = int64(cl)
} else {
// emit a warning for malformed Content-Length and remove it
w.logger.Errorf("Malformed Content-Length %s", clen)
w.header.Del("Content-Length")
}
}
}

func (w *responseWriter) Write(p []byte) (int, error) {
bodyAllowed := bodyAllowedForStatus(w.status)
if !w.headerWritten {
if !w.headerComplete {
// If body is not allowed, we don't need to (and we can't) sniff the content type.
if bodyAllowed {
// If no content type, apply sniffing algorithm to body.
Expand Down Expand Up @@ -167,27 +165,58 @@ func (w *responseWriter) Write(p []byte) (int, error) {
return len(p), nil
}

df := &dataFrame{Length: uint64(len(p))}
if !w.headerWritten {
// Buffer small responses.
// This allows us to automatically set the Content-Length field.
if len(w.smallResponseBuf)+len(p) < maxSmallResponseSize {
w.smallResponseBuf = append(w.smallResponseBuf, p...)
return len(p), nil
}
}
return w.doWrite(p)
}

func (w *responseWriter) doWrite(p []byte) (int, error) {
if !w.headerWritten {
if err := w.writeHeader(); err != nil {
return 0, maybeReplaceError(err)
}
w.headerWritten = true
}

l := uint64(len(w.smallResponseBuf) + len(p))
if l == 0 {
return 0, nil
}
df := &dataFrame{Length: l}
w.buf = w.buf[:0]
w.buf = df.Append(w.buf)
if _, err := w.bufferedStr.Write(w.buf); err != nil {
if _, err := w.str.Write(w.buf); err != nil {
return 0, maybeReplaceError(err)
}
n, err := w.bufferedStr.Write(p)
return n, maybeReplaceError(err)
if len(w.smallResponseBuf) > 0 {
if _, err := w.str.Write(w.smallResponseBuf); err != nil {
return 0, maybeReplaceError(err)
}
w.smallResponseBuf = nil
}
var n int
if len(p) > 0 {
var err error
n, err = w.str.Write(p)
if err != nil {
return n, maybeReplaceError(err)
}
}
return n, nil
}

func (w *responseWriter) FlushError() error {
if !w.headerWritten {
if !w.headerComplete {
w.WriteHeader(http.StatusOK)
}
if !w.written {
if err := w.writeHeader(); err != nil {
return maybeReplaceError(err)
}
w.written = true
}
return w.bufferedStr.Flush()
_, err := w.doWrite(nil)
return err
}

func (w *responseWriter) Flush() {
Expand Down
2 changes: 1 addition & 1 deletion http3/response_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ var _ = Describe("Response Writer", func() {
str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes()
str.EXPECT().SetReadDeadline(gomock.Any()).Return(nil).AnyTimes()
str.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).AnyTimes()
rw = newResponseWriter(str, nil, utils.DefaultLogger)
rw = newResponseWriter(str, nil, false, utils.DefaultLogger)
})

decodeHeader := func(str io.Reader) map[string][]string {
Expand Down
7 changes: 2 additions & 5 deletions http3/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -554,10 +554,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack
}
}
req = req.WithContext(ctx)
r := newResponseWriter(str, conn, s.logger)
if req.Method == http.MethodHead {
r.isHead = true
}
r := newResponseWriter(str, conn, req.Method == http.MethodHead, s.logger)
handler := s.Handler
if handler == nil {
handler = http.DefaultServeMux
Expand Down Expand Up @@ -588,7 +585,7 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack
// only write response when there is no panic
if !panicked {
// response not written to the client yet, set Content-Length
if !r.written {
if !r.headerWritten {
if _, haveCL := r.header["Content-Length"]; !haveCL {
r.header.Set("Content-Length", strconv.FormatInt(r.numWritten, 10))
}
Expand Down
8 changes: 4 additions & 4 deletions http3/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,12 @@ var _ = Describe("Server", func() {
Expect(hfs).To(HaveLen(3))
})

It("response to HEAD request should not have body", func() {
It("ignores calls to Write for responses to HEAD requests", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("foobar"))
})

headRequest, err := http.NewRequest("HEAD", "https://www.example.com", nil)
headRequest, err := http.NewRequest(http.MethodHead, "https://www.example.com", nil)
Expect(err).ToNot(HaveOccurred())
responseBuf := &bytes.Buffer{}
setRequest(encodeRequest(headRequest))
Expand All @@ -245,15 +245,15 @@ var _ = Describe("Server", func() {
s.handleRequest(conn, str, qpackDecoder)
hfs := decodeHeader(responseBuf)
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
Expect(responseBuf.Bytes()).To(HaveLen(0))
Expect(responseBuf.Bytes()).To(BeEmpty())
})

It("response to HEAD request should also do content sniffing", func() {
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("<html></html>"))
})

headRequest, err := http.NewRequest("HEAD", "https://www.example.com", nil)
headRequest, err := http.NewRequest(http.MethodHead, "https://www.example.com", nil)
Expect(err).ToNot(HaveOccurred())
responseBuf := &bytes.Buffer{}
setRequest(encodeRequest(headRequest))
Expand Down
5 changes: 3 additions & 2 deletions integrationtests/self/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,14 @@ var _ = Describe("HTTP tests", func() {
It("sets content-length for small response", func() {
mux.HandleFunc("/small", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
w.Write([]byte("foobar"))
w.Write([]byte("foo"))
w.Write([]byte("bar"))
})

resp, err := client.Get(fmt.Sprintf("https://localhost:%d/small", port))
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
Expect(resp.Header.Get("Content-Length")).To(Equal(strconv.Itoa(len("foobar"))))
Expect(resp.Header.Get("Content-Length")).To(Equal("6"))
})

It("detects stream errors when server panics when writing response", func() {
Expand Down
Loading