Skip to content

Commit

Permalink
delay completion of the receive stream until the reset error was read
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Apr 25, 2024
1 parent 12aa638 commit 11bb2e7
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 49 deletions.
86 changes: 54 additions & 32 deletions receive_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ type receiveStream struct {
readPosInFrame int
currentFrameIsLast bool // is the currentFrame the last frame on this stream

finRead bool // set once we read a frame with a Fin
// Set once we read the io.EOF or the cancellation error.
// Note that for local cancellations, this doesn't necessarily mean that we know the final offset yet.
errorRead bool
completed bool // set once we've called streamSender.onStreamCompleted
closeForShutdownErr error
cancelReadErr error
resetRemotelyErr *StreamError
Expand Down Expand Up @@ -83,27 +86,45 @@ func (s *receiveStream) Read(p []byte) (int, error) {
defer func() { <-s.readOnce }()

s.mutex.Lock()
completed, n, err := s.readImpl(p)
n, err := s.readImpl(p)
if err != nil {
s.errorRead = true
}
completed := s.isNewlyCompleted()
s.mutex.Unlock()

if completed {
if err != io.EOF {
s.flowController.Abandon()
}
s.sender.onStreamCompleted(s.streamID)
}
return n, err
}

func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, error) {
if s.finRead {
return false, 0, io.EOF
func (s *receiveStream) isNewlyCompleted() bool {
// We're done with the stream once:
// 1. The application has consumed the io.EOF or the cancellation error
// 2. We know the final offset (for flow control accounting)
isNewlyCompleted := !s.completed && s.errorRead && s.finalOffset != protocol.MaxByteCount
if isNewlyCompleted {
s.completed = true
}
return isNewlyCompleted
}

func (s *receiveStream) readImpl(p []byte) (int, error) {
if s.currentFrameIsLast && s.currentFrame == nil {
return 0, io.EOF
}
if s.cancelReadErr != nil {
return false, 0, s.cancelReadErr
return 0, s.cancelReadErr
}
if s.resetRemotelyErr != nil {
return false, 0, s.resetRemotelyErr
return 0, s.resetRemotelyErr
}
if s.closeForShutdownErr != nil {
return false, 0, s.closeForShutdownErr
return 0, s.closeForShutdownErr
}

var bytesRead int
Expand All @@ -113,25 +134,25 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err
s.dequeueNextFrame()
}
if s.currentFrame == nil && bytesRead > 0 {
return false, bytesRead, s.closeForShutdownErr
return bytesRead, s.closeForShutdownErr
}

for {
// Stop waiting on errors
if s.closeForShutdownErr != nil {
return false, bytesRead, s.closeForShutdownErr
return bytesRead, s.closeForShutdownErr
}
if s.cancelReadErr != nil {
return false, bytesRead, s.cancelReadErr
return bytesRead, s.cancelReadErr
}
if s.resetRemotelyErr != nil {
return false, bytesRead, s.resetRemotelyErr
return bytesRead, s.resetRemotelyErr
}

deadline := s.deadline
if !deadline.IsZero() {
if !time.Now().Before(deadline) {
return false, bytesRead, errDeadline
return bytesRead, errDeadline
}
if deadlineTimer == nil {
deadlineTimer = utils.NewTimer()
Expand Down Expand Up @@ -161,10 +182,10 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err
}

if bytesRead > len(p) {
return false, bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p))
return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p))
}
if s.readPosInFrame > len(s.currentFrame) {
return false, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame))
return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame))
}

m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:])
Expand All @@ -178,15 +199,14 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err
}

if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast {
s.finRead = true
s.currentFrame = nil
if s.currentFrameDone != nil {
s.currentFrameDone()
}
return true, bytesRead, io.EOF
return bytesRead, io.EOF
}
}
return false, bytesRead, nil
return bytesRead, nil
}

