Skip to content

Commit

Permalink
server: expose connection counts and close event
Browse files Browse the repository at this point in the history
  • Loading branch information
aldas committed Jul 27, 2023
1 parent c30cd6a commit cb503a9
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 13 deletions.
2 changes: 0 additions & 2 deletions examples/server_and_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ func TestRequestToServer(t *testing.T) {
serverAddrCh <- addr.String()
log.Printf("listening on: %v\n", addr.String())
},
OnErrorFunc: nil,
OnAcceptFunc: nil,
}

tCtx, tCancel := context.WithTimeout(context.Background(), 1*time.Second)
Expand Down
36 changes: 27 additions & 9 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,30 @@ type ModbusHandler interface {
//
// Public fields are not designed to be goroutine safe. Do not mutate after server has been started
type Server struct {
mu sync.RWMutex
listener net.Listener // for simplicity, we only allow serving one listener
isShutdown atomic.Bool
activeConnections map[*connection]struct{}
mu sync.RWMutex
listener net.Listener // for simplicity, we only allow serving one listener
isShutdown atomic.Bool
activeConnections map[*connection]struct{}
activeConnectionCount atomic.Int64

// AssemblerCreatorFunc creates Assembler for each connetion to assemble different read byte fragments into complete
// modbus packet. Could have different implementations for TCP or RTU packets
AssemblerCreatorFunc func(handler ModbusHandler) PacketAssembler

// OnServeFunc allows capturing listener address just before server starts to accepting connections. This is useful
// for testing when listener is started with random port `:0`.
OnServeFunc func(addr net.Addr)
OnErrorFunc func(err error)
OnAcceptFunc func(ctx context.Context, remoteAddr net.Addr) error
OnServeFunc func(addr net.Addr)

OnErrorFunc func(err error)

// OnAcceptConnFunc is called when server accepts new connection. When method returns an error the connection will be closed.
// connectionCount indicated currently active connection count.
//
// This is where firewall rules and other limits can be implemented
OnAcceptConnFunc func(ctx context.Context, remoteAddr net.Addr, connectionCount uint64) error

// OnCloseConnFunc is called at the end of connection. isServerShutdown indicated if method is called at server shutdown.
OnCloseConnFunc func(ctx context.Context, remoteAddr net.Addr, isServerShutdown bool)
}

type connection struct {
Expand Down Expand Up @@ -114,8 +124,11 @@ func (s *Server) serve(ctx context.Context, listener net.Listener, handler Modbu
return err
}

if s.OnAcceptFunc != nil {
if err := s.OnAcceptFunc(ctx, netConn.RemoteAddr()); err != nil {
if s.OnAcceptConnFunc != nil {
if err := s.OnAcceptConnFunc(ctx, netConn.RemoteAddr(), uint64(s.activeConnectionCount.Load())); err != nil {
if err := netConn.Close(); err != nil {
onErrorFunc(fmt.Errorf("connection.close error, err: %w", err))
}
continue
}
}
Expand All @@ -142,6 +155,9 @@ func (s *Server) serve(ctx context.Context, listener net.Listener, handler Modbu
conn.onErrorFunc(fmt.Errorf("failed to close handler connection, err: %w", err))
}
s.trackConn(c, false)
if s.OnAcceptConnFunc != nil {
s.OnCloseConnFunc(ctx, conn.conn.RemoteAddr(), s.isShutdown.Load())
}
}()
conn.handle(ctx)
}(ctx, c)
Expand Down Expand Up @@ -173,8 +189,10 @@ func (s *Server) trackConn(c *connection, isAdd bool) {
}
if isAdd {
s.activeConnections[c] = struct{}{}
s.activeConnectionCount.Add(1)
} else {
delete(s.activeConnections, c)
s.activeConnectionCount.Add(-1)
}
}

Expand Down
5 changes: 3 additions & 2 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ func TestRequestToServer(t *testing.T) {
OnServeFunc: func(addr net.Addr) {
serverAddrCh <- addr.String()
},
OnErrorFunc: nil,
OnAcceptFunc: nil,
OnErrorFunc: nil,
OnAcceptConnFunc: nil,
OnCloseConnFunc: nil,
}

tCtx, tCancel := context.WithTimeout(context.Background(), 1*time.Second)
Expand Down

0 comments on commit cb503a9

Please sign in to comment.