Skip to content

Commit

Permalink
http3: pass tracing ID instead of quic.Connection to stream hijackers (
Browse files Browse the repository at this point in the history
…#4401)

The stream hijackers only need to be able to associate the stream with
the underlying QUIC connection. They are not supposed to call any
functions on the quic.Connection. As such, the better API is to just
pass them a unique identifier.
  • Loading branch information
marten-seemann committed Apr 2, 2024
1 parent 27a06f3 commit 183d42a
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 27 deletions.
7 changes: 4 additions & 3 deletions http3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ type roundTripperOpts struct {
EnableDatagram bool
MaxHeaderBytes int64
AdditionalSettings map[uint64]uint64
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
StreamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error)
UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)
}

// client is a HTTP3 client doing requests
Expand Down Expand Up @@ -183,7 +183,8 @@ func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) {
}
go func(str quic.Stream) {
_, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) {
return c.opts.StreamHijacker(ft, conn, str, e)
id := conn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
return c.opts.StreamHijacker(ft, id, str, e)
})
if err == errHijacked {
return
Expand Down
32 changes: 25 additions & 7 deletions http3/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,11 @@ var _ = Describe("Client", func() {
})

It("hijacks a bidirectional stream of unknown frame type", func() {
id := quic.ConnectionTracingID(1234)
frameTypeChan := make(chan FrameType, 1)
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
cl.opts.StreamHijacker = func(ft FrameType, connTracingID quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
Expect(connTracingID).To(Equal(id))
frameTypeChan <- ft
return true, nil
}
Expand All @@ -252,6 +254,8 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id)
conn.EXPECT().Context().Return(ctx).AnyTimes()
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
Expand All @@ -260,7 +264,7 @@ var _ = Describe("Client", func() {

It("closes the connection when hijacker didn't hijack a bidirectional stream", func() {
frameTypeChan := make(chan FrameType, 1)
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
cl.opts.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return false, nil
Expand All @@ -274,6 +278,8 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
conn.EXPECT().Context().Return(ctx).AnyTimes()
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Expand All @@ -282,7 +288,7 @@ var _ = Describe("Client", func() {

It("closes the connection when hijacker returned error", func() {
frameTypeChan := make(chan FrameType, 1)
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
cl.opts.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return false, errors.New("error in hijacker")
Expand All @@ -296,6 +302,8 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
conn.EXPECT().Context().Return(ctx).AnyTimes()
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Expand All @@ -306,7 +314,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("test error")
unknownStr := mockquic.NewMockStream(mockCtrl)
done := make(chan struct{})
cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) {
cl.opts.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, str quic.Stream, e error) (hijacked bool, err error) {
defer close(done)
Expect(e).To(MatchError(testErr))
Expect(ft).To(BeZero())
Expand All @@ -320,6 +328,8 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
conn.EXPECT().Context().Return(ctx).AnyTimes()
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
_, err := cl.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Expand Down Expand Up @@ -363,8 +373,10 @@ var _ = Describe("Client", func() {
})

It("hijacks an unidirectional stream of unknown stream type", func() {
id := quic.ConnectionTracingID(100)
streamTypeChan := make(chan StreamType, 1)
cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
cl.opts.UniStreamHijacker = func(st StreamType, connTracingID quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool {
Expect(connTracingID).To(Equal(id))
Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st
return true
Expand All @@ -380,6 +392,8 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id)
conn.EXPECT().Context().Return(ctx).AnyTimes()
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
Expand All @@ -390,7 +404,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("test error")
done := make(chan struct{})
unknownStr := mockquic.NewMockStream(mockCtrl)
cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
cl.opts.UniStreamHijacker = func(st StreamType, _ quic.ConnectionTracingID, str quic.ReceiveStream, err error) bool {
defer close(done)
Expect(st).To(BeZero())
Expect(str).To(Equal(unknownStr))
Expand All @@ -404,6 +418,8 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
conn.EXPECT().Context().Return(ctx).AnyTimes()
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
Expand All @@ -412,7 +428,7 @@ var _ = Describe("Client", func() {

It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
streamTypeChan := make(chan StreamType, 1)
cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
cl.opts.UniStreamHijacker = func(st StreamType, _ quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st
return false
Expand All @@ -429,6 +445,8 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
conn.EXPECT().Context().Return(ctx).AnyTimes()
_, err := cl.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
Expand Down
18 changes: 13 additions & 5 deletions http3/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type connection struct {
logger utils.Logger

enableDatagrams bool
uniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
uniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)

settings *Settings
receivedSettings chan struct{}
Expand All @@ -26,7 +26,7 @@ type connection struct {
func newConnection(
quicConn quic.Connection,
enableDatagrams bool,
uniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool),
uniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool),
perspective protocol.Perspective,
logger utils.Logger,
) *connection {
Expand Down Expand Up @@ -57,7 +57,8 @@ func (c *connection) HandleUnidirectionalStreams() {
go func(str quic.ReceiveStream) {
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
if c.uniStreamHijacker != nil && c.uniStreamHijacker(StreamType(streamType), c.Connection, str, err) {
id := c.Connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
if c.uniStreamHijacker != nil && c.uniStreamHijacker(StreamType(streamType), id, str, err) {
return
}
c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
Expand Down Expand Up @@ -89,8 +90,15 @@ func (c *connection) HandleUnidirectionalStreams() {
}
return
default:
if c.uniStreamHijacker != nil && c.uniStreamHijacker(StreamType(streamType), c.Connection, str, nil) {
return
if c.uniStreamHijacker != nil {
if c.uniStreamHijacker(
StreamType(streamType),
c.Connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID),
str,
nil,
) {
return
}
}
str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
return
Expand Down
4 changes: 2 additions & 2 deletions http3/roundtrip.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ type RoundTripper struct {
// Callers can either ignore the frame and return control of the stream back to HTTP/3
// (by returning hijacked false).
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
StreamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error)

// When set, this callback is called for unknown unidirectional stream of unknown stream type.
// If parsing the stream type fails, the error is passed to the callback.
// In that case, the stream type will not be set.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)

// Dial specifies an optional dial function for creating QUIC
// connections for requests.
Expand Down
13 changes: 10 additions & 3 deletions http3/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,12 @@ type Server struct {
// Callers can either ignore the frame and return control of the stream back to HTTP/3
// (by returning hijacked false).
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
StreamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error)

// UniStreamHijacker, when set, is called for unknown unidirectional stream of unknown stream type.
// If parsing the stream type fails, the error is passed to the callback.
// In that case, the stream type will not be set.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)

// ConnContext optionally specifies a function that modifies
// the context used for a new connection c. The provided ctx
Expand Down Expand Up @@ -512,7 +512,14 @@ func (s *Server) maxHeaderBytes() uint64 {
func (s *Server) handleRequest(conn *connection, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError {
var ufh unknownFrameHandlerFunc
if s.StreamHijacker != nil {
ufh = func(ft FrameType, e error) (processed bool, err error) { return s.StreamHijacker(ft, conn, str, e) }
ufh = func(ft FrameType, e error) (processed bool, err error) {
return s.StreamHijacker(
ft,
conn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID),
str,
e,
)
}
}
frame, err := parseNextFrame(str, ufh)
if err != nil {
Expand Down
33 changes: 26 additions & 7 deletions http3/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,12 @@ var _ = Describe("Server", func() {
AfterEach(func() { testDone <- struct{}{} })

It("hijacks a bidirectional stream of unknown frame type", func() {
id := quic.ConnectionTracingID(1337)
frameTypeChan := make(chan FrameType, 1)
s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
s.StreamHijacker = func(ft FrameType, connTracingID quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
defer GinkgoRecover()
Expect(e).ToNot(HaveOccurred())
Expect(connTracingID).To(Equal(id))
frameTypeChan <- ft
return true, nil
}
Expand All @@ -331,14 +334,16 @@ var _ = Describe("Server", func() {
<-testDone
return nil, errors.New("test done")
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id)
conn.EXPECT().Context().Return(ctx).AnyTimes()
s.handleConn(conn)
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})

It("cancels writing when hijacker didn't hijack a bidirectional stream", func() {
frameTypeChan := make(chan FrameType, 1)
s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
s.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return false, nil
Expand All @@ -354,14 +359,16 @@ var _ = Describe("Server", func() {
<-testDone
return nil, errors.New("test done")
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
conn.EXPECT().Context().Return(ctx).AnyTimes()
s.handleConn(conn)
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})

It("cancels writing when hijacker returned error", func() {
frameTypeChan := make(chan FrameType, 1)
s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
s.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return false, errors.New("error in hijacker")
Expand All @@ -377,6 +384,8 @@ var _ = Describe("Server", func() {
<-testDone
return nil, errors.New("test done")
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
conn.EXPECT().Context().Return(ctx).AnyTimes()
s.handleConn(conn)
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
Expand All @@ -386,7 +395,7 @@ var _ = Describe("Server", func() {
testErr := errors.New("test error")
done := make(chan struct{})
unknownStr := mockquic.NewMockStream(mockCtrl)
s.StreamHijacker = func(ft FrameType, _ quic.Connection, str quic.Stream, err error) (bool, error) {
s.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, str quic.Stream, err error) (bool, error) {
defer close(done)
Expect(ft).To(BeZero())
Expect(str).To(Equal(unknownStr))
Expand All @@ -401,6 +410,8 @@ var _ = Describe("Server", func() {
<-testDone
return nil, errors.New("test done")
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
conn.EXPECT().Context().Return(ctx).AnyTimes()
s.handleConn(conn)
Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
Expand All @@ -425,9 +436,11 @@ var _ = Describe("Server", func() {
AfterEach(func() { testDone <- struct{}{} })

It("hijacks an unidirectional stream of unknown stream type", func() {
id := quic.ConnectionTracingID(42)
streamTypeChan := make(chan StreamType, 1)
s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
s.UniStreamHijacker = func(st StreamType, connTracingID quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred())
Expect(connTracingID).To(Equal(id))
streamTypeChan <- st
return true
}
Expand All @@ -442,6 +455,8 @@ var _ = Describe("Server", func() {
<-testDone
return nil, errors.New("test done")
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id)
conn.EXPECT().Context().Return(ctx).AnyTimes()
s.handleConn(conn)
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
Expand All @@ -451,7 +466,7 @@ var _ = Describe("Server", func() {
testErr := errors.New("test error")
done := make(chan struct{})
unknownStr := mockquic.NewMockStream(mockCtrl)
s.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
s.UniStreamHijacker = func(st StreamType, _ quic.ConnectionTracingID, str quic.ReceiveStream, err error) bool {
defer close(done)
Expect(st).To(BeZero())
Expect(str).To(Equal(unknownStr))
Expand All @@ -465,14 +480,16 @@ var _ = Describe("Server", func() {
<-testDone
return nil, errors.New("test done")
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
conn.EXPECT().Context().Return(ctx).AnyTimes()
s.handleConn(conn)
Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})

It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
streamTypeChan := make(chan StreamType, 1)
s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
s.UniStreamHijacker = func(st StreamType, _ quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st
return false
Expand All @@ -490,6 +507,8 @@ var _ = Describe("Server", func() {
<-testDone
return nil, errors.New("test done")
})
ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234))
conn.EXPECT().Context().Return(ctx).AnyTimes()
s.handleConn(conn)
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
Expand Down

0 comments on commit 183d42a

Please sign in to comment.