Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

delay completion of the receive stream until the reset error was read #4460

Merged
merged 2 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 81 additions & 55 deletions receive_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,14 @@ type receiveStream struct {
readPosInFrame int
currentFrameIsLast bool // is the currentFrame the last frame on this stream

finRead bool // set once we read a frame with a Fin
// Set once we read the io.EOF or the cancellation error.
// Note that for local cancellations, this doesn't necessarily mean that we know the final offset yet.
errorRead bool
completed bool // set once we've called streamSender.onStreamCompleted
cancelledRemotely bool
cancelledLocally bool
cancelErr *StreamError
closeForShutdownErr error
cancelReadErr error
resetRemotelyErr *StreamError

readChan chan struct{}
readOnce chan struct{} // cap: 1, to protect against concurrent use of Read
Expand Down Expand Up @@ -83,7 +87,8 @@ func (s *receiveStream) Read(p []byte) (int, error) {
defer func() { <-s.readOnce }()

s.mutex.Lock()
completed, n, err := s.readImpl(p)
n, err := s.readImpl(p)
completed := s.isNewlyCompleted()
s.mutex.Unlock()

if completed {
Expand All @@ -92,18 +97,38 @@ func (s *receiveStream) Read(p []byte) (int, error) {
return n, err
}

func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, error) {
if s.finRead {
return false, 0, io.EOF
func (s *receiveStream) isNewlyCompleted() bool {
if s.completed {
return false
}
// We need to know the final offset (either via FIN or RESET_STREAM) for flow control accounting.
if s.finalOffset == protocol.MaxByteCount {
return false
}
// We're done with the stream if it was cancelled locally...
if s.cancelledLocally {
s.completed = true
return true
}
if s.cancelReadErr != nil {
return false, 0, s.cancelReadErr
// ... or if the error (either io.EOF or the reset error) was read
if s.errorRead {
s.completed = true
return true
}
return false
}

func (s *receiveStream) readImpl(p []byte) (int, error) {
if s.currentFrameIsLast && s.currentFrame == nil {
s.errorRead = true
return 0, io.EOF
}
if s.resetRemotelyErr != nil {
return false, 0, s.resetRemotelyErr
if s.cancelledRemotely || s.cancelledLocally {
s.errorRead = true
return 0, s.cancelErr
}
if s.closeForShutdownErr != nil {
return false, 0, s.closeForShutdownErr
return 0, s.closeForShutdownErr
}

var bytesRead int
Expand All @@ -113,25 +138,23 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err
s.dequeueNextFrame()
}
if s.currentFrame == nil && bytesRead > 0 {
return false, bytesRead, s.closeForShutdownErr
return bytesRead, s.closeForShutdownErr
}

for {
// Stop waiting on errors
if s.closeForShutdownErr != nil {
return false, bytesRead, s.closeForShutdownErr
return bytesRead, s.closeForShutdownErr
}
if s.cancelReadErr != nil {
return false, bytesRead, s.cancelReadErr
}
if s.resetRemotelyErr != nil {
return false, bytesRead, s.resetRemotelyErr
if s.cancelledRemotely || s.cancelledLocally {
s.errorRead = true
return 0, s.cancelErr
}

deadline := s.deadline
if !deadline.IsZero() {
if !time.Now().Before(deadline) {
return false, bytesRead, errDeadline
return bytesRead, errDeadline
}
if deadlineTimer == nil {
deadlineTimer = utils.NewTimer()
Expand Down Expand Up @@ -161,10 +184,10 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err
}

if bytesRead > len(p) {
return false, bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p))
return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p))
}
if s.readPosInFrame > len(s.currentFrame) {
return false, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame))
return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame))
}

m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:])
Expand All @@ -173,20 +196,20 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err

// when a RESET_STREAM was received, the flow controller was already
// informed about the final byteOffset for this stream
if s.resetRemotelyErr == nil {
if !s.cancelledRemotely {
s.flowController.AddBytesRead(protocol.ByteCount(m))
}

if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast {
s.finRead = true
s.currentFrame = nil
if s.currentFrameDone != nil {
s.currentFrameDone()
}
return true, bytesRead, io.EOF
s.errorRead = true
return bytesRead, io.EOF
}
}
return false, bytesRead, nil
return bytesRead, nil
}

func (s *receiveStream) dequeueNextFrame() {
Expand All @@ -202,7 +225,8 @@ func (s *receiveStream) dequeueNextFrame() {

func (s *receiveStream) CancelRead(errorCode StreamErrorCode) {
s.mutex.Lock()
completed := s.cancelReadImpl(errorCode)
s.cancelReadImpl(errorCode)
completed := s.isNewlyCompleted()
marten-seemann marked this conversation as resolved.
Show resolved Hide resolved
s.mutex.Unlock()

if completed {
Expand All @@ -211,23 +235,26 @@ func (s *receiveStream) CancelRead(errorCode StreamErrorCode) {
}
}

func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) bool /* completed */ {
if s.finRead || s.cancelReadErr != nil || s.resetRemotelyErr != nil {
return false
func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) {
if s.cancelledLocally { // duplicate call to CancelRead
return
}
s.cancelReadErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false}
s.cancelledLocally = true
if s.errorRead || s.cancelledRemotely {
return
}
s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false}
s.signalRead()
s.sender.queueControlFrame(&wire.StopSendingFrame{
StreamID: s.streamID,
ErrorCode: errorCode,
})
// We're done with this stream if the final offset was already received.
return s.finalOffset != protocol.MaxByteCount
}

func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error {
s.mutex.Lock()
completed, err := s.handleStreamFrameImpl(frame)
err := s.handleStreamFrameImpl(frame)
completed := s.isNewlyCompleted()
s.mutex.Unlock()

if completed {
Expand All @@ -237,59 +264,58 @@ func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error {
return err
}

func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) (bool /* completed */, error) {
func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) error {
maxOffset := frame.Offset + frame.DataLen()
if err := s.flowController.UpdateHighestReceived(maxOffset, frame.Fin); err != nil {
return false, err
return err
}
var newlyRcvdFinalOffset bool
if frame.Fin {
newlyRcvdFinalOffset = s.finalOffset == protocol.MaxByteCount
s.finalOffset = maxOffset
}
if s.cancelReadErr != nil {
return newlyRcvdFinalOffset, nil
if s.cancelledLocally {
return nil
}
if err := s.frameQueue.Push(frame.Data, frame.Offset, frame.PutBack); err != nil {
return false, err
return err
}
s.signalRead()
return false, nil
return nil
}

func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame) error {
s.mutex.Lock()
completed, err := s.handleResetStreamFrameImpl(frame)
err := s.handleResetStreamFrameImpl(frame)
completed := s.isNewlyCompleted()
s.mutex.Unlock()

if completed {
s.flowController.Abandon()
s.sender.onStreamCompleted(s.streamID)
}
return err
}

func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) (bool /*completed */, error) {
func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) error {
if s.closeForShutdownErr != nil {
return false, nil
return nil
}
if err := s.flowController.UpdateHighestReceived(frame.FinalSize, true); err != nil {
return false, err
return err
}
newlyRcvdFinalOffset := s.finalOffset == protocol.MaxByteCount
s.finalOffset = frame.FinalSize

// ignore duplicate RESET_STREAM frames for this stream (after checking their final offset)
if s.resetRemotelyErr != nil {
return false, nil
if s.cancelledRemotely {
return nil
}
s.resetRemotelyErr = &StreamError{
StreamID: s.streamID,
ErrorCode: frame.ErrorCode,
Remote: true,
s.flowController.Abandon()
// don't save the error if the RESET_STREAM frames was received after CancelRead was called
if s.cancelledLocally {
return nil
}
s.cancelledRemotely = true
s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: frame.ErrorCode, Remote: true}
s.signalRead()
return newlyRcvdFinalOffset, nil
return nil
}

func (s *receiveStream) SetReadDeadline(t time.Time) error {
Expand Down
65 changes: 46 additions & 19 deletions receive_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,7 @@ var _ = Describe("Receive Stream", func() {

It("returns an error when Read is called after the deadline", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false).AnyTimes()
f := &wire.StreamFrame{Data: []byte("foobar")}
err := str.handleStreamFrame(f)
Expect(err).ToNot(HaveOccurred())
Expect(str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")})).To(Succeed())
str.SetReadDeadline(time.Now().Add(-time.Second))
b := make([]byte, 6)
n, err := strWithTimeout.Read(b)
Expand Down Expand Up @@ -534,34 +532,46 @@ var _ = Describe("Receive Stream", func() {
Fin: true,
})).To(Succeed())
mockSender.EXPECT().onStreamCompleted(streamID)
_, err := strWithTimeout.Read(make([]byte, 100))
n, err := strWithTimeout.Read(make([]byte, 100))
Expect(err).To(MatchError(io.EOF))
Expect(n).To(Equal(6))
str.CancelRead(1234)
})

It("doesn't send a STOP_SENDING frame, if the stream was already reset", func() {
gomock.InOrder(
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true),
mockFC.EXPECT().Abandon(),
)
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
mockFC.EXPECT().Abandon().MinTimes(1)
Expect(str.handleResetStreamFrame(&wire.ResetStreamFrame{
ErrorCode: 1337,
StreamID: streamID,
FinalSize: 42,
})).To(Succeed())
mockSender.EXPECT().onStreamCompleted(gomock.Any())
str.CancelRead(1234)
// check that the error indicates a remote reset
n, err := str.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(n).To(BeZero())
var streamErr *StreamError
Expect(errors.As(err, &streamErr)).To(BeTrue())
Expect(streamErr.ErrorCode).To(BeEquivalentTo(1337))
Expect(streamErr.Remote).To(BeTrue())
})

It("sends a STOP_SENDING and completes the stream after receiving the final offset", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(1000), true)
It("sends a STOP_SENDING after receiving the final offset", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true)
Expect(str.handleStreamFrame(&wire.StreamFrame{
Offset: 1000,
Fin: true,
Data: []byte("foobar"),
Fin: true,
})).To(Succeed())
mockFC.EXPECT().Abandon()
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
str.CancelRead(1234)
// read the error
n, err := str.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(n).To(BeZero())
})

It("completes the stream when receiving the Fin after the stream was canceled", func() {
Expand Down Expand Up @@ -649,32 +659,49 @@ var _ = Describe("Receive Stream", func() {
})

It("ignores duplicate RESET_STREAM frames", func() {
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().Abandon()
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2)
mockFC.EXPECT().Abandon()
Expect(str.handleResetStreamFrame(rst)).To(Succeed())
Expect(str.handleResetStreamFrame(rst)).To(Succeed())
})

It("doesn't call onStreamCompleted again when the final offset was already received via Fin", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
str.CancelRead(1234)
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().Abandon()
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2)
Expect(str.handleStreamFrame(&wire.StreamFrame{
StreamID: streamID,
Offset: rst.FinalSize,
Fin: true,
})).To(Succeed())
mockFC.EXPECT().Abandon().MinTimes(1)
mockSender.EXPECT().onStreamCompleted(streamID)
Expect(str.handleResetStreamFrame(rst)).To(Succeed())
// now read the error
n, err := str.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(n).To(BeZero())
})

It("doesn't do anything when it was closed for shutdown", func() {
str.closeForShutdown(errors.New("shutdown"))
err := str.handleResetStreamFrame(rst)
Expect(err).ToNot(HaveOccurred())
})

It("handles RESET_STREAM after CancelRead", func() {
mockFC.EXPECT().Abandon()
mockSender.EXPECT().queueControlFrame(gomock.Any())
str.CancelRead(1234)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
mockSender.EXPECT().onStreamCompleted(streamID)
Expect(str.handleResetStreamFrame(rst)).To(Succeed())
// check that the error indicates a local reset
n, err := str.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(n).To(BeZero())
var streamErr *StreamError
Expect(errors.As(err, &streamErr)).To(BeTrue())
Expect(streamErr.Remote).To(BeFalse())
})
})
})

Expand Down
Loading