Skip to content

Commit

Permalink
allow tracing connection reads
Browse files Browse the repository at this point in the history
  • Loading branch information
aldas committed Jul 27, 2023
1 parent cb503a9 commit f5f23d7
Showing 1 changed file with 37 additions and 5 deletions.
42 changes: 37 additions & 5 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import (
)

const (
readTimeout = 1 * time.Millisecond
writeTimeout = 1 * time.Millisecond
readTimeout = 5 * time.Millisecond
writeTimeout = 50 * time.Millisecond
idleTimeout = 25 * time.Second
)

Expand All @@ -33,6 +33,11 @@ type PacketAssembler interface {
ReceiveRead(ctx context.Context, received []byte, bytesRead int) (response []byte, closeConnection bool)
}

// RawReadTracer is PacketAssembler optional interface that is called for each Read from connection to allow tracing read results
type RawReadTracer interface {
Read(data []byte, n int, err error)
}

// ModbusHandler calls Handle method when it has received enough data to be parsed into Modbus packet.
type ModbusHandler interface {
Handle(ctx context.Context, received packet.Request) (packet.Response, error)
Expand All @@ -49,6 +54,11 @@ type Server struct {
activeConnections map[*connection]struct{}
activeConnectionCount atomic.Int64

// WriteTimeout is amount of time writing the request can take after it errors out
WriteTimeout time.Duration
// ReadTimeout is amount of time reading the response can take
ReadTimeout time.Duration

// AssemblerCreatorFunc creates Assembler for each connetion to assemble different read byte fragments into complete
// modbus packet. Could have different implementations for TCP or RTU packets
AssemblerCreatorFunc func(handler ModbusHandler) PacketAssembler
Expand All @@ -74,6 +84,9 @@ type connection struct {
isBeingHandled atomic.Bool
assembler PacketAssembler

writeTimeout time.Duration
readTimeout time.Duration

onErrorFunc func(error)
}

Expand All @@ -93,6 +106,9 @@ func (s *Server) Serve(ctx context.Context, listener net.Listener, handler Modbu
return s.serve(ctx, listener, handler)
}

// ContextRemoteAddr is context.Context value containing clients remote address
type ContextRemoteAddr struct{}

func (s *Server) serve(ctx context.Context, listener net.Listener, handler ModbusHandler) error {
if s.AssemblerCreatorFunc == nil {
s.AssemblerCreatorFunc = func(handler ModbusHandler) PacketAssembler {
Expand Down Expand Up @@ -139,10 +155,13 @@ func (s *Server) serve(ctx context.Context, listener net.Listener, handler Modbu
default:
}

cCtx := context.WithValue(ctx, ContextRemoteAddr{}, netConn.RemoteAddr())
c := &connection{
conn: netConn,
isBeingHandled: atomic.Bool{},
assembler: s.AssemblerCreatorFunc(handler),
writeTimeout: s.WriteTimeout,
readTimeout: s.ReadTimeout,
onErrorFunc: onErrorFunc,
}
s.trackConn(c, true)
Expand All @@ -160,7 +179,7 @@ func (s *Server) serve(ctx context.Context, listener net.Listener, handler Modbu
}
}()
conn.handle(ctx)
}(ctx, c)
}(cCtx, c)
}
}

Expand Down Expand Up @@ -200,6 +219,16 @@ func (c *connection) handle(ctx context.Context) {
cCtx, cCancel := context.WithCancel(ctx)
defer cCancel()

rTimeout := readTimeout
if c.readTimeout > 0 {
rTimeout = c.readTimeout
}
wTimeout := writeTimeout
if c.writeTimeout > 0 {
wTimeout = c.writeTimeout
}

rrt, debugRawRead := c.assembler.(RawReadTracer)
conn := c.conn
var lastReceived time.Time
received := make([]byte, 300)
Expand All @@ -210,8 +239,11 @@ func (c *connection) handle(ctx context.Context) {
default:
}

_ = conn.SetReadDeadline(time.Now().Add(readTimeout))
_ = conn.SetReadDeadline(time.Now().Add(rTimeout))
n, err := conn.Read(received)
if debugRawRead {
rrt.Read(received[0:n], n, err)
}
if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
if !errors.Is(err, io.EOF) {
c.onErrorFunc(err)
Expand All @@ -229,7 +261,7 @@ func (c *connection) handle(ctx context.Context) {
c.isBeingHandled.Store(true)
toSend, closeConn := c.assembler.ReceiveRead(cCtx, received[0:n], n)
if toSend != nil {
_ = conn.SetWriteDeadline(time.Now().Add(writeTimeout))
_ = conn.SetWriteDeadline(time.Now().Add(wTimeout))
if _, err := conn.Write(toSend); err != nil {
c.onErrorFunc(err)
return // when write fails to client we close connection
Expand Down

0 comments on commit f5f23d7

Please sign in to comment.