Skip to content

Commit

Permalink
RetryClient: add ResponseTimeout and DirectlyPublishQoS0 options (#190)
Browse files Browse the repository at this point in the history
* Add ResponseTimeout option to RetryClient
* Add option to set RetryClient to ReconnectClient
* Add RetryClient.OnError callback
  • Loading branch information
at-wat committed Oct 20, 2021
1 parent 3b3e9f7 commit 29f4e22
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 13 deletions.
10 changes: 10 additions & 0 deletions error.go
Expand Up @@ -41,6 +41,16 @@ func (e *errorWithRetry) Retry(ctx context.Context, cli *BaseClient) error {
return e.retryFn(ctx, cli)
}

// RequestTimeoutError is a context deadline exceeded error caused by RetryClient.ResponseTimeout.
type RequestTimeoutError struct {
error
}

// Error implements error.
func (e *RequestTimeoutError) Error() string {
return fmt.Sprintf("request timeout exceeded: %v", e.error.Error())
}

type errorInterface interface {
Error() string
Unwrap() error
Expand Down
13 changes: 12 additions & 1 deletion reconnclient.go
Expand Up @@ -41,14 +41,15 @@ func NewReconnectClient(dialer Dialer, opts ...ReconnectOption) (ReconnectClient
options := &ReconnectOptions{
ReconnectWaitBase: time.Second,
ReconnectWaitMax: 10 * time.Second,
RetryClient: &RetryClient{},
}
for _, opt := range opts {
if err := opt(options); err != nil {
return nil, err
}
}
return &reconnectClient{
RetryClient: &RetryClient{},
RetryClient: options.RetryClient,
done: make(chan struct{}),
disconnected: make(chan struct{}),
options: options,
Expand Down Expand Up @@ -201,6 +202,7 @@ type ReconnectOptions struct {
ReconnectWaitBase time.Duration
ReconnectWaitMax time.Duration
PingInterval time.Duration
RetryClient *RetryClient
}

// ReconnectOption sets option for Connect.
Expand Down Expand Up @@ -232,3 +234,12 @@ func WithPingInterval(interval time.Duration) ReconnectOption {
return nil
}
}

// WithRetryClient sets RetryClient.
// Default value is zero RetryClient.
func WithRetryClient(cli *RetryClient) ReconnectOption {
return func(o *ReconnectOptions) error {
o.RetryClient = cli
return nil
}
}
89 changes: 77 additions & 12 deletions retryclient.go
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"errors"
"sync"
"time"
)

// ErrClosedClient means operation was requested on closed client.
Expand All @@ -29,16 +30,28 @@ type RetryClient struct {
chConnectErr chan error
chConnSwitch chan struct{}

retryQueue []retryFn
subEstablished subscriptions // acknoledged subscriptions
mu sync.RWMutex
handler Handler
chTask chan struct{}
stopped bool
taskQueue []func(ctx context.Context, cli *BaseClient)
newRetryByError bool
retryQueue []retryFn
subEstablished subscriptions // acknoledged subscriptions
mu sync.RWMutex
handler Handler
chTask chan struct{}
stopped bool
taskQueue []func(ctx context.Context, cli *BaseClient)

muStats sync.RWMutex
stats RetryStats

// Maximum duration to wait for acknoledge response.
// Messages with QoS1 and QoS2 will be retried.
ResponseTimeout time.Duration

// Directly publish QoS0 messages without queuing.
// It will cause inorder of the messages but performance may be increased.
DirectlyPublishQoS0 bool

// Callback to receive background errors on raw message publish/subscribe operations.
OnError func(error)
}

// Retryer is an interface to control message retrying.
Expand Down Expand Up @@ -86,6 +99,10 @@ func (c *RetryClient) Publish(ctx context.Context, message *Message) error {
cli := c.cli
c.mu.RUnlock()

if c.DirectlyPublishQoS0 && message.QoS == QoS0 {
return cli.Publish(ctx, message)
}

if cli != nil {
if err := cli.ValidateMessage(message); err != nil {
return wrapError(err, "validating publishing message")
Expand Down Expand Up @@ -119,7 +136,10 @@ func (c *RetryClient) publish(ctx context.Context, cli *BaseClient, message *Mes
return
}
publish := func(ctx context.Context, cli *BaseClient, message *Message) {
if err := cli.Publish(ctx, message); err != nil {
ctx2, cancel := c.requestContext(ctx)
defer cancel()
if err := cli.Publish(ctx2, message); err != nil {
c.onError(err)
select {
case <-ctx.Done():
// User cancelled; don't queue.
Expand All @@ -128,6 +148,7 @@ func (c *RetryClient) publish(ctx context.Context, cli *BaseClient, message *Mes
}
if retryErr, ok := err.(ErrorWithRetry); ok {
c.retryQueue = append(c.retryQueue, retryErr.Retry)
c.newRetryByError = true
}
}
return
Expand All @@ -151,7 +172,11 @@ func (c *RetryClient) publish(ctx context.Context, cli *BaseClient, message *Mes
func (c *RetryClient) subscribe(ctx context.Context, retry bool, cli *BaseClient, subs ...Subscription) {
subscribe := func(ctx context.Context, cli *BaseClient) error {
subscriptions(subs).applyTo(&c.subEstablished)
if _, err := cli.Subscribe(ctx, subs...); err != nil {

ctx2, cancel := c.requestContext(ctx)
defer cancel()
if _, err := cli.Subscribe(ctx2, subs...); err != nil {
c.onError(err)
select {
case <-ctx.Done():
if !retry {
Expand All @@ -162,6 +187,7 @@ func (c *RetryClient) subscribe(ctx context.Context, retry bool, cli *BaseClient
}
if retryErr, ok := err.(ErrorWithRetry); ok {
c.retryQueue = append(c.retryQueue, retryErr.Retry)
c.newRetryByError = true
}
}
return nil
Expand All @@ -178,7 +204,11 @@ func (c *RetryClient) subscribe(ctx context.Context, retry bool, cli *BaseClient
func (c *RetryClient) unsubscribe(ctx context.Context, cli *BaseClient, topics ...string) {
unsubscribe := func(ctx context.Context, cli *BaseClient) error {
unsubscriptions(topics).applyTo(&c.subEstablished)
if err := cli.Unsubscribe(ctx, topics...); err != nil {

ctx2, cancel := c.requestContext(ctx)
defer cancel()
if err := cli.Unsubscribe(ctx2, topics...); err != nil {
c.onError(err)
select {
case <-ctx.Done():
// User cancelled; don't queue.
Expand All @@ -187,6 +217,7 @@ func (c *RetryClient) unsubscribe(ctx context.Context, cli *BaseClient, topics .
}
if retryErr, ok := err.(ErrorWithRetry); ok {
c.retryQueue = append(c.retryQueue, retryErr.Retry)
c.newRetryByError = true
}
}
return nil
Expand All @@ -203,7 +234,11 @@ func (c *RetryClient) unsubscribe(ctx context.Context, cli *BaseClient, topics .
// Disconnect from the broker.
func (c *RetryClient) Disconnect(ctx context.Context) error {
err := wrapError(c.pushTask(ctx, func(ctx context.Context, cli *BaseClient) {
cli.Disconnect(ctx)
ctx2, cancel := c.requestContext(ctx)
defer cancel()
if err := cli.Disconnect(ctx2); err != nil {
c.onError(err)
}
}), "retryclient: disconnecting")
c.mu.Lock()
close(c.chTask)
Expand All @@ -217,7 +252,9 @@ func (c *RetryClient) Ping(ctx context.Context) error {
c.mu.RLock()
cli := c.cli
c.mu.RUnlock()
return wrapError(cli.Ping(ctx), "retryclient: pinging")
ctx2, cancel := c.requestContext(ctx)
defer cancel()
return wrapError(cli.Ping(ctx2), "retryclient: pinging")
}

// Client returns the base client.
Expand Down Expand Up @@ -302,10 +339,32 @@ func (c *RetryClient) SetClient(ctx context.Context, cli *BaseClient) {
c.muStats.Unlock()

task(ctx, cli)

if c.newRetryByError {
_ = cli.Close()
connected = false
c.newRetryByError = false
}
}
}()
}

func (c *RetryClient) requestContext(ctx context.Context) (context.Context, func()) {
if c.ResponseTimeout == 0 {
return ctx, func() {}
}
ctx2, cancel := context.WithTimeout(ctx, c.ResponseTimeout)
return &requestContext{ctx2}, cancel
}

type requestContext struct {
context.Context
}

func (c *requestContext) Err() error {
return &RequestTimeoutError{c.Context.Err()}
}

func (c *RetryClient) pushTask(ctx context.Context, task func(ctx context.Context, cli *BaseClient)) error {
c.mu.Lock()
defer c.mu.Unlock()
Expand All @@ -322,6 +381,12 @@ func (c *RetryClient) pushTask(ctx context.Context, task func(ctx context.Contex
return nil
}

func (c *RetryClient) onError(err error) {
if c.OnError != nil {
c.OnError(err)
}
}

// Connect to the broker.
func (c *RetryClient) Connect(ctx context.Context, clientID string, opts ...ConnectOption) (sessionPresent bool, err error) {
c.mu.Lock()
Expand Down
137 changes: 137 additions & 0 deletions retryclient_integration_test.go
Expand Up @@ -20,6 +20,8 @@ package mqtt
import (
"context"
"crypto/tls"
"io"
"net"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -359,3 +361,138 @@ func TestIntegration_RetryClient_RetryInitialRequest(t *testing.T) {
})
}
}

func baseCliRemovePacket(ctx context.Context, t *testing.T, packetPass func([]byte) bool) *BaseClient {
t.Helper()
cliBase, err := DialContext(ctx, urls["MQTT"], WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
if err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
ca, cb := net.Pipe()
connOrig := cliBase.Transport
cliBase.Transport = cb
go func() {
buf := make([]byte, 1024)
for {
n, err := ca.Read(buf)
if err != nil {
return
}
if !packetPass(buf[:n]) {
continue
}
if _, err := connOrig.Write(buf[:n]); err != nil {
return
}
}
}()
go func() {
io.Copy(ca, connOrig)
}()
return cliBase
}

func TestIntegration_RetryClient_ResponseTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cli := &RetryClient{
ResponseTimeout: 50 * time.Millisecond,
}
cli.SetClient(ctx, baseCliRemovePacket(ctx, t, func(b []byte) bool {
return b[0] != byte(packetPublish)|byte(publishFlagQoS1)
}))

if _, err := cli.Connect(ctx, "RetryClientTimeout"); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

if err := cli.Publish(ctx, &Message{
Topic: "test/ResponseTimeout",
QoS: QoS1,
Payload: []byte("message"),
}); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

select {
case <-time.After(200 * time.Millisecond):
t.Fatal("Timeout")
case <-cli.Client().Done():
// Client must be closed due to response timeout.
}
expectRetryStats(t, RetryStats{
QueuedRetries: 1,
TotalTasks: 1,
}, cli.Stats())

cli.SetClient(ctx, baseCliRemovePacket(ctx, t, func([]byte) bool {
return true
}))
cli.Retry(ctx)
if _, err := cli.Connect(ctx, "RetryClientTimeout"); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

time.Sleep(150 * time.Millisecond)
expectRetryStats(t, RetryStats{
TotalRetries: 1,
TotalTasks: 2,
}, cli.Stats())

if err := cli.Disconnect(ctx); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
}

func TestIntegration_RetryClient_DirectlyPublishQoS0(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

cli := &RetryClient{
DirectlyPublishQoS0: true,
}

publishQoS0Msg := make(chan struct{})
cli.SetClient(ctx, baseCliRemovePacket(ctx, t, func(b []byte) bool {
if b[0] == byte(packetPublish) {
close(publishQoS0Msg)
}
return b[0] != byte(packetPublish)|byte(publishFlagQoS1)
}))

if _, err := cli.Connect(ctx, "RetryClientDirectlyPublishQoS0"); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

if err := cli.Publish(ctx, &Message{
Topic: "test/DirectlyPublishQoS0",
QoS: QoS1,
Payload: []byte("message"),
}); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

time.Sleep(150 * time.Millisecond)
expectRetryStats(t, RetryStats{
TotalRetries: 0,
TotalTasks: 1,
}, cli.Stats())

if err := cli.Publish(ctx, &Message{
Topic: "test/DirectlyPublishQoS0",
QoS: QoS0,
Payload: []byte("message"),
}); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

select {
case <-publishQoS0Msg:
case <-time.After(150 * time.Millisecond):
t.Fatal("Timeout")
}
expectRetryStats(t, RetryStats{
TotalRetries: 0,
TotalTasks: 1,
}, cli.Stats())
}
2 changes: 2 additions & 0 deletions retryclient_test.go
Expand Up @@ -20,6 +20,8 @@ import (
"testing"
)

var _ Retryer = &RetryClient{} // RetryClient must implement Retryer.

func TestRetryClientPublish_MessageValidationError(t *testing.T) {
cli := &RetryClient{
cli: &BaseClient{
Expand Down

0 comments on commit 29f4e22

Please sign in to comment.