Skip to content

Commit

Permalink
http3: fix memory leak in stream state tracking (#4523)
Browse files Browse the repository at this point in the history
* fix(http3): handle streamStateSendAndReceiveClosed in onStreamStateChange

Signed-off-by: George MacRorie <[email protected]>

* refactor(http3): adjust stateTrackingStream to operate over streamClearer and errorSetter

* test(http3): remove duplicate test case

* chore(http3): rename test spies to be mocks

---------

Signed-off-by: George MacRorie <[email protected]>
  • Loading branch information
GeorgeMac committed May 19, 2024
1 parent f3cecf9 commit e2fbf3c
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 127 deletions.
24 changes: 4 additions & 20 deletions http3/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,27 +70,11 @@ func newConnection(
return c
}

func (c *connection) onStreamStateChange(id quic.StreamID, state streamState, e error) {
func (c *connection) clearStream(id quic.StreamID) {
c.streamMx.Lock()
defer c.streamMx.Unlock()

d, ok := c.streams[id]
if !ok { // should never happen
return
}
var isDone bool
//nolint:exhaustive // These are all the cases we care about.
switch state {
case streamStateReceiveClosed:
isDone = d.SetReceiveError(e)
case streamStateSendClosed:
isDone = d.SetSendError(e)
default:
return
}
if isDone {
delete(c.streams, id)
}
delete(c.streams, id)
}

func (c *connection) openRequestStream(
Expand All @@ -108,7 +92,7 @@ func (c *connection) openRequestStream(
c.streamMx.Lock()
c.streams[str.StreamID()] = datagrams
c.streamMx.Unlock()
qstr := newStateTrackingStream(str, func(s streamState, e error) { c.onStreamStateChange(str.StreamID(), s, e) })
qstr := newStateTrackingStream(str, c, datagrams)
hstr := newStream(qstr, c, datagrams)
return newRequestStream(hstr, requestWriter, reqDone, c.decoder, disableCompression, maxHeaderBytes), nil
}
Expand All @@ -124,7 +108,7 @@ func (c *connection) acceptStream(ctx context.Context) (quic.Stream, *datagramme
c.streamMx.Lock()
c.streams[strID] = datagrams
c.streamMx.Unlock()
str = newStateTrackingStream(str, func(s streamState, e error) { c.onStreamStateChange(strID, s, e) })
str = newStateTrackingStream(str, c, datagrams)
}
return str, datagrams, nil
}
Expand Down
6 changes: 2 additions & 4 deletions http3/datagram.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,19 @@ func newDatagrammer(sendDatagram func([]byte) error) *datagrammer {
}
}

func (d *datagrammer) SetReceiveError(err error) (isDone bool) {
func (d *datagrammer) SetReceiveError(err error) {
d.mx.Lock()
defer d.mx.Unlock()

d.receiveErr = err
d.signalHasData()
return d.sendErr != nil
}

func (d *datagrammer) SetSendError(err error) (isDone bool) {
func (d *datagrammer) SetSendError(err error) {
d.mx.Lock()
defer d.mx.Unlock()

d.sendErr = err
return d.receiveErr != nil
}

func (d *datagrammer) Send(b []byte) error {
Expand Down
2 changes: 1 addition & 1 deletion http3/datagram_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ var _ = Describe("Datagrams", func() {
dg := newDatagrammer(nil)
dg.enqueue([]byte("foo"))
testErr := errors.New("test error")
Expect(dg.SetReceiveError(testErr)).To(BeFalse())
dg.SetReceiveError(testErr)
dg.enqueue([]byte("bar"))
data, err := dg.Receive(context.Background())
Expect(err).ToNot(HaveOccurred())
Expand Down
80 changes: 50 additions & 30 deletions http3/state_tracking_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,59 +9,79 @@ import (
"github.com/quic-go/quic-go"
)

type streamState uint8

const (
streamStateOpen streamState = iota
streamStateReceiveClosed
streamStateSendClosed
streamStateSendAndReceiveClosed
)
var _ quic.Stream = &stateTrackingStream{}

// stateTrackingStream is an implementation of quic.Stream that delegates
// to an underlying stream
// it takes care of proxying send and receive errors onto an implementation of
// the errorSetter interface (intended to be occupied by a datagrammer)
// it is also responsible for clearing the stream based on its ID from its
// parent connection, this is done through the streamClearer interface when
// both the send and receive sides are closed
type stateTrackingStream struct {
quic.Stream

mx sync.Mutex
state streamState
mx sync.Mutex
sendErr error
recvErr error

clearer streamClearer
setter errorSetter
}

type streamClearer interface {
clearStream(quic.StreamID)
}

onStateChange func(streamState, error)
type errorSetter interface {
SetSendError(error)
SetReceiveError(error)
}

func newStateTrackingStream(s quic.Stream, onStateChange func(streamState, error)) *stateTrackingStream {
func newStateTrackingStream(s quic.Stream, clearer streamClearer, setter errorSetter) *stateTrackingStream {
t := &stateTrackingStream{
Stream: s,
clearer: clearer,
setter: setter,
}

context.AfterFunc(s.Context(), func() {
onStateChange(streamStateSendClosed, context.Cause(s.Context()))
t.closeSend(context.Cause(s.Context()))
})
return &stateTrackingStream{
Stream: s,
state: streamStateOpen,
onStateChange: onStateChange,
}
}

var _ quic.Stream = &stateTrackingStream{}
return t
}

func (s *stateTrackingStream) closeSend(e error) {
s.mx.Lock()
defer s.mx.Unlock()

if s.state == streamStateReceiveClosed || s.state == streamStateSendAndReceiveClosed {
s.state = streamStateSendAndReceiveClosed
} else {
s.state = streamStateSendClosed
// clear the stream the first time both the send
// and receive are finished
if s.sendErr == nil {
if s.recvErr != nil {
s.clearer.clearStream(s.StreamID())
}

s.setter.SetSendError(e)
s.sendErr = e
}
s.onStateChange(s.state, e)
}

func (s *stateTrackingStream) closeReceive(e error) {
s.mx.Lock()
defer s.mx.Unlock()

if s.state == streamStateSendClosed || s.state == streamStateSendAndReceiveClosed {
s.state = streamStateSendAndReceiveClosed
} else {
s.state = streamStateReceiveClosed
// clear the stream the first time both the send
// and receive are finished
if s.recvErr == nil {
if s.sendErr != nil {
s.clearer.clearStream(s.StreamID())
}

s.setter.SetReceiveError(e)
s.recvErr = e
}
s.onStateChange(s.state, e)
}

func (s *stateTrackingStream) Close() error {
Expand Down

0 comments on commit e2fbf3c

Please sign in to comment.