Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pass a context to Transport.ConnContext #4536

Merged
merged 2 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ var (

var newConnection = func(
ctx context.Context,
ctxCancel context.CancelCauseFunc,
conn sendConn,
runner connRunner,
origDestConnID protocol.ConnectionID,
Expand All @@ -240,6 +241,8 @@ var newConnection = func(
v protocol.Version,
) quicConn {
s := &connection{
ctx: ctx,
ctxCancel: ctxCancel,
conn: conn,
config: conf,
handshakeDestConnID: destConnID,
Expand Down Expand Up @@ -273,7 +276,6 @@ var newConnection = func(
s.queueControlFrame,
connIDGenerator,
)
s.ctx, s.ctxCancel = context.WithCancelCause(ctx)
s.preSetup()
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
0,
Expand Down Expand Up @@ -499,9 +501,7 @@ func (s *connection) preSetup() {
// run the connection main loop
func (s *connection) run() error {
var closeErr closeError
defer func() {
s.ctxCancel(closeErr.err)
}()
defer func() { s.ctxCancel(closeErr.err) }()

s.timer = *newTimer()

Expand Down
4 changes: 3 additions & 1 deletion connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,10 @@ var _ = Describe("Connection", func() {
tracer.EXPECT().SentTransportParameters(gomock.Any())
tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes()
tracer.EXPECT().UpdatedCongestionState(gomock.Any())
ctx, cancel := context.WithCancelCause(context.Background())
conn = newConnection(
context.Background(),
ctx,
cancel,
mconn,
connRunner,
protocol.ConnectionID{},
Expand Down
95 changes: 73 additions & 22 deletions integrationtests/self/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ var _ = Describe("Handshake tests", func() {
})

It("uses the context everywhere, on the server side", func() {
//nolint:staticcheck
serverCtx := context.WithValue(context.Background(), "foo", "bar")
tlsGetConfigForClientContextChan := make(chan context.Context, 1)
tlsGetCertificateContextChan := make(chan context.Context, 1)
tracerContextChan := make(chan context.Context, 1)
Expand All @@ -135,8 +133,11 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{
Conn: conn,
ConnContext: func() context.Context { return serverCtx },
Conn: conn,
ConnContext: func(ctx context.Context) context.Context {
//nolint:staticcheck
return context.WithValue(ctx, "foo", "bar")
},
}
defer tr.Close()
server, err := tr.Listen(
Expand Down Expand Up @@ -173,7 +174,7 @@ var _ = Describe("Handshake tests", func() {
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
streamContextChan <- str.Context()
str.Close()
str.Write([]byte{1, 2, 3})
}
}()

Expand All @@ -184,21 +185,63 @@ var _ = Describe("Handshake tests", func() {
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer c.CloseWithError(0, "")
_, err = c.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
c.CloseWithError(1337, "bye")

checkContext := func(c <-chan context.Context) {
checkContext := func(c <-chan context.Context, checkCancellationCause bool) {
var ctx context.Context
Eventually(c).Should(Receive(&ctx))
EventuallyWithOffset(1, c).Should(Receive(&ctx))
val := ctx.Value("foo")
ExpectWithOffset(1, val).ToNot(BeNil())
v := val.(string)
ExpectWithOffset(1, v).To(Equal("bar"))
EventuallyWithOffset(1, ctx.Done).Should(BeClosed())
if !checkCancellationCause {
return
}
ctxErr := context.Cause(ctx)
var appErr *quic.ApplicationError
ExpectWithOffset(1, errors.As(ctxErr, &appErr)).To(BeTrue())
ExpectWithOffset(1, appErr.ErrorCode).To(BeEquivalentTo(1337))
}
checkContext(connContextChan)
checkContext(tlsGetConfigForClientContextChan)
checkContext(tlsGetCertificateContextChan)
checkContext(tracerContextChan)
checkContext(streamContextChan)
checkContext(connContextChan, true)
checkContext(tracerContextChan, true)
checkContext(streamContextChan, true)
// crypto/tls cancels the context when the TLS handshake completes.
checkContext(tlsGetConfigForClientContextChan, false)
checkContext(tlsGetCertificateContextChan, false)
})

It("correctly handles a fresh context returned from ConnContext", func() {
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{
Conn: conn,
ConnContext: func(ctx context.Context) context.Context { return context.Background() },
}
server, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
conn, err := server.Accept(context.Background())
if err != nil {
return
}
Eventually(conn.Context().Done).Should(BeClosed())
}()

c, err := quic.DialAddr(
context.Background(),
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
c.CloseWithError(1337, "bye")
})

It("uses the context everywhere, on the client side", func() {
Expand Down Expand Up @@ -227,31 +270,39 @@ var _ = Describe("Handshake tests", func() {
}),
)
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
cancel()
// make sure the connection context is not cancelled (even though derived from the ctx passed to Dial)
Expect(ctx.Done()).ToNot(Receive())

checkContext := func(ctx context.Context) {
checkContext := func(ctx context.Context, checkCancellationCause bool) {
val := ctx.Value("foo")
ExpectWithOffset(2, val).ToNot(BeNil())
v := val.(string)
ExpectWithOffset(2, v).To(Equal("bar"))
if !checkCancellationCause {
return
}
ctxErr := context.Cause(ctx)
var appErr *quic.ApplicationError
ExpectWithOffset(1, errors.As(ctxErr, &appErr)).To(BeTrue())
ExpectWithOffset(1, appErr.ErrorCode).To(BeEquivalentTo(1337))
}

checkContextFromChan := func(c <-chan context.Context) {
checkContextFromChan := func(c <-chan context.Context, checkCancellationCause bool) {
var ctx context.Context
EventuallyWithOffset(1, c).Should(Receive(&ctx))
checkContext(ctx)
checkContext(ctx, checkCancellationCause)
}

checkContext(conn.Context())
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
checkContext(str.Context())
str.Close()
checkContextFromChan(tlsContextChan)
checkContextFromChan(tracerContextChan)
conn.CloseWithError(1337, "bye")

checkContext(conn.Context(), true)
checkContext(str.Context(), true)
// crypto/tls cancels the context when the TLS handshake completes.
checkContextFromChan(tlsContextChan, false)
checkContextFromChan(tracerContextChan, false)
})

Context("using different cipher suites", func() {
Expand Down
4 changes: 2 additions & 2 deletions integrationtests/self/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -600,9 +600,9 @@ var _ = Describe("HTTP tests", func() {
defer conn.Close()
tr := &quic.Transport{
Conn: conn,
ConnContext: func() context.Context {
ConnContext: func(ctx context.Context) context.Context {
//nolint:staticcheck
return context.WithValue(context.Background(), "foo", "bar")
return context.WithValue(ctx, "foo", "bar")
},
}
defer tr.Close()
Expand Down
22 changes: 17 additions & 5 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,12 @@ type baseServer struct {
nextZeroRTTCleanup time.Time
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true

connContext func() context.Context
connContext func(context.Context) context.Context

// set as a member, so they can be set in the tests
newConn func(
context.Context,
context.CancelCauseFunc,
sendConn,
connRunner,
protocol.ConnectionID, /* original dest connection ID */
Expand Down Expand Up @@ -233,7 +234,7 @@ func newServer(
conn rawConn,
connHandler packetHandlerManager,
connIDGenerator ConnectionIDGenerator,
connContext func() context.Context,
connContext func(context.Context) context.Context,
tlsConf *tls.Config,
config *Config,
tracer *logging.Tracer,
Expand Down Expand Up @@ -635,14 +636,24 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}

var conn quicConn
var ctx context.Context
var cancel context.CancelCauseFunc
ctx, cancel1 := context.WithCancelCause(context.Background())
if s.connContext != nil {
ctx = s.connContext()
ctx = s.connContext(ctx)
if ctx == nil {
panic("quic: ConnContext returned nil")
}
// There's no guarantee that the application returns a context
// that's derived from the context we passed into ConnContext.
// We need to make sure that both contexts are cancelled.
var cancel2 context.CancelCauseFunc
ctx, cancel2 = context.WithCancelCause(ctx)
cancel = func(cause error) {
cancel1(cause)
cancel2(cause)
}
} else {
ctx = context.Background()
cancel = cancel1
}
ctx = context.WithValue(ctx, ConnectionTracingKey, nextConnTracingID())
var tracer *logging.ConnectionTracer
Expand All @@ -661,6 +672,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
s.logger.Debugf("Changing connection ID to %s.", connID)
conn = s.newConn(
ctx,
cancel,
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
s.connHandler,
origDestConnID,
Expand Down
13 changes: 13 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ var _ = Describe("Server", func() {
conn := NewMockQUICConn(mockCtrl)
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
origDestConnID protocol.ConnectionID,
Expand Down Expand Up @@ -490,6 +491,7 @@ var _ = Describe("Server", func() {
conn := NewMockQUICConn(mockCtrl)
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
origDestConnID protocol.ConnectionID,
Expand Down Expand Up @@ -558,6 +560,7 @@ var _ = Describe("Server", func() {
var counter atomic.Uint32
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
Expand Down Expand Up @@ -613,6 +616,7 @@ var _ = Describe("Server", func() {
done := make(chan struct{})
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
Expand Down Expand Up @@ -662,6 +666,7 @@ var _ = Describe("Server", func() {
done := make(chan struct{})
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
Expand Down Expand Up @@ -723,6 +728,7 @@ var _ = Describe("Server", func() {
It("decodes the token from the token field", func() {
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
Expand Down Expand Up @@ -956,6 +962,7 @@ var _ = Describe("Server", func() {
destroyed := make(chan struct{})
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
Expand Down Expand Up @@ -1023,6 +1030,7 @@ var _ = Describe("Server", func() {
handshakeChan := make(chan struct{})
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
Expand Down Expand Up @@ -1093,6 +1101,7 @@ var _ = Describe("Server", func() {
handshakeChan := make(chan struct{})
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
Expand Down Expand Up @@ -1163,6 +1172,7 @@ var _ = Describe("Server", func() {
ready := make(chan struct{})
serv.baseServer.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
Expand Down Expand Up @@ -1204,6 +1214,7 @@ var _ = Describe("Server", func() {
wg.Add(protocol.MaxAcceptQueueSize)
serv.baseServer.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
Expand Down Expand Up @@ -1263,6 +1274,7 @@ var _ = Describe("Server", func() {
conn := NewMockQUICConn(mockCtrl)
serv.baseServer.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
Expand Down Expand Up @@ -1385,6 +1397,7 @@ var _ = Describe("Server", func() {
called := make(chan struct{})
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
Expand Down
3 changes: 2 additions & 1 deletion transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,15 @@ type Transport struct {
VerifySourceAddress func(net.Addr) bool

// ConnContext is called when the server accepts a new connection.
// The context is closed when the connection is closed, or when the handshake fails for any reason.
// The context returned from the callback is used to derive every other context used during the
// lifetime of the connection:
// * the context passed to crypto/tls (and used on the tls.ClientHelloInfo)
// * the context used in Config.Tracer
// * the context returned from Connection.Context
// * the context returned from SendStream.Context
// It is not used for dialed connections.
ConnContext func() context.Context
ConnContext func(context.Context) context.Context
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's another restriction here: The context returned MUST be derived from the context passed into ConnContext.

We can:

  1. Document it.
  2. Document and enforce it, by attaching an internal value to the context, and checking that it exists on the returned context (and panic if not).
  3. Remove this requirement. This means we'll have to pass a second context.CancelCauseFunc to the connection struct.

Copy link
Collaborator

@sukunrt sukunrt Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we derive a context from the context returned from transport.ConnContext?

I mean

ctx, cancel = context.WithCancelCause(s.connContext())
...
conn = s.newConn(ctx, cancel ...)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise 3 seems the nicest.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we derive a context from the context returned from transport.ConnContext?

I mean

ctx, cancel = context.WithCancelCause(s.connContext())
...
conn = s.newConn(ctx, cancel ...)

I don't think that works, since we want the context we pass to ConnContext to be canceled when the connection is closed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we can do 3 without any major refactor. A context.ContextCancelFunc is just a fancy name for a func(error), so we can just wrap the cancel func of the original and the returned context into one function.


// A Tracer traces events that don't belong to a single QUIC connection.
// Tracer.Close is called when the transport is closed.
Expand Down
Loading