Skip to content

Commit

Permalink
send out the CONNECTION_REFUSED error when refusing a connection (#4250)
Browse files Browse the repository at this point in the history
So far, we used Connection.destroy, which destroys a connection without
sending out a CONNECTION_CLOSE frame. This is useful (for example) when
receiving a stateless reset, but it's not what we want when the server
refuses an incoming connection. In this case, we want to send out a
packet with a CONNECTION_CLOSE frame to inform the client that the
connection attempt is being rejected.
  • Loading branch information
marten-seemann committed Jan 19, 2024
1 parent b3eb375 commit cb1775a
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 12 deletions.
12 changes: 7 additions & 5 deletions closed_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ func (c *closedLocalConn) handlePacket(p receivedPacket) {
c.sendPacket(p.remoteAddr, p.info)
}

func (c *closedLocalConn) destroy(error) {}
func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective }
func (c *closedLocalConn) destroy(error) {}
func (c *closedLocalConn) closeWithTransportError(TransportErrorCode) {}
func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective }

// A closedRemoteConn is a connection that was closed remotely.
// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE.
Expand All @@ -57,6 +58,7 @@ func newClosedRemoteConn(pers protocol.Perspective) packetHandler {
return &closedRemoteConn{perspective: pers}
}

func (s *closedRemoteConn) handlePacket(receivedPacket) {}
func (s *closedRemoteConn) destroy(error) {}
func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective }
func (c *closedRemoteConn) handlePacket(receivedPacket) {}
func (c *closedRemoteConn) destroy(error) {}
func (c *closedRemoteConn) closeWithTransportError(TransportErrorCode) {}
func (c *closedRemoteConn) getPerspective() protocol.Perspective { return c.perspective }
5 changes: 5 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -1581,6 +1581,11 @@ func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) erro
return nil
}

func (s *connection) closeWithTransportError(code TransportErrorCode) {
s.closeLocal(&qerr.TransportError{ErrorCode: code})
<-s.ctx.Done()
}

func (s *connection) handleCloseError(closeErr *closeError) {
e := closeErr.err
if e == nil {
Expand Down
35 changes: 35 additions & 0 deletions integrationtests/self/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,41 @@ var _ = Describe("Handshake tests", func() {
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
})

It("closes handshaking connections when the server is closed", func() {
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
tr := quic.Transport{
Conn: udpConn,
}
defer tr.Close()
tlsConf := &tls.Config{}
done := make(chan struct{})
tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
<-done
return nil, errors.New("closed")
}
ln, err := tr.Listen(tlsConf, getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())

errChan := make(chan error, 1)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
go func() {
defer GinkgoRecover()
_, err := quic.DialAddr(ctx, ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil))
errChan <- err
}()
time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued
Expect(ln.Close()).To(Succeed())
close(done)
err = <-errChan
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
})
})

Context("ALPN", func() {
Expand Down
37 changes: 37 additions & 0 deletions mock_packet_handler_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 36 additions & 0 deletions mock_quic_conn_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ var ErrServerClosed = errors.New("quic: server closed")
type packetHandler interface {
handlePacket(receivedPacket)
destroy(error)
closeWithTransportError(qerr.TransportErrorCode)
getPerspective() protocol.Perspective
}

Expand All @@ -44,6 +45,7 @@ type quicConn interface {
getPerspective() protocol.Perspective
run() error
destroy(error)
closeWithTransportError(TransportErrorCode)
}

type zeroRTTQueue struct {
Expand Down Expand Up @@ -693,7 +695,7 @@ func (s *baseServer) handleNewConn(conn quicConn) {
// wait until the early connection is ready, the handshake fails, or the server is closed
select {
case <-s.errorChan:
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
conn.closeWithTransportError(ConnectionRefused)
return
case <-conn.earlyConnReady():
case <-connCtx.Done():
Expand All @@ -703,7 +705,7 @@ func (s *baseServer) handleNewConn(conn quicConn) {
// wait until the handshake is complete (or fails)
select {
case <-s.errorChan:
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
conn.closeWithTransportError(ConnectionRefused)
return
case <-conn.HandshakeComplete():
case <-connCtx.Done():
Expand Down
10 changes: 5 additions & 5 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ var _ = Describe("Server", func() {
Eventually(run).Should(BeClosed())
Eventually(done).Should(BeClosed())
// shutdown
conn.EXPECT().destroy(gomock.Any())
conn.EXPECT().closeWithTransportError(gomock.Any())
})

It("sends a Version Negotiation Packet for unsupported versions", func() {
Expand Down Expand Up @@ -530,7 +530,7 @@ var _ = Describe("Server", func() {
Eventually(run).Should(BeClosed())
Eventually(done).Should(BeClosed())
// shutdown
conn.EXPECT().destroy(gomock.Any()).MaxTimes(1)
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1)
})

It("drops packets if the receive queue is full", func() {
Expand Down Expand Up @@ -570,7 +570,7 @@ var _ = Describe("Server", func() {
conn.EXPECT().Context().Return(context.Background()).MaxTimes(1)
conn.EXPECT().HandshakeComplete().Return(make(chan struct{})).MaxTimes(1)
// shutdown
conn.EXPECT().destroy(gomock.Any()).MaxTimes(1)
conn.EXPECT().closeWithTransportError(gomock.Any()).MaxTimes(1)
return conn
}

Expand Down Expand Up @@ -1008,7 +1008,7 @@ var _ = Describe("Server", func() {
) quicConn {
conn := NewMockQUICConn(mockCtrl)
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().destroy(&qerr.TransportError{ErrorCode: ConnectionRefused}).Do(func(error) { close(destroyed) })
conn.EXPECT().closeWithTransportError(ConnectionRefused).Do(func(TransportErrorCode) { close(destroyed) })
conn.EXPECT().HandshakeComplete().Return(make(chan struct{}))
conn.EXPECT().run().MaxTimes(1)
conn.EXPECT().Context().Return(context.Background())
Expand Down Expand Up @@ -1468,7 +1468,7 @@ var _ = Describe("Server", func() {
conn.EXPECT().Context().Return(context.Background())
close(called)
// shutdown
conn.EXPECT().destroy(gomock.Any())
conn.EXPECT().closeWithTransportError(gomock.Any())
return conn
}

Expand Down

0 comments on commit cb1775a

Please sign in to comment.