From e2fbf3cdcd4b803d237ef9752f1a555b022f47a4 Mon Sep 17 00:00:00 2001 From: George Date: Sun, 19 May 2024 03:15:32 +0100 Subject: [PATCH] http3: fix memory leak in stream state tracking (#4523) * fix(http3): handle streamStateSendAndReceiveClosed in onStreamStateChange Signed-off-by: George MacRorie * refactor(http3): adjust stateTrackingStream to operate over streamClearer and errorSetter * test(http3): remove duplicate test case * chore(http3): rename test spies to be mocks --------- Signed-off-by: George MacRorie --- http3/conn.go | 24 +-- http3/datagram.go | 6 +- http3/datagram_test.go | 2 +- http3/state_tracking_stream.go | 80 +++++--- http3/state_tracking_stream_test.go | 273 ++++++++++++++++++++-------- 5 files changed, 258 insertions(+), 127 deletions(-) diff --git a/http3/conn.go b/http3/conn.go index e411e233726..7ea4b292918 100644 --- a/http3/conn.go +++ b/http3/conn.go @@ -70,27 +70,11 @@ func newConnection( return c } -func (c *connection) onStreamStateChange(id quic.StreamID, state streamState, e error) { +func (c *connection) clearStream(id quic.StreamID) { c.streamMx.Lock() defer c.streamMx.Unlock() - d, ok := c.streams[id] - if !ok { // should never happen - return - } - var isDone bool - //nolint:exhaustive // These are all the cases we care about. - switch state { - case streamStateReceiveClosed: - isDone = d.SetReceiveError(e) - case streamStateSendClosed: - isDone = d.SetSendError(e) - default: - return - } - if isDone { - delete(c.streams, id) - } + delete(c.streams, id) } func (c *connection) openRequestStream( @@ -108,7 +92,7 @@ func (c *connection) openRequestStream( c.streamMx.Lock() c.streams[str.StreamID()] = datagrams c.streamMx.Unlock() - qstr := newStateTrackingStream(str, func(s streamState, e error) { c.onStreamStateChange(str.StreamID(), s, e) }) + qstr := newStateTrackingStream(str, c, datagrams) hstr := newStream(qstr, c, datagrams) return newRequestStream(hstr, requestWriter, reqDone, c.decoder, disableCompression, maxHeaderBytes), nil } @@ -124,7 +108,7 @@ func (c *connection) acceptStream(ctx context.Context) (quic.Stream, *datagramme c.streamMx.Lock() c.streams[strID] = datagrams c.streamMx.Unlock() - str = newStateTrackingStream(str, func(s streamState, e error) { c.onStreamStateChange(strID, s, e) }) + str = newStateTrackingStream(str, c, datagrams) } return str, datagrams, nil } diff --git a/http3/datagram.go b/http3/datagram.go index 491e97ed713..6d570e6b006 100644 --- a/http3/datagram.go +++ b/http3/datagram.go @@ -27,21 +27,19 @@ func newDatagrammer(sendDatagram func([]byte) error) *datagrammer { } } -func (d *datagrammer) SetReceiveError(err error) (isDone bool) { +func (d *datagrammer) SetReceiveError(err error) { d.mx.Lock() defer d.mx.Unlock() d.receiveErr = err d.signalHasData() - return d.sendErr != nil } -func (d *datagrammer) SetSendError(err error) (isDone bool) { +func (d *datagrammer) SetSendError(err error) { d.mx.Lock() defer d.mx.Unlock() d.sendErr = err - return d.receiveErr != nil } func (d *datagrammer) Send(b []byte) error { diff --git a/http3/datagram_test.go b/http3/datagram_test.go index 85a6a823b36..647d3dcd72b 100644 --- a/http3/datagram_test.go +++ b/http3/datagram_test.go @@ -54,7 +54,7 @@ var _ = Describe("Datagrams", func() { dg := newDatagrammer(nil) dg.enqueue([]byte("foo")) testErr := errors.New("test error") - Expect(dg.SetReceiveError(testErr)).To(BeFalse()) + dg.SetReceiveError(testErr) dg.enqueue([]byte("bar")) data, err := dg.Receive(context.Background()) Expect(err).ToNot(HaveOccurred()) diff --git a/http3/state_tracking_stream.go b/http3/state_tracking_stream.go index a5cd834ce14..9cf17f5e68e 100644 --- a/http3/state_tracking_stream.go +++ b/http3/state_tracking_stream.go @@ -9,59 +9,79 @@ import ( "github.com/quic-go/quic-go" ) -type streamState uint8 - -const ( - streamStateOpen streamState = iota - streamStateReceiveClosed - streamStateSendClosed - streamStateSendAndReceiveClosed -) +var _ quic.Stream = &stateTrackingStream{} +// stateTrackingStream is an implementation of quic.Stream that delegates +// to an underlying stream +// it takes care of proxying send and receive errors onto an implementation of +// the errorSetter interface (intended to be occupied by a datagrammer) +// it is also responsible for clearing the stream based on its ID from its +// parent connection, this is done through the streamClearer interface when +// both the send and receive sides are closed type stateTrackingStream struct { quic.Stream - mx sync.Mutex - state streamState + mx sync.Mutex + sendErr error + recvErr error + + clearer streamClearer + setter errorSetter +} + +type streamClearer interface { + clearStream(quic.StreamID) +} - onStateChange func(streamState, error) +type errorSetter interface { + SetSendError(error) + SetReceiveError(error) } -func newStateTrackingStream(s quic.Stream, onStateChange func(streamState, error)) *stateTrackingStream { +func newStateTrackingStream(s quic.Stream, clearer streamClearer, setter errorSetter) *stateTrackingStream { + t := &stateTrackingStream{ + Stream: s, + clearer: clearer, + setter: setter, + } + context.AfterFunc(s.Context(), func() { - onStateChange(streamStateSendClosed, context.Cause(s.Context())) + t.closeSend(context.Cause(s.Context())) }) - return &stateTrackingStream{ - Stream: s, - state: streamStateOpen, - onStateChange: onStateChange, - } -} -var _ quic.Stream = &stateTrackingStream{} + return t +} func (s *stateTrackingStream) closeSend(e error) { s.mx.Lock() defer s.mx.Unlock() - if s.state == streamStateReceiveClosed || s.state == streamStateSendAndReceiveClosed { - s.state = streamStateSendAndReceiveClosed - } else { - s.state = streamStateSendClosed + // clear the stream the first time both the send + // and receive are finished + if s.sendErr == nil { + if s.recvErr != nil { + s.clearer.clearStream(s.StreamID()) + } + + s.setter.SetSendError(e) + s.sendErr = e } - s.onStateChange(s.state, e) } func (s *stateTrackingStream) closeReceive(e error) { s.mx.Lock() defer s.mx.Unlock() - if s.state == streamStateSendClosed || s.state == streamStateSendAndReceiveClosed { - s.state = streamStateSendAndReceiveClosed - } else { - s.state = streamStateReceiveClosed + // clear the stream the first time both the send + // and receive are finished + if s.recvErr == nil { + if s.sendErr != nil { + s.clearer.clearStream(s.StreamID()) + } + + s.setter.SetReceiveError(e) + s.recvErr = e } - s.onStateChange(s.state, e) } func (s *stateTrackingStream) Close() error { diff --git a/http3/state_tracking_stream_test.go b/http3/state_tracking_stream_test.go index e900cf1ae0a..cc42ff9e21a 100644 --- a/http3/state_tracking_stream_test.go +++ b/http3/state_tracking_stream_test.go @@ -15,145 +15,180 @@ import ( "go.uber.org/mock/gomock" ) -type stateTransition struct { - state streamState - err error -} +var someStreamID = quic.StreamID(12) var _ = Describe("State Tracking Stream", func() { It("recognizes when the receive side is closed", func() { qstr := mockquic.NewMockStream(mockCtrl) - qstr.EXPECT().StreamID().AnyTimes() + qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID) qstr.EXPECT().Context().Return(context.Background()).AnyTimes() - var states []stateTransition - str := newStateTrackingStream(qstr, func(state streamState, err error) { - states = append(states, stateTransition{state, err}) - }) + + var ( + clearer mockStreamClearer + setter mockErrorSetter + str = newStateTrackingStream(qstr, &clearer, &setter) + ) buf := bytes.NewBuffer([]byte("foobar")) qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() for i := 0; i < 3; i++ { _, err := str.Read([]byte{0}) Expect(err).ToNot(HaveOccurred()) - Expect(states).To(BeEmpty()) + Expect(clearer.cleared).To(BeNil()) + Expect(setter.recvErrs).To(BeEmpty()) + Expect(setter.sendErrs).To(BeEmpty()) } _, err := io.ReadAll(str) Expect(err).ToNot(HaveOccurred()) - Expect(states).To(HaveLen(1)) - Expect(states[0].state).To(Equal(streamStateReceiveClosed)) - Expect(states[0].err).To(Equal(io.EOF)) + Expect(clearer.cleared).To(BeNil()) + Expect(setter.recvErrs).To(HaveLen(1)) + Expect(setter.recvErrs[0]).To(Equal(io.EOF)) + Expect(setter.sendErrs).To(BeEmpty()) }) It("recognizes local read cancellations", func() { qstr := mockquic.NewMockStream(mockCtrl) - qstr.EXPECT().StreamID().AnyTimes() + qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID) qstr.EXPECT().Context().Return(context.Background()).AnyTimes() - var states []stateTransition - str := newStateTrackingStream(qstr, func(state streamState, err error) { - states = append(states, stateTransition{state, err}) - }) + + var ( + clearer mockStreamClearer + setter mockErrorSetter + str = newStateTrackingStream(qstr, &clearer, &setter) + ) buf := bytes.NewBuffer([]byte("foobar")) qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() qstr.EXPECT().CancelRead(quic.StreamErrorCode(1337)) _, err := str.Read(make([]byte, 3)) Expect(err).ToNot(HaveOccurred()) - Expect(states).To(BeEmpty()) + Expect(clearer.cleared).To(BeNil()) + Expect(setter.recvErrs).To(BeEmpty()) + Expect(setter.sendErrs).To(BeEmpty()) + str.CancelRead(1337) - Expect(states).To(HaveLen(1)) - Expect(states[0].state).To(Equal(streamStateReceiveClosed)) - Expect(states[0].err).To(Equal(&quic.StreamError{ErrorCode: 1337})) + Expect(clearer.cleared).To(BeNil()) + Expect(setter.recvErrs).To(HaveLen(1)) + Expect(setter.recvErrs[0]).To(Equal(&quic.StreamError{StreamID: someStreamID, ErrorCode: 1337})) + Expect(setter.sendErrs).To(BeEmpty()) }) It("recognizes remote cancellations", func() { qstr := mockquic.NewMockStream(mockCtrl) - qstr.EXPECT().StreamID().AnyTimes() + qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID) qstr.EXPECT().Context().Return(context.Background()).AnyTimes() - var states []stateTransition - str := newStateTrackingStream(qstr, func(state streamState, err error) { - states = append(states, stateTransition{state, err}) - }) + + var ( + clearer mockStreamClearer + setter mockErrorSetter + str = newStateTrackingStream(qstr, &clearer, &setter) + ) testErr := errors.New("test error") qstr.EXPECT().Read(gomock.Any()).Return(0, testErr) _, err := str.Read(make([]byte, 3)) Expect(err).To(MatchError(testErr)) - Expect(states).To(HaveLen(1)) - Expect(states[0].state).To(Equal(streamStateReceiveClosed)) - Expect(states[0].err).To(MatchError(testErr)) + Expect(clearer.cleared).To(BeNil()) + Expect(setter.recvErrs).To(HaveLen(1)) + Expect(setter.recvErrs[0]).To(Equal(testErr)) + Expect(setter.sendErrs).To(BeEmpty()) }) It("doesn't misinterpret read deadline errors", func() { qstr := mockquic.NewMockStream(mockCtrl) - qstr.EXPECT().StreamID().AnyTimes() + qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID) qstr.EXPECT().Context().Return(context.Background()).AnyTimes() - var states []stateTransition - str := newStateTrackingStream(qstr, func(state streamState, err error) { - states = append(states, stateTransition{state, err}) - }) + + var ( + clearer mockStreamClearer + setter mockErrorSetter + str = newStateTrackingStream(qstr, &clearer, &setter) + ) qstr.EXPECT().Read(gomock.Any()).Return(0, os.ErrDeadlineExceeded) _, err := str.Read(make([]byte, 3)) Expect(err).To(MatchError(os.ErrDeadlineExceeded)) - Expect(states).To(BeEmpty()) + Expect(clearer.cleared).To(BeNil()) + Expect(setter.recvErrs).To(BeEmpty()) + Expect(setter.sendErrs).To(BeEmpty()) }) It("recognizes when the send side is closed, when write errors", func() { qstr := mockquic.NewMockStream(mockCtrl) - qstr.EXPECT().StreamID().AnyTimes() + qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID) qstr.EXPECT().Context().Return(context.Background()).AnyTimes() - var states []stateTransition - str := newStateTrackingStream(qstr, func(state streamState, err error) { - states = append(states, stateTransition{state, err}) - }) + + var ( + clearer mockStreamClearer + setter mockErrorSetter + str = newStateTrackingStream(qstr, &clearer, &setter) + ) testErr := errors.New("test error") qstr.EXPECT().Write([]byte("foo")).Return(3, nil) qstr.EXPECT().Write([]byte("bar")).Return(0, testErr) + _, err := str.Write([]byte("foo")) Expect(err).ToNot(HaveOccurred()) - Expect(states).To(BeEmpty()) + Expect(clearer.cleared).To(BeNil()) + Expect(setter.recvErrs).To(BeEmpty()) + Expect(setter.sendErrs).To(BeEmpty()) + _, err = str.Write([]byte("bar")) Expect(err).To(MatchError(testErr)) - Expect(states).To(HaveLen(1)) - Expect(states[0].state).To(Equal(streamStateSendClosed)) - Expect(states[0].err).To(Equal(testErr)) + Expect(clearer.cleared).To(BeNil()) + Expect(setter.recvErrs).To(BeEmpty()) + Expect(setter.sendErrs).To(HaveLen(1)) + Expect(setter.sendErrs[0]).To(Equal(testErr)) }) It("recognizes when the send side is closed, when write errors", func() { qstr := mockquic.NewMockStream(mockCtrl) - qstr.EXPECT().StreamID().AnyTimes() + qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID) qstr.EXPECT().Context().Return(context.Background()).AnyTimes() - var states []stateTransition - str := newStateTrackingStream(qstr, func(state streamState, err error) { - states = append(states, stateTransition{state, err}) - }) + + var ( + clearer mockStreamClearer + setter mockErrorSetter + str = newStateTrackingStream(qstr, &clearer, &setter) + ) qstr.EXPECT().Write([]byte("foo")).Return(0, os.ErrDeadlineExceeded) - Expect(states).To(BeEmpty()) + Expect(clearer.cleared).To(BeNil()) + Expect(setter.recvErrs).To(BeEmpty()) + Expect(setter.sendErrs).To(BeEmpty()) + _, err := str.Write([]byte("foo")) Expect(err).To(MatchError(os.ErrDeadlineExceeded)) - Expect(states).To(BeEmpty()) + Expect(clearer.cleared).To(BeNil()) + Expect(setter.recvErrs).To(BeEmpty()) + Expect(setter.sendErrs).To(BeEmpty()) }) It("recognizes when the send side is closed, when CancelWrite is called", func() { qstr := mockquic.NewMockStream(mockCtrl) - qstr.EXPECT().StreamID().AnyTimes() + qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID) qstr.EXPECT().Context().Return(context.Background()).AnyTimes() - var states []stateTransition - str := newStateTrackingStream(qstr, func(state streamState, err error) { - states = append(states, stateTransition{state, err}) - }) + + var ( + clearer mockStreamClearer + setter mockErrorSetter + str = newStateTrackingStream(qstr, &clearer, &setter) + ) qstr.EXPECT().Write(gomock.Any()) qstr.EXPECT().CancelWrite(quic.StreamErrorCode(1337)) _, err := str.Write([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) - Expect(states).To(BeEmpty()) + Expect(clearer.cleared).To(BeNil()) + Expect(setter.recvErrs).To(BeEmpty()) + Expect(setter.sendErrs).To(BeEmpty()) + str.CancelWrite(1337) - Expect(states).To(HaveLen(1)) - Expect(states[0].state).To(Equal(streamStateSendClosed)) - Expect(states[0].err).To(Equal(&quic.StreamError{ErrorCode: 1337})) + Expect(clearer.cleared).To(BeNil()) + Expect(setter.recvErrs).To(BeEmpty()) + Expect(setter.sendErrs).To(HaveLen(1)) + Expect(setter.sendErrs[0]).To(Equal(&quic.StreamError{StreamID: someStreamID, ErrorCode: 1337})) }) It("recognizes when the send side is closed, when the stream context is canceled", func() { @@ -161,20 +196,114 @@ var _ = Describe("State Tracking Stream", func() { qstr.EXPECT().StreamID().AnyTimes() ctx, cancel := context.WithCancelCause(context.Background()) qstr.EXPECT().Context().Return(ctx).AnyTimes() - var states []stateTransition - done := make(chan struct{}) - newStateTrackingStream(qstr, func(state streamState, err error) { - states = append(states, stateTransition{state, err}) - close(done) - }) + var ( + clearer mockStreamClearer + setter = mockErrorSetter{ + sendSent: make(chan struct{}), + } + ) + + _ = newStateTrackingStream(qstr, &clearer, &setter) + Expect(clearer.cleared).To(BeNil()) + Expect(setter.recvErrs).To(BeEmpty()) + Expect(setter.sendErrs).To(BeEmpty()) - Expect(states).To(BeEmpty()) testErr := errors.New("test error") cancel(testErr) - Eventually(done).Should(BeClosed()) - Expect(states).To(HaveLen(1)) - Expect(states[0].state).To(Equal(streamStateSendClosed)) - Expect(states[0].err).To(Equal(testErr)) + Eventually(setter.sendSent).Should(BeClosed()) + Expect(clearer.cleared).To(BeNil()) + Expect(setter.recvErrs).To(BeEmpty()) + Expect(setter.sendErrs).To(HaveLen(1)) + Expect(setter.sendErrs[0]).To(Equal(testErr)) + }) + + It("clears the stream when receive is closed followed by send is closed", func() { + qstr := mockquic.NewMockStream(mockCtrl) + qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID) + qstr.EXPECT().Context().Return(context.Background()).AnyTimes() + + var ( + clearer mockStreamClearer + setter mockErrorSetter + str = newStateTrackingStream(qstr, &clearer, &setter) + ) + + buf := bytes.NewBuffer([]byte("foobar")) + qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + _, err := io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + + Expect(clearer.cleared).To(BeNil()) + Expect(setter.recvErrs).To(HaveLen(1)) + Expect(setter.recvErrs[0]).To(Equal(io.EOF)) + + testErr := errors.New("test error") + qstr.EXPECT().Write([]byte("bar")).Return(0, testErr) + + _, err = str.Write([]byte("bar")) + Expect(err).To(MatchError(testErr)) + Expect(setter.sendErrs).To(HaveLen(1)) + Expect(setter.sendErrs[0]).To(Equal(testErr)) + + Expect(clearer.cleared).To(Equal(&someStreamID)) + }) + + It("clears the stream when send is closed followed by receive is closed", func() { + qstr := mockquic.NewMockStream(mockCtrl) + qstr.EXPECT().StreamID().AnyTimes().Return(someStreamID) + qstr.EXPECT().Context().Return(context.Background()).AnyTimes() + + var ( + clearer mockStreamClearer + setter mockErrorSetter + str = newStateTrackingStream(qstr, &clearer, &setter) + ) + + testErr := errors.New("test error") + qstr.EXPECT().Write([]byte("bar")).Return(0, testErr) + + _, err := str.Write([]byte("bar")) + Expect(err).To(MatchError(testErr)) + Expect(clearer.cleared).To(BeNil()) + Expect(setter.sendErrs).To(HaveLen(1)) + Expect(setter.sendErrs[0]).To(Equal(testErr)) + + buf := bytes.NewBuffer([]byte("foobar")) + qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + + _, err = io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(setter.recvErrs).To(HaveLen(1)) + Expect(setter.recvErrs[0]).To(Equal(io.EOF)) + + Expect(clearer.cleared).To(Equal(&someStreamID)) }) }) + +type mockStreamClearer struct { + cleared *quic.StreamID +} + +func (s *mockStreamClearer) clearStream(id quic.StreamID) { + s.cleared = &id +} + +type mockErrorSetter struct { + sendErrs []error + recvErrs []error + + sendSent chan struct{} +} + +func (e *mockErrorSetter) SetSendError(err error) { + e.sendErrs = append(e.sendErrs, err) + + if e.sendSent != nil { + close(e.sendSent) + } +} + +func (e *mockErrorSetter) SetReceiveError(err error) { + e.recvErrs = append(e.recvErrs, err) +}