func (s *receiveStream) dequeueNextFrame() {
Expand All @@ -202,7 +222,8 @@ func (s *receiveStream) dequeueNextFrame() {

func (s *receiveStream) CancelRead(errorCode StreamErrorCode) {
s.mutex.Lock()
completed := s.cancelReadImpl(errorCode)
s.cancelReadImpl(errorCode)
completed := s.isNewlyCompleted()
s.mutex.Unlock()

if completed {
Expand All @@ -211,18 +232,16 @@ func (s *receiveStream) CancelRead(errorCode StreamErrorCode) {
}
}

func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) bool /* completed */ {
if s.finRead || s.cancelReadErr != nil || s.resetRemotelyErr != nil {
return false
func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) {
if s.errorRead || s.cancelReadErr != nil || s.resetRemotelyErr != nil {
return
}
s.cancelReadErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false}
s.signalRead()
s.sender.queueControlFrame(&wire.StopSendingFrame{
StreamID: s.streamID,
ErrorCode: errorCode,
})
// We're done with this stream if the final offset was already received.
return s.finalOffset != protocol.MaxByteCount
}

func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error {
Expand Down Expand Up @@ -259,37 +278,40 @@ func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) (bool /*

func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame) error {
s.mutex.Lock()
completed, err := s.handleResetStreamFrameImpl(frame)
err := s.handleResetStreamFrameImpl(frame)
completed := s.isNewlyCompleted()
s.mutex.Unlock()

if completed {
s.flowController.Abandon()
s.sender.onStreamCompleted(s.streamID)
}
return err
}

func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) (bool /*completed */, error) {
func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) error {
fmt.Println("h1")
if s.closeForShutdownErr != nil {
return false, nil
return nil
}
fmt.Println("h2")
if err := s.flowController.UpdateHighestReceived(frame.FinalSize, true); err != nil {
return false, err
return err
}
newlyRcvdFinalOffset := s.finalOffset == protocol.MaxByteCount
s.finalOffset = frame.FinalSize
fmt.Println("h3")

// ignore duplicate RESET_STREAM frames for this stream (after checking their final offset)
if s.resetRemotelyErr != nil {
return false, nil
return nil
}
fmt.Println("h4")
s.resetRemotelyErr = &StreamError{
StreamID: s.streamID,
ErrorCode: frame.ErrorCode,
Remote: true,
}
s.signalRead()
return newlyRcvdFinalOffset, nil
return nil
}

func (s *receiveStream) SetReadDeadline(t time.Time) error {
Expand Down
36 changes: 19 additions & 17 deletions receive_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -534,34 +534,35 @@ var _ = Describe("Receive Stream", func() {
Fin: true,
})).To(Succeed())
mockSender.EXPECT().onStreamCompleted(streamID)
_, err := strWithTimeout.Read(make([]byte, 100))
n, err := strWithTimeout.Read(make([]byte, 100))
Expect(err).To(MatchError(io.EOF))
Expect(n).To(Equal(6))
str.CancelRead(1234)
})

It("doesn't send a STOP_SENDING frame, if the stream was already reset", func() {
gomock.InOrder(
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true),
mockFC.EXPECT().Abandon(),
)
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
Expect(str.handleResetStreamFrame(&wire.ResetStreamFrame{
StreamID: streamID,
FinalSize: 42,
})).To(Succeed())
str.CancelRead(1234)
})

It("sends a STOP_SENDING and completes the stream after receiving the final offset", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(1000), true)
It("sends a STOP_SENDING after receiving the final offset", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true)
Expect(str.handleStreamFrame(&wire.StreamFrame{
Offset: 1000,
Fin: true,
Data: []byte("foobar"),
Fin: true,
})).To(Succeed())
mockFC.EXPECT().Abandon()
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
str.CancelRead(1234)
// read the error
mockSender.EXPECT().onStreamCompleted(streamID)
n, err := str.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(n).To(BeZero())
})

It("completes the stream when receiving the Fin after the stream was canceled", func() {
Expand Down Expand Up @@ -649,25 +650,26 @@ var _ = Describe("Receive Stream", func() {
})

It("ignores duplicate RESET_STREAM frames", func() {
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().Abandon()
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2)
Expect(str.handleResetStreamFrame(rst)).To(Succeed())
Expect(str.handleResetStreamFrame(rst)).To(Succeed())
})

It("doesn't call onStreamCompleted again when the final offset was already received via Fin", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
str.CancelRead(1234)
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().Abandon()
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2)
Expect(str.handleStreamFrame(&wire.StreamFrame{
StreamID: streamID,
Offset: rst.FinalSize,
Fin: true,
})).To(Succeed())
mockSender.EXPECT().queueControlFrame(gomock.Any())
str.CancelRead(1234)
mockFC.EXPECT().Abandon()
Expect(str.handleResetStreamFrame(rst)).To(Succeed())
// now read the error
mockSender.EXPECT().onStreamCompleted(streamID)
_, err := str.Read([]byte{0})
Expect(err).To(HaveOccurred())
})

It("doesn't do anything when it was closed for shutdown", func() {
Expand Down

0 comments on commit 11bb2e7

Please sign in to comment.