Skip to content

Commit

Permalink
pass a context to Transport.ConnContext
Browse files Browse the repository at this point in the history
This context is cancelled when the QUIC connection is closed, or when
the QUIC handshake fails. This allows the application to easily build
and garbage collect a map of active connections.
  • Loading branch information
marten-seemann committed May 28, 2024
1 parent 21b643e commit 9047fdd
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 23 deletions.
8 changes: 4 additions & 4 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ var (

var newConnection = func(
ctx context.Context,
ctxCancel context.CancelCauseFunc,
conn sendConn,
runner connRunner,
origDestConnID protocol.ConnectionID,
Expand All @@ -241,6 +242,8 @@ var newConnection = func(
v protocol.Version,
) quicConn {
s := &connection{
ctx: ctx,
ctxCancel: ctxCancel,
conn: conn,
config: conf,
handshakeDestConnID: destConnID,
Expand Down Expand Up @@ -274,7 +277,6 @@ var newConnection = func(
s.queueControlFrame,
connIDGenerator,
)
s.ctx, s.ctxCancel = context.WithCancelCause(ctx)
s.preSetup()
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
0,
Expand Down Expand Up @@ -500,9 +502,7 @@ func (s *connection) preSetup() {
// run the connection main loop
func (s *connection) run() error {
var closeErr closeError
defer func() {
s.ctxCancel(closeErr.err)
}()
defer func() { s.ctxCancel(closeErr.err) }()

s.timer = *newTimer()

Expand Down
4 changes: 3 additions & 1 deletion connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,10 @@ var _ = Describe("Connection", func() {
tracer.EXPECT().SentTransportParameters(gomock.Any())
tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes()
tracer.EXPECT().UpdatedCongestionState(gomock.Any())
ctx, cancel := context.WithCancelCause(context.Background())
conn = newConnection(
context.Background(),
ctx,
cancel,
mconn,
connRunner,
protocol.ConnectionID{},
Expand Down
38 changes: 25 additions & 13 deletions integrationtests/self/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ var _ = Describe("Handshake tests", func() {
})

It("uses the context everywhere, on the server side", func() {
//nolint:staticcheck
serverCtx := context.WithValue(context.Background(), "foo", "bar")
tlsGetConfigForClientContextChan := make(chan context.Context, 1)
tlsGetCertificateContextChan := make(chan context.Context, 1)
tracerContextChan := make(chan context.Context, 1)
Expand All @@ -135,8 +133,11 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{
Conn: conn,
ConnContext: func() context.Context { return serverCtx },
Conn: conn,
ConnContext: func(ctx context.Context) context.Context {
//nolint:staticcheck
return context.WithValue(ctx, "foo", "bar")
},
}
defer tr.Close()
server, err := tr.Listen(
Expand Down Expand Up @@ -173,7 +174,7 @@ var _ = Describe("Handshake tests", func() {
str, err := conn.OpenUniStream()
Expect(err).ToNot(HaveOccurred())
streamContextChan <- str.Context()
str.Close()
str.Write([]byte{1, 2, 3})
}
}()

Expand All @@ -184,21 +185,32 @@ var _ = Describe("Handshake tests", func() {
getQuicConfig(nil),
)
Expect(err).ToNot(HaveOccurred())
defer c.CloseWithError(0, "")
_, err = c.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())
c.CloseWithError(1337, "bye")

checkContext := func(c <-chan context.Context) {
checkContext := func(c <-chan context.Context, checkCancellationCause bool) {
var ctx context.Context
Eventually(c).Should(Receive(&ctx))
EventuallyWithOffset(1, c).Should(Receive(&ctx))
val := ctx.Value("foo")
ExpectWithOffset(1, val).ToNot(BeNil())
v := val.(string)
ExpectWithOffset(1, v).To(Equal("bar"))
EventuallyWithOffset(1, ctx.Done).Should(BeClosed())
if !checkCancellationCause {
return
}
ctxErr := context.Cause(ctx)
var appErr *quic.ApplicationError
ExpectWithOffset(1, errors.As(ctxErr, &appErr)).To(BeTrue())
ExpectWithOffset(1, appErr.ErrorCode).To(BeEquivalentTo(1337))
}
checkContext(connContextChan)
checkContext(tlsGetConfigForClientContextChan)
checkContext(tlsGetCertificateContextChan)
checkContext(tracerContextChan)
checkContext(streamContextChan)
checkContext(connContextChan, true)
checkContext(tracerContextChan, true)
checkContext(streamContextChan, true)
// crypto/tls cancels the context when the TLS handshake completes.
checkContext(tlsGetConfigForClientContextChan, false)
checkContext(tlsGetCertificateContextChan, false)
})

It("uses the context everywhere, on the client side", func() {
Expand Down
10 changes: 6 additions & 4 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,12 @@ type baseServer struct {
nextZeroRTTCleanup time.Time
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true

connContext func() context.Context
connContext func(context.Context) context.Context

// set as a member, so they can be set in the tests
newConn func(
context.Context,
context.CancelCauseFunc,
sendConn,
connRunner,
protocol.ConnectionID, /* original dest connection ID */
Expand Down Expand Up @@ -233,7 +234,7 @@ func newServer(
conn rawConn,
connHandler packetHandlerManager,
connIDGenerator ConnectionIDGenerator,
connContext func() context.Context,
connContext func(context.Context) context.Context,
tlsConf *tls.Config,
config *Config,
tracer *logging.Tracer,
Expand Down Expand Up @@ -635,9 +636,9 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}

var conn quicConn
var ctx context.Context
ctx, cancel := context.WithCancelCause(context.Background())
if s.connContext != nil {
ctx = s.connContext()
ctx = s.connContext(ctx)
if ctx == nil {
panic("quic: ConnContext returned nil")
}
Expand All @@ -661,6 +662,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
s.logger.Debugf("Changing connection ID to %s.", connID)
conn = s.newConn(
ctx,
cancel,
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
s.connHandler,
origDestConnID,
Expand Down
13 changes: 13 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ var _ = Describe("Server", func() {
conn := NewMockQUICConn(mockCtrl)
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
origDestConnID protocol.ConnectionID,
Expand Down Expand Up @@ -490,6 +491,7 @@ var _ = Describe("Server", func() {
conn := NewMockQUICConn(mockCtrl)
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
origDestConnID protocol.ConnectionID,
Expand Down Expand Up @@ -558,6 +560,7 @@ var _ = Describe("Server", func() {
var counter atomic.Uint32
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
Expand Down Expand Up @@ -613,6 +616,7 @@ var _ = Describe("Server", func() {
done := make(chan struct{})
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
Expand Down Expand Up @@ -662,6 +666,7 @@ var _ = Describe("Server", func() {
done := make(chan struct{})
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
Expand Down Expand Up @@ -723,6 +728,7 @@ var _ = Describe("Server", func() {
It("decodes the token from the token field", func() {
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
Expand Down Expand Up @@ -956,6 +962,7 @@ var _ = Describe("Server", func() {
destroyed := make(chan struct{})
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
Expand Down Expand Up @@ -1023,6 +1030,7 @@ var _ = Describe("Server", func() {
handshakeChan := make(chan struct{})
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
Expand Down Expand Up @@ -1093,6 +1101,7 @@ var _ = Describe("Server", func() {
handshakeChan := make(chan struct{})
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
Expand Down Expand Up @@ -1163,6 +1172,7 @@ var _ = Describe("Server", func() {
ready := make(chan struct{})
serv.baseServer.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
Expand Down Expand Up @@ -1204,6 +1214,7 @@ var _ = Describe("Server", func() {
wg.Add(protocol.MaxAcceptQueueSize)
serv.baseServer.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
Expand Down Expand Up @@ -1263,6 +1274,7 @@ var _ = Describe("Server", func() {
conn := NewMockQUICConn(mockCtrl)
serv.baseServer.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
runner connRunner,
_ protocol.ConnectionID,
Expand Down Expand Up @@ -1385,6 +1397,7 @@ var _ = Describe("Server", func() {
called := make(chan struct{})
serv.newConn = func(
_ context.Context,
_ context.CancelCauseFunc,
_ sendConn,
_ connRunner,
_ protocol.ConnectionID,
Expand Down
3 changes: 2 additions & 1 deletion transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,15 @@ type Transport struct {
VerifySourceAddress func(net.Addr) bool

// ConnContext is called when the server accepts a new connection.
// The context is closed when the connection is closed, or when the handshake fails for any reason.
// The context returned from the callback is used to derive every other context used during the
// lifetime of the connection:
// * the context passed to crypto/tls (and used on the tls.ClientHelloInfo)
// * the context used in Config.Tracer
// * the context returned from Connection.Context
// * the context returned from SendStream.Context
// It is not used for dialed connections.
ConnContext func() context.Context
ConnContext func(context.Context) context.Context

// A Tracer traces events that don't belong to a single QUIC connection.
// Tracer.Close is called when the transport is closed.
Expand Down

0 comments on commit 9047fdd

Please sign in to comment.