diff --git a/error.go b/error.go index 5a89651..569e292 100644 --- a/error.go +++ b/error.go @@ -41,6 +41,16 @@ func (e *errorWithRetry) Retry(ctx context.Context, cli *BaseClient) error { return e.retryFn(ctx, cli) } +// RequestTimeoutError is a context deadline exceeded error caused by RetryClient.ResponseTimeout. +type RequestTimeoutError struct { + error +} + +// Error implements error. +func (e *RequestTimeoutError) Error() string { + return fmt.Sprintf("request timeout exceeded: %v", e.error.Error()) +} + type errorInterface interface { Error() string Unwrap() error diff --git a/reconnclient.go b/reconnclient.go index 15ef946..fcad1b4 100644 --- a/reconnclient.go +++ b/reconnclient.go @@ -41,6 +41,7 @@ func NewReconnectClient(dialer Dialer, opts ...ReconnectOption) (ReconnectClient options := &ReconnectOptions{ ReconnectWaitBase: time.Second, ReconnectWaitMax: 10 * time.Second, + RetryClient: &RetryClient{}, } for _, opt := range opts { if err := opt(options); err != nil { @@ -48,7 +49,7 @@ func NewReconnectClient(dialer Dialer, opts ...ReconnectOption) (ReconnectClient } } return &reconnectClient{ - RetryClient: &RetryClient{}, + RetryClient: options.RetryClient, done: make(chan struct{}), disconnected: make(chan struct{}), options: options, @@ -201,6 +202,7 @@ type ReconnectOptions struct { ReconnectWaitBase time.Duration ReconnectWaitMax time.Duration PingInterval time.Duration + RetryClient *RetryClient } // ReconnectOption sets option for Connect. @@ -232,3 +234,12 @@ func WithPingInterval(interval time.Duration) ReconnectOption { return nil } } + +// WithRetryClient sets RetryClient. +// Default value is zero RetryClient. +func WithRetryClient(cli *RetryClient) ReconnectOption { + return func(o *ReconnectOptions) error { + o.RetryClient = cli + return nil + } +} diff --git a/retryclient.go b/retryclient.go index 9a9ff41..eeaafb4 100644 --- a/retryclient.go +++ b/retryclient.go @@ -18,6 +18,7 @@ import ( "context" "errors" "sync" + "time" ) // ErrClosedClient means operation was requested on closed client. @@ -29,16 +30,28 @@ type RetryClient struct { chConnectErr chan error chConnSwitch chan struct{} - retryQueue []retryFn - subEstablished subscriptions // acknoledged subscriptions - mu sync.RWMutex - handler Handler - chTask chan struct{} - stopped bool - taskQueue []func(ctx context.Context, cli *BaseClient) + newRetryByError bool + retryQueue []retryFn + subEstablished subscriptions // acknoledged subscriptions + mu sync.RWMutex + handler Handler + chTask chan struct{} + stopped bool + taskQueue []func(ctx context.Context, cli *BaseClient) muStats sync.RWMutex stats RetryStats + + // Maximum duration to wait for acknoledge response. + // Messages with QoS1 and QoS2 will be retried. + ResponseTimeout time.Duration + + // Directly publish QoS0 messages without queuing. + // It will cause inorder of the messages but performance may be increased. + DirectlyPublishQoS0 bool + + // Callback to receive background errors on raw message publish/subscribe operations. + OnError func(error) } // Retryer is an interface to control message retrying. @@ -86,6 +99,10 @@ func (c *RetryClient) Publish(ctx context.Context, message *Message) error { cli := c.cli c.mu.RUnlock() + if c.DirectlyPublishQoS0 && message.QoS == QoS0 { + return cli.Publish(ctx, message) + } + if cli != nil { if err := cli.ValidateMessage(message); err != nil { return wrapError(err, "validating publishing message") @@ -119,7 +136,10 @@ func (c *RetryClient) publish(ctx context.Context, cli *BaseClient, message *Mes return } publish := func(ctx context.Context, cli *BaseClient, message *Message) { - if err := cli.Publish(ctx, message); err != nil { + ctx2, cancel := c.requestContext(ctx) + defer cancel() + if err := cli.Publish(ctx2, message); err != nil { + c.onError(err) select { case <-ctx.Done(): // User cancelled; don't queue. @@ -128,6 +148,7 @@ func (c *RetryClient) publish(ctx context.Context, cli *BaseClient, message *Mes } if retryErr, ok := err.(ErrorWithRetry); ok { c.retryQueue = append(c.retryQueue, retryErr.Retry) + c.newRetryByError = true } } return @@ -151,7 +172,11 @@ func (c *RetryClient) publish(ctx context.Context, cli *BaseClient, message *Mes func (c *RetryClient) subscribe(ctx context.Context, retry bool, cli *BaseClient, subs ...Subscription) { subscribe := func(ctx context.Context, cli *BaseClient) error { subscriptions(subs).applyTo(&c.subEstablished) - if _, err := cli.Subscribe(ctx, subs...); err != nil { + + ctx2, cancel := c.requestContext(ctx) + defer cancel() + if _, err := cli.Subscribe(ctx2, subs...); err != nil { + c.onError(err) select { case <-ctx.Done(): if !retry { @@ -162,6 +187,7 @@ func (c *RetryClient) subscribe(ctx context.Context, retry bool, cli *BaseClient } if retryErr, ok := err.(ErrorWithRetry); ok { c.retryQueue = append(c.retryQueue, retryErr.Retry) + c.newRetryByError = true } } return nil @@ -178,7 +204,11 @@ func (c *RetryClient) subscribe(ctx context.Context, retry bool, cli *BaseClient func (c *RetryClient) unsubscribe(ctx context.Context, cli *BaseClient, topics ...string) { unsubscribe := func(ctx context.Context, cli *BaseClient) error { unsubscriptions(topics).applyTo(&c.subEstablished) - if err := cli.Unsubscribe(ctx, topics...); err != nil { + + ctx2, cancel := c.requestContext(ctx) + defer cancel() + if err := cli.Unsubscribe(ctx2, topics...); err != nil { + c.onError(err) select { case <-ctx.Done(): // User cancelled; don't queue. @@ -187,6 +217,7 @@ func (c *RetryClient) unsubscribe(ctx context.Context, cli *BaseClient, topics . } if retryErr, ok := err.(ErrorWithRetry); ok { c.retryQueue = append(c.retryQueue, retryErr.Retry) + c.newRetryByError = true } } return nil @@ -203,7 +234,11 @@ func (c *RetryClient) unsubscribe(ctx context.Context, cli *BaseClient, topics . // Disconnect from the broker. func (c *RetryClient) Disconnect(ctx context.Context) error { err := wrapError(c.pushTask(ctx, func(ctx context.Context, cli *BaseClient) { - cli.Disconnect(ctx) + ctx2, cancel := c.requestContext(ctx) + defer cancel() + if err := cli.Disconnect(ctx2); err != nil { + c.onError(err) + } }), "retryclient: disconnecting") c.mu.Lock() close(c.chTask) @@ -217,7 +252,9 @@ func (c *RetryClient) Ping(ctx context.Context) error { c.mu.RLock() cli := c.cli c.mu.RUnlock() - return wrapError(cli.Ping(ctx), "retryclient: pinging") + ctx2, cancel := c.requestContext(ctx) + defer cancel() + return wrapError(cli.Ping(ctx2), "retryclient: pinging") } // Client returns the base client. @@ -302,10 +339,32 @@ func (c *RetryClient) SetClient(ctx context.Context, cli *BaseClient) { c.muStats.Unlock() task(ctx, cli) + + if c.newRetryByError { + _ = cli.Close() + connected = false + c.newRetryByError = false + } } }() } +func (c *RetryClient) requestContext(ctx context.Context) (context.Context, func()) { + if c.ResponseTimeout == 0 { + return ctx, func() {} + } + ctx2, cancel := context.WithTimeout(ctx, c.ResponseTimeout) + return &requestContext{ctx2}, cancel +} + +type requestContext struct { + context.Context +} + +func (c *requestContext) Err() error { + return &RequestTimeoutError{c.Context.Err()} +} + func (c *RetryClient) pushTask(ctx context.Context, task func(ctx context.Context, cli *BaseClient)) error { c.mu.Lock() defer c.mu.Unlock() @@ -322,6 +381,12 @@ func (c *RetryClient) pushTask(ctx context.Context, task func(ctx context.Contex return nil } +func (c *RetryClient) onError(err error) { + if c.OnError != nil { + c.OnError(err) + } +} + // Connect to the broker. func (c *RetryClient) Connect(ctx context.Context, clientID string, opts ...ConnectOption) (sessionPresent bool, err error) { c.mu.Lock() diff --git a/retryclient_integration_test.go b/retryclient_integration_test.go index bf9e827..1cbfd4b 100644 --- a/retryclient_integration_test.go +++ b/retryclient_integration_test.go @@ -20,6 +20,8 @@ package mqtt import ( "context" "crypto/tls" + "io" + "net" "sync/atomic" "testing" "time" @@ -359,3 +361,138 @@ func TestIntegration_RetryClient_RetryInitialRequest(t *testing.T) { }) } } + +func baseCliRemovePacket(ctx context.Context, t *testing.T, packetPass func([]byte) bool) *BaseClient { + t.Helper() + cliBase, err := DialContext(ctx, urls["MQTT"], WithTLSConfig(&tls.Config{InsecureSkipVerify: true})) + if err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + ca, cb := net.Pipe() + connOrig := cliBase.Transport + cliBase.Transport = cb + go func() { + buf := make([]byte, 1024) + for { + n, err := ca.Read(buf) + if err != nil { + return + } + if !packetPass(buf[:n]) { + continue + } + if _, err := connOrig.Write(buf[:n]); err != nil { + return + } + } + }() + go func() { + io.Copy(ca, connOrig) + }() + return cliBase +} + +func TestIntegration_RetryClient_ResponseTimeout(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + cli := &RetryClient{ + ResponseTimeout: 50 * time.Millisecond, + } + cli.SetClient(ctx, baseCliRemovePacket(ctx, t, func(b []byte) bool { + return b[0] != byte(packetPublish)|byte(publishFlagQoS1) + })) + + if _, err := cli.Connect(ctx, "RetryClientTimeout"); err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + + if err := cli.Publish(ctx, &Message{ + Topic: "test/ResponseTimeout", + QoS: QoS1, + Payload: []byte("message"), + }); err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + + select { + case <-time.After(200 * time.Millisecond): + t.Fatal("Timeout") + case <-cli.Client().Done(): + // Client must be closed due to response timeout. + } + expectRetryStats(t, RetryStats{ + QueuedRetries: 1, + TotalTasks: 1, + }, cli.Stats()) + + cli.SetClient(ctx, baseCliRemovePacket(ctx, t, func([]byte) bool { + return true + })) + cli.Retry(ctx) + if _, err := cli.Connect(ctx, "RetryClientTimeout"); err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + + time.Sleep(150 * time.Millisecond) + expectRetryStats(t, RetryStats{ + TotalRetries: 1, + TotalTasks: 2, + }, cli.Stats()) + + if err := cli.Disconnect(ctx); err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } +} + +func TestIntegration_RetryClient_DirectlyPublishQoS0(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cli := &RetryClient{ + DirectlyPublishQoS0: true, + } + + publishQoS0Msg := make(chan struct{}) + cli.SetClient(ctx, baseCliRemovePacket(ctx, t, func(b []byte) bool { + if b[0] == byte(packetPublish) { + close(publishQoS0Msg) + } + return b[0] != byte(packetPublish)|byte(publishFlagQoS1) + })) + + if _, err := cli.Connect(ctx, "RetryClientDirectlyPublishQoS0"); err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + + if err := cli.Publish(ctx, &Message{ + Topic: "test/DirectlyPublishQoS0", + QoS: QoS1, + Payload: []byte("message"), + }); err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + + time.Sleep(150 * time.Millisecond) + expectRetryStats(t, RetryStats{ + TotalRetries: 0, + TotalTasks: 1, + }, cli.Stats()) + + if err := cli.Publish(ctx, &Message{ + Topic: "test/DirectlyPublishQoS0", + QoS: QoS0, + Payload: []byte("message"), + }); err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + + select { + case <-publishQoS0Msg: + case <-time.After(150 * time.Millisecond): + t.Fatal("Timeout") + } + expectRetryStats(t, RetryStats{ + TotalRetries: 0, + TotalTasks: 1, + }, cli.Stats()) +} diff --git a/retryclient_test.go b/retryclient_test.go index f847722..8293e77 100644 --- a/retryclient_test.go +++ b/retryclient_test.go @@ -20,6 +20,8 @@ import ( "testing" ) +var _ Retryer = &RetryClient{} // RetryClient must implement Retryer. + func TestRetryClientPublish_MessageValidationError(t *testing.T) { cli := &RetryClient{ cli: &BaseClient{