Skip to content

Commit

Permalink
fix handling of CancelRead after receiving a RESET_STREAM
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Apr 26, 2024
1 parent a43a8ea commit 1a926a8
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 50 deletions.
90 changes: 49 additions & 41 deletions receive_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ type receiveStream struct {
// Note that for local cancellations, this doesn't necessarily mean that we know the final offset yet.
errorRead bool
completed bool // set once we've called streamSender.onStreamCompleted
cancelledRemotely bool
cancelledLocally bool
cancelErr *StreamError
closeForShutdownErr error
cancelReadErr error
resetRemotelyErr *StreamError

readChan chan struct{}
readOnce chan struct{} // cap: 1, to protect against concurrent use of Read
Expand Down Expand Up @@ -87,41 +88,44 @@ func (s *receiveStream) Read(p []byte) (int, error) {

s.mutex.Lock()
n, err := s.readImpl(p)
if err != nil {
s.errorRead = true
}
completed := s.isNewlyCompleted()
s.mutex.Unlock()

if completed {
if err != io.EOF {
s.flowController.Abandon()
}
s.sender.onStreamCompleted(s.streamID)
}
return n, err
}

func (s *receiveStream) isNewlyCompleted() bool {
// We're done with the stream once:
// 1. The application has consumed the io.EOF or the cancellation error
// 2. We know the final offset (for flow control accounting)
isNewlyCompleted := !s.completed && s.errorRead && s.finalOffset != protocol.MaxByteCount
if isNewlyCompleted {
if s.completed {
return false
}
// We need to know the final offset (either via FIN or RESET_STREAM) for flow control accounting.
if s.finalOffset == protocol.MaxByteCount {
return false
}
// We're done with the stream if it was cancelled locally...
if s.cancelledLocally {
s.completed = true
return true
}
return isNewlyCompleted
// ... or if the error (either io.EOF or the reset error) was read
if s.errorRead {
s.completed = true
return true
}
return false
}

func (s *receiveStream) readImpl(p []byte) (int, error) {
if s.currentFrameIsLast && s.currentFrame == nil {
s.errorRead = true
return 0, io.EOF
}
if s.cancelReadErr != nil {
return 0, s.cancelReadErr
}
if s.resetRemotelyErr != nil {
return 0, s.resetRemotelyErr
if s.cancelledRemotely || s.cancelledLocally {
s.errorRead = true
return 0, s.cancelErr
}
if s.closeForShutdownErr != nil {
return 0, s.closeForShutdownErr
Expand All @@ -142,11 +146,9 @@ func (s *receiveStream) readImpl(p []byte) (int, error) {
if s.closeForShutdownErr != nil {
return bytesRead, s.closeForShutdownErr
}
if s.cancelReadErr != nil {
return bytesRead, s.cancelReadErr
}
if s.resetRemotelyErr != nil {
return bytesRead, s.resetRemotelyErr
if s.cancelledRemotely || s.cancelledLocally {
s.errorRead = true
return 0, s.cancelErr
}

deadline := s.deadline
Expand Down Expand Up @@ -194,7 +196,7 @@ func (s *receiveStream) readImpl(p []byte) (int, error) {

// when a RESET_STREAM was received, the flow controller was already
// informed about the final byteOffset for this stream
if s.resetRemotelyErr == nil {
if !s.cancelledRemotely {
s.flowController.AddBytesRead(protocol.ByteCount(m))
}

Expand All @@ -203,6 +205,7 @@ func (s *receiveStream) readImpl(p []byte) (int, error) {
if s.currentFrameDone != nil {
s.currentFrameDone()
}
s.errorRead = true
return bytesRead, io.EOF
}
}
Expand Down Expand Up @@ -233,10 +236,14 @@ func (s *receiveStream) CancelRead(errorCode StreamErrorCode) {
}

func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) {
if s.errorRead || s.cancelReadErr != nil || s.resetRemotelyErr != nil {
if s.cancelledLocally { // duplicate call to CancelRead
return
}
s.cancelledLocally = true
if s.errorRead || s.cancelledRemotely {
return
}
s.cancelReadErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false}
s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false}
s.signalRead()
s.sender.queueControlFrame(&wire.StopSendingFrame{
StreamID: s.streamID,
Expand All @@ -246,7 +253,8 @@ func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) {

func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error {
s.mutex.Lock()
completed, err := s.handleStreamFrameImpl(frame)
err := s.handleStreamFrameImpl(frame)
completed := s.isNewlyCompleted()
s.mutex.Unlock()

if completed {
Expand All @@ -256,24 +264,22 @@ func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error {
return err
}

func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) (bool /* completed */, error) {
func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) error {
maxOffset := frame.Offset + frame.DataLen()
if err := s.flowController.UpdateHighestReceived(maxOffset, frame.Fin); err != nil {
return false, err
return err
}
var newlyRcvdFinalOffset bool
if frame.Fin {
newlyRcvdFinalOffset = s.finalOffset == protocol.MaxByteCount
s.finalOffset = maxOffset
}
if s.cancelReadErr != nil {
return newlyRcvdFinalOffset, nil
if s.cancelledLocally {
return nil
}
if err := s.frameQueue.Push(frame.Data, frame.Offset, frame.PutBack); err != nil {
return false, err
return err
}
s.signalRead()
return false, nil
return nil
}

func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame) error {
Expand All @@ -298,14 +304,16 @@ func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame)
s.finalOffset = frame.FinalSize

// ignore duplicate RESET_STREAM frames for this stream (after checking their final offset)
if s.resetRemotelyErr != nil {
if s.cancelledRemotely {
return nil
}
s.resetRemotelyErr = &StreamError{
StreamID: s.streamID,
ErrorCode: frame.ErrorCode,
Remote: true,
s.flowController.Abandon()
// don't save the error if the RESET_STREAM frames was received after CancelRead was called
if s.cancelledLocally {
return nil
}
s.cancelledRemotely = true
s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: frame.ErrorCode, Remote: true}
s.signalRead()
return nil
}
Expand Down
43 changes: 34 additions & 9 deletions receive_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,7 @@ var _ = Describe("Receive Stream", func() {

It("returns an error when Read is called after the deadline", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false).AnyTimes()
f := &wire.StreamFrame{Data: []byte("foobar")}
err := str.handleStreamFrame(f)
Expect(err).ToNot(HaveOccurred())
Expect(str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")})).To(Succeed())
str.SetReadDeadline(time.Now().Add(-time.Second))
b := make([]byte, 6)
n, err := strWithTimeout.Read(b)
Expand Down Expand Up @@ -542,11 +540,22 @@ var _ = Describe("Receive Stream", func() {

It("doesn't send a STOP_SENDING frame, if the stream was already reset", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
mockFC.EXPECT().Abandon().MinTimes(1)
Expect(str.handleResetStreamFrame(&wire.ResetStreamFrame{
ErrorCode: 1337,
StreamID: streamID,
FinalSize: 42,
})).To(Succeed())
mockSender.EXPECT().onStreamCompleted(gomock.Any())
str.CancelRead(1234)
// check that the error indicates a remote reset
n, err := str.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(n).To(BeZero())
var streamErr *StreamError
Expect(errors.As(err, &streamErr)).To(BeTrue())
Expect(streamErr.ErrorCode).To(BeEquivalentTo(1337))
Expect(streamErr.Remote).To(BeTrue())
})

It("sends a STOP_SENDING after receiving the final offset", func() {
Expand All @@ -557,9 +566,9 @@ var _ = Describe("Receive Stream", func() {
})).To(Succeed())
mockFC.EXPECT().Abandon()
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
str.CancelRead(1234)
// read the error
mockSender.EXPECT().onStreamCompleted(streamID)
n, err := str.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(n).To(BeZero())
Expand Down Expand Up @@ -651,6 +660,7 @@ var _ = Describe("Receive Stream", func() {

It("ignores duplicate RESET_STREAM frames", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2)
mockFC.EXPECT().Abandon()
Expect(str.handleResetStreamFrame(rst)).To(Succeed())
Expect(str.handleResetStreamFrame(rst)).To(Succeed())
})
Expand All @@ -662,21 +672,36 @@ var _ = Describe("Receive Stream", func() {
Offset: rst.FinalSize,
Fin: true,
})).To(Succeed())
mockSender.EXPECT().queueControlFrame(gomock.Any())
str.CancelRead(1234)
mockFC.EXPECT().Abandon()
mockFC.EXPECT().Abandon().MinTimes(1)
mockSender.EXPECT().onStreamCompleted(streamID)
Expect(str.handleResetStreamFrame(rst)).To(Succeed())
// now read the error
mockSender.EXPECT().onStreamCompleted(streamID)
_, err := str.Read([]byte{0})
n, err := str.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(n).To(BeZero())
})

It("doesn't do anything when it was closed for shutdown", func() {
str.closeForShutdown(errors.New("shutdown"))
err := str.handleResetStreamFrame(rst)
Expect(err).ToNot(HaveOccurred())
})

It("handles RESET_STREAM after CancelRead", func() {
mockFC.EXPECT().Abandon()
mockSender.EXPECT().queueControlFrame(gomock.Any())
str.CancelRead(1234)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
mockSender.EXPECT().onStreamCompleted(streamID)
Expect(str.handleResetStreamFrame(rst)).To(Succeed())
// check that the error indicates a local reset
n, err := str.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(n).To(BeZero())
var streamErr *StreamError
Expect(errors.As(err, &streamErr)).To(BeTrue())
Expect(streamErr.Remote).To(BeFalse())
})
})
})

Expand Down

0 comments on commit 1a926a8

Please sign in to comment.