Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

delay completion of the receive stream until the reset error was read #4460

Merged
merged 2 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fix handling of CancelRead after receiving a RESET_STREAM
  • Loading branch information
marten-seemann committed Apr 26, 2024
commit 1a926a8a66b4139de9310c36be92901e1eaa5771
90 changes: 49 additions & 41 deletions receive_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ type receiveStream struct {
// Note that for local cancellations, this doesn't necessarily mean that we know the final offset yet.
errorRead bool
completed bool // set once we've called streamSender.onStreamCompleted
cancelledRemotely bool
cancelledLocally bool
cancelErr *StreamError
closeForShutdownErr error
cancelReadErr error
resetRemotelyErr *StreamError

readChan chan struct{}
readOnce chan struct{} // cap: 1, to protect against concurrent use of Read
Expand Down Expand Up @@ -87,41 +88,44 @@ func (s *receiveStream) Read(p []byte) (int, error) {

s.mutex.Lock()
n, err := s.readImpl(p)
if err != nil {
s.errorRead = true
}
completed := s.isNewlyCompleted()
s.mutex.Unlock()

if completed {
if err != io.EOF {
s.flowController.Abandon()
}
s.sender.onStreamCompleted(s.streamID)
}
return n, err
}

func (s *receiveStream) isNewlyCompleted() bool {
// We're done with the stream once:
// 1. The application has consumed the io.EOF or the cancellation error
// 2. We know the final offset (for flow control accounting)
isNewlyCompleted := !s.completed && s.errorRead && s.finalOffset != protocol.MaxByteCount
if isNewlyCompleted {
if s.completed {
return false
}
// We need to know the final offset (either via FIN or RESET_STREAM) for flow control accounting.
if s.finalOffset == protocol.MaxByteCount {
return false
}
// We're done with the stream if it was cancelled locally...
if s.cancelledLocally {
s.completed = true
return true
}
return isNewlyCompleted
// ... or if the error (either io.EOF or the reset error) was read
if s.errorRead {
s.completed = true
return true
}
return false
}

func (s *receiveStream) readImpl(p []byte) (int, error) {
if s.currentFrameIsLast && s.currentFrame == nil {
s.errorRead = true
return 0, io.EOF
}
if s.cancelReadErr != nil {
return 0, s.cancelReadErr
}
if s.resetRemotelyErr != nil {
return 0, s.resetRemotelyErr
if s.cancelledRemotely || s.cancelledLocally {
s.errorRead = true
return 0, s.cancelErr
}
if s.closeForShutdownErr != nil {
return 0, s.closeForShutdownErr
Expand All @@ -142,11 +146,9 @@ func (s *receiveStream) readImpl(p []byte) (int, error) {
if s.closeForShutdownErr != nil {
return bytesRead, s.closeForShutdownErr
}
if s.cancelReadErr != nil {
return bytesRead, s.cancelReadErr
}
if s.resetRemotelyErr != nil {
return bytesRead, s.resetRemotelyErr
if s.cancelledRemotely || s.cancelledLocally {
s.errorRead = true
return 0, s.cancelErr
}

deadline := s.deadline
Expand Down Expand Up @@ -194,7 +196,7 @@ func (s *receiveStream) readImpl(p []byte) (int, error) {

// when a RESET_STREAM was received, the flow controller was already
// informed about the final byteOffset for this stream
if s.resetRemotelyErr == nil {
if !s.cancelledRemotely {
s.flowController.AddBytesRead(protocol.ByteCount(m))
}

Expand All @@ -203,6 +205,7 @@ func (s *receiveStream) readImpl(p []byte) (int, error) {
if s.currentFrameDone != nil {
s.currentFrameDone()
}
s.errorRead = true
return bytesRead, io.EOF
}
}
Expand Down Expand Up @@ -233,10 +236,14 @@ func (s *receiveStream) CancelRead(errorCode StreamErrorCode) {
}

func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) {
if s.errorRead || s.cancelReadErr != nil || s.resetRemotelyErr != nil {
if s.cancelledLocally { // duplicate call to CancelRead
return
}
s.cancelledLocally = true
if s.errorRead || s.cancelledRemotely {
return
}
s.cancelReadErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false}
s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false}
s.signalRead()
s.sender.queueControlFrame(&wire.StopSendingFrame{
StreamID: s.streamID,
Expand All @@ -246,7 +253,8 @@ func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) {

func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error {
s.mutex.Lock()
completed, err := s.handleStreamFrameImpl(frame)
err := s.handleStreamFrameImpl(frame)
completed := s.isNewlyCompleted()
s.mutex.Unlock()

if completed {
Expand All @@ -256,24 +264,22 @@ func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error {
return err
}

func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) (bool /* completed */, error) {
func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) error {
maxOffset := frame.Offset + frame.DataLen()
if err := s.flowController.UpdateHighestReceived(maxOffset, frame.Fin); err != nil {
return false, err
return err
}
var newlyRcvdFinalOffset bool
if frame.Fin {
newlyRcvdFinalOffset = s.finalOffset == protocol.MaxByteCount
s.finalOffset = maxOffset
}
if s.cancelReadErr != nil {
return newlyRcvdFinalOffset, nil
if s.cancelledLocally {
return nil
}
if err := s.frameQueue.Push(frame.Data, frame.Offset, frame.PutBack); err != nil {
return false, err
return err
}
s.signalRead()
return false, nil
return nil
}

func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame) error {
Expand All @@ -298,14 +304,16 @@ func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame)
s.finalOffset = frame.FinalSize

// ignore duplicate RESET_STREAM frames for this stream (after checking their final offset)
if s.resetRemotelyErr != nil {
if s.cancelledRemotely {
return nil
}
s.resetRemotelyErr = &StreamError{
StreamID: s.streamID,
ErrorCode: frame.ErrorCode,
Remote: true,
s.flowController.Abandon()
// don't save the error if the RESET_STREAM frames was received after CancelRead was called
if s.cancelledLocally {
return nil
}
s.cancelledRemotely = true
s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: frame.ErrorCode, Remote: true}
s.signalRead()
return nil
}
Expand Down
43 changes: 34 additions & 9 deletions receive_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,7 @@ var _ = Describe("Receive Stream", func() {

It("returns an error when Read is called after the deadline", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false).AnyTimes()
f := &wire.StreamFrame{Data: []byte("foobar")}
err := str.handleStreamFrame(f)
Expect(err).ToNot(HaveOccurred())
Expect(str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")})).To(Succeed())
str.SetReadDeadline(time.Now().Add(-time.Second))
b := make([]byte, 6)
n, err := strWithTimeout.Read(b)
Expand Down Expand Up @@ -542,11 +540,22 @@ var _ = Describe("Receive Stream", func() {

It("doesn't send a STOP_SENDING frame, if the stream was already reset", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
mockFC.EXPECT().Abandon().MinTimes(1)
Expect(str.handleResetStreamFrame(&wire.ResetStreamFrame{
ErrorCode: 1337,
StreamID: streamID,
FinalSize: 42,
})).To(Succeed())
mockSender.EXPECT().onStreamCompleted(gomock.Any())
str.CancelRead(1234)
// check that the error indicates a remote reset
n, err := str.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(n).To(BeZero())
var streamErr *StreamError
Expect(errors.As(err, &streamErr)).To(BeTrue())
Expect(streamErr.ErrorCode).To(BeEquivalentTo(1337))
Expect(streamErr.Remote).To(BeTrue())
})

It("sends a STOP_SENDING after receiving the final offset", func() {
Expand All @@ -557,9 +566,9 @@ var _ = Describe("Receive Stream", func() {
})).To(Succeed())
mockFC.EXPECT().Abandon()
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
str.CancelRead(1234)
// read the error
mockSender.EXPECT().onStreamCompleted(streamID)
n, err := str.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(n).To(BeZero())
Expand Down Expand Up @@ -651,6 +660,7 @@ var _ = Describe("Receive Stream", func() {

It("ignores duplicate RESET_STREAM frames", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2)
mockFC.EXPECT().Abandon()
Expect(str.handleResetStreamFrame(rst)).To(Succeed())
Expect(str.handleResetStreamFrame(rst)).To(Succeed())
})
Expand All @@ -662,21 +672,36 @@ var _ = Describe("Receive Stream", func() {
Offset: rst.FinalSize,
Fin: true,
})).To(Succeed())
mockSender.EXPECT().queueControlFrame(gomock.Any())
str.CancelRead(1234)
mockFC.EXPECT().Abandon()
mockFC.EXPECT().Abandon().MinTimes(1)
mockSender.EXPECT().onStreamCompleted(streamID)
Expect(str.handleResetStreamFrame(rst)).To(Succeed())
// now read the error
mockSender.EXPECT().onStreamCompleted(streamID)
_, err := str.Read([]byte{0})
n, err := str.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(n).To(BeZero())
})

It("doesn't do anything when it was closed for shutdown", func() {
str.closeForShutdown(errors.New("shutdown"))
err := str.handleResetStreamFrame(rst)
Expect(err).ToNot(HaveOccurred())
})

It("handles RESET_STREAM after CancelRead", func() {
mockFC.EXPECT().Abandon()
mockSender.EXPECT().queueControlFrame(gomock.Any())
str.CancelRead(1234)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
mockSender.EXPECT().onStreamCompleted(streamID)
Expect(str.handleResetStreamFrame(rst)).To(Succeed())
// check that the error indicates a local reset
n, err := str.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(n).To(BeZero())
var streamErr *StreamError
Expect(errors.As(err, &streamErr)).To(BeTrue())
Expect(streamErr.Remote).To(BeFalse())
})
})
})

Expand Down
Loading