Skip to content

Commit

Permalink
Release Permits on Conversion Errors (awspring#1090)
Browse files Browse the repository at this point in the history
Fixes awspring#1051

When a conversion error occurred, permits were not being released properly eventually leading to permit depletion.
  • Loading branch information
tomazfernandes committed Mar 15, 2024
1 parent 9a9dd7c commit 4bd05af
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
import io.awspring.cloud.sqs.support.converter.MessageConversionContext;
import io.awspring.cloud.sqs.support.converter.MessagingMessageConverter;
import java.util.Collection;
import java.util.Objects;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;

Expand All @@ -46,6 +49,8 @@
*/
public abstract class AbstractMessageConvertingMessageSource<T, S> implements MessageSource<T> {

private static final Logger logger = LoggerFactory.getLogger(AbstractMessageConvertingMessageSource.class);

private MessagingMessageConverter<S> messagingMessageConverter;

@Nullable
Expand Down Expand Up @@ -82,14 +87,22 @@ private MessageConversionContext maybeCreateConversionContext() {
}

protected Collection<Message<T>> convertMessages(Collection<S> messages) {
return messages.stream().map(this::convertMessage).collect(Collectors.toList());
return messages.stream().map(this::convertMessage).filter(Objects::nonNull).collect(Collectors.toList());
}

@Nullable
@SuppressWarnings("unchecked")
protected Message<T> convertMessage(S msg) {
return this.messagingMessageConverter instanceof ContextAwareMessagingMessageConverter
? (Message<T>) getContextAwareConverter().toMessagingMessage(msg, this.messageConversionContext)
: (Message<T>) this.messagingMessageConverter.toMessagingMessage(msg);
try {
logger.trace("Converting message {}", msg);
return this.messagingMessageConverter instanceof ContextAwareMessagingMessageConverter
? (Message<T>) getContextAwareConverter().toMessagingMessage(msg, this.messageConversionContext)
: (Message<T>) this.messagingMessageConverter.toMessagingMessage(msg);
}
catch (Exception e) {
logger.error("Error converting message {}, ignoring.", msg, e);
return null;
}
}

private ContextAwareMessagingMessageConverter<S> getContextAwareConverter() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ private void pollAndEmitMessages() {
managePollingFuture(doPollForMessages(acquiredPermits))
.thenApply(this::resetBackOffContext)
.exceptionally(this::handlePollingException)
.thenApply(msgs -> releaseUnusedPermits(acquiredPermits, msgs))
.thenApply(this::convertMessages)
.thenApply(msgs -> releaseUnusedPermits(acquiredPermits, msgs))
.thenCompose(this::emitMessagesToPipeline)
.exceptionally(this::handleSinkException);
// @formatter:on
Expand Down Expand Up @@ -251,7 +251,7 @@ private void handlePollBackOff() {

protected abstract CompletableFuture<Collection<S>> doPollForMessages(int messagesToRequest);

public Collection<S> releaseUnusedPermits(int permits, Collection<S> msgs) {
public Collection<Message<T>> releaseUnusedPermits(int permits, Collection<Message<T>> msgs) {
if (msgs.isEmpty() && permits == this.backPressureHandler.getBatchSize()) {
this.backPressureHandler.releaseBatch();
logger.trace("Released batch of unused permits for queue {}", this.pollingEndpointName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import io.awspring.cloud.sqs.listener.SqsContainerOptions;
import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementCallback;
import io.awspring.cloud.sqs.listener.acknowledgement.AcknowledgementProcessor;
import io.awspring.cloud.sqs.support.converter.MessageConversionContext;
import io.awspring.cloud.sqs.support.converter.SqsMessagingMessageConverter;
import java.time.Duration;
import java.util.Collection;
import java.util.Collections;
Expand All @@ -42,11 +44,14 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.task.TaskExecutor;
import org.springframework.lang.Nullable;
import org.springframework.retry.backoff.BackOffContext;
import org.springframework.retry.backoff.BackOffPolicy;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
Expand Down Expand Up @@ -127,7 +132,7 @@ else if (hasMadeSecondPoll.compareAndSet(false, true)) {
source.setId(testName + " source");
source.configure(SqsContainerOptions.builder().build());
source.setTaskExecutor(createTaskExecutor(testName));
source.setAcknowledgementProcessor(getAcknowledgementProcessor());
source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor());
source.start();
assertThat(doAwait(pollingCounter)).isTrue();
assertThat(doAwait(processingCounter)).isTrue();
Expand Down Expand Up @@ -225,14 +230,75 @@ else if (hasAcquired9.compareAndSet(false, true)) {
source.setId(testName + " source");
source.configure(SqsContainerOptions.builder().build());
source.setTaskExecutor(createTaskExecutor(testName));
source.setAcknowledgementProcessor(getAcknowledgementProcessor());
source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor());
source.start();
assertThat(doAwait(processingCounter)).isTrue();
assertThat(doAwait(pollingCounter)).isTrue();
source.stop();
assertThat(hasThrownError.get()).isFalse();
}

@Test
void shouldReleasePermitsOnConversionErrors() {
String testName = "shouldReleasePermitsOnConversionErrors";
SemaphoreBackPressureHandler backPressureHandler = SemaphoreBackPressureHandler.builder()
.acquireTimeout(Duration.ofMillis(150)).batchSize(10).totalPermits(10)
.throughputConfiguration(BackPressureMode.AUTO).build();

AtomicInteger convertedMessages = new AtomicInteger(0);
AtomicInteger messagesInSink = new AtomicInteger(0);
AtomicBoolean hasFailed = new AtomicBoolean(false);

var converter = new SqsMessagingMessageConverter() {
@Override
public org.springframework.messaging.Message<?> toMessagingMessage(Message source,
@Nullable MessageConversionContext context) {
var converted = convertedMessages.incrementAndGet();
logger.trace("Messages converted: {}", converted);
if (converted % 9 == 0) {
throw new RuntimeException("Expected error");
}
return super.toMessagingMessage(source, context);
}
};

AbstractPollingMessageSource<Object, Message> source = new AbstractPollingMessageSource<>() {

@Override
protected CompletableFuture<Collection<Message>> doPollForMessages(int messagesToRequest) {
if (messagesToRequest != 10) {
logger.error("Expected 10 messages to requesst, received {}", messagesToRequest);
hasFailed.set(true);
}
return convertedMessages.get() < 30 ? CompletableFuture.completedFuture(create10Messages())
: CompletableFuture.completedFuture(List.of());
}

private Collection<Message> create10Messages() {
return IntStream.range(0, 10).mapToObj(
index -> Message.builder().messageId(UUID.randomUUID().toString()).body("test-message").build())
.toList();
}
};

source.setBackPressureHandler(backPressureHandler);
source.setMessageSink((msgs, context) -> {
msgs.forEach(message -> messagesInSink.incrementAndGet());
msgs.forEach(msg -> context.runBackPressureReleaseCallback());
return CompletableFuture.completedFuture(null);
});
source.setId(testName + " source");
source.configure(SqsContainerOptions.builder().messageConverter(converter).build());
source.setPollingEndpointName("shouldReleasePermitsOnConversionErrors-queue");
source.setTaskExecutor(createTaskExecutor(testName));
source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor());
source.start();
Awaitility.waitAtMost(Duration.ofSeconds(10)).until(() -> convertedMessages.get() == 30);
assertThat(hasFailed).isFalse();
assertThat(messagesInSink).hasValue(27);
source.stop();
}

@Test
void shouldBackOffIfPollingThrowsAnError() {

Expand Down Expand Up @@ -277,7 +343,7 @@ else if (currentPoll.compareAndSet(2, 3)) {
source.configure(SqsContainerOptions.builder().pollBackOffPolicy(policy).build());

source.setTaskExecutor(createTaskExecutor(testName));
source.setAcknowledgementProcessor(getAcknowledgementProcessor());
source.setAcknowledgementProcessor(getNoOpsAcknowledgementProcessor());
source.start();

doAwait(waitThirdPollLatch);
Expand Down Expand Up @@ -315,6 +381,7 @@ private void assertAvailablePermitsLessThanOrEqualTo(SemaphoreBackPressureHandle
.isLessThanOrEqualTo(maxExpectedPermits);
}

// Used to slow down tests while developing
private void doSleep(int time) {
try {
Thread.sleep(time);
Expand Down Expand Up @@ -343,7 +410,7 @@ protected ThreadFactory createThreadFactory(String testName) {
return threadFactory;
}

private AcknowledgementProcessor<Object> getAcknowledgementProcessor() {
private AcknowledgementProcessor<Object> getNoOpsAcknowledgementProcessor() {
return new AcknowledgementProcessor<>() {
@Override
public AcknowledgementCallback<Object> getAcknowledgementCallback() {
Expand Down

0 comments on commit 4bd05af

Please sign in to comment.