Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

send the CONNECTION_REFUSED error when refusing a connection #4250

Merged
merged 1 commit into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions closed_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ func (c *closedLocalConn) handlePacket(p receivedPacket) {
c.sendPacket(p.remoteAddr, p.info)
}

func (c *closedLocalConn) destroy(error) {}
func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective }
func (c *closedLocalConn) destroy(error) {}
func (c *closedLocalConn) closeWithTransportError(TransportErrorCode) {}
func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective }

// A closedRemoteConn is a connection that was closed remotely.
// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE.
Expand All @@ -57,6 +58,7 @@ func newClosedRemoteConn(pers protocol.Perspective) packetHandler {
return &closedRemoteConn{perspective: pers}
}

func (s *closedRemoteConn) handlePacket(receivedPacket) {}
func (s *closedRemoteConn) destroy(error) {}
func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective }
func (c *closedRemoteConn) handlePacket(receivedPacket) {}
func (c *closedRemoteConn) destroy(error) {}
func (c *closedRemoteConn) closeWithTransportError(TransportErrorCode) {}
func (c *closedRemoteConn) getPerspective() protocol.Perspective { return c.perspective }
5 changes: 5 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -1581,6 +1581,11 @@ func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) erro
return nil
}

func (s *connection) closeWithTransportError(code TransportErrorCode) {
s.closeLocal(&qerr.TransportError{ErrorCode: code})
<-s.ctx.Done()
}

func (s *connection) handleCloseError(closeErr *closeError) {
e := closeErr.err
if e == nil {
Expand Down
35 changes: 35 additions & 0 deletions integrationtests/self/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,41 @@ var _ = Describe("Handshake tests", func() {
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
})

It("closes handshaking connections when the server is closed", func() {
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
tr := quic.Transport{
Conn: udpConn,
}
defer tr.Close()
tlsConf := &tls.Config{}
done := make(chan struct{})
tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
<-done
return nil, errors.New("closed")
}
ln, err := tr.Listen(tlsConf, getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())

errChan := make(chan error, 1)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
go func() {
defer GinkgoRecover()
_, err := quic.DialAddr(ctx, ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil))
errChan <- err
}()
time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued
Expect(ln.Close()).To(Succeed())
close(done)
err = <-errChan
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
})
})

Context("ALPN", func() {
Expand Down
37 changes: 37 additions & 0 deletions mock_packet_handler_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 36 additions & 0 deletions mock_quic_conn_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ var ErrServerClosed = errors.New("quic: server closed")
type packetHandler interface {
handlePacket(receivedPacket)
destroy(error)
closeWithTransportError(qerr.TransportErrorCode)
getPerspective() protocol.Perspective
}

Expand All @@ -44,6 +45,7 @@ type quicConn interface {
getPerspective() protocol.Perspective
run() error
destroy(error)
closeWithTransportError(TransportErrorCode)
}

type zeroRTTQueue struct {
Expand Down Expand Up @@ -693,7 +695,7 @@ func (s *baseServer) handleNewConn(conn quicConn) {
// wait until the early connection is ready, the handshake fails, or the server is closed
select {
case <-s.errorChan:
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
conn.closeWithTransportError(ConnectionRefused)
return
case <-conn.earlyConnReady():
case <-connCtx.Done():
Expand All @@ -703,7 +705,7 @@ func (s *baseServer) handleNewConn(conn quicConn) {
// wait until the handshake is complete (or fails)
select {
case <-s.errorChan:
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
conn.closeWithTransportError(ConnectionRefused)
return
case <-conn.HandshakeComplete():
case <-connCtx.Done():
Expand Down
10 changes: 5 additions & 5 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ var _ = Describe("Server", func() {
Eventually(run).Should(BeClosed())
Eventually(done).Should(BeClosed())
// shutdown
conn.EXPECT().destroy(gomock.Any())
conn.EXPECT().closeWithTransportError(gomock.Any())
})

It("sends a Version Negotiation Packet for unsupported versions", func() {
Expand Down Expand Up @@ -530,7 +530,7 @@ var _ = Describe("Server", func() {
Eventually(run).Should(BeClosed())
Eventually(done).Should(BeClosed())
// shutdown
conn.EXPECT().destroy(gomock.Any()).MaxTimes(1)
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1)
})

It("drops packets if the receive queue is full", func() {
Expand Down Expand Up @@ -570,7 +570,7 @@ var _ = Describe("Server", func() {
conn.EXPECT().Context().Return(context.Background()).MaxTimes(1)
conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1)
// shutdown
conn.EXPECT().destroy(gomock.Any()).MaxTimes(1)
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1)
return conn
}

Expand Down Expand Up @@ -1008,7 +1008,7 @@ var _ = Describe("Server", func() {
) quicConn {
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().destroy(&qerr.TransportError{ErrorCode: ConnectionRefused}).Do(func(error) { close(destroyed) })
conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(TransportErrorCode) { close(destroyed) })
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn.EXPECT().run().MaxTimes(1)
conn.EXPECT().Context().Return(context.Background())
Expand Down Expand Up @@ -1468,7 +1468,7 @@ var _ = Describe("Server", func() {
conn.EXPECT().Context().Return(context.Background())
close(called)
// shutdown
conn.EXPECT().destroy(gomock.Any())
conn.EXPECT().closeWithTransportError(gomock.Any())
return conn
}

Expand Down