diff --git a/flink-connectors/flink-connector-rabbitmq/src/main/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSink.java b/flink-connectors/flink-connector-rabbitmq/src/main/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSink.java index 5966713c55896..0218df5e51eb7 100644 --- a/flink-connectors/flink-connector-rabbitmq/src/main/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSink.java +++ b/flink-connectors/flink-connector-rabbitmq/src/main/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSink.java @@ -137,13 +137,28 @@ public void setLogFailuresOnly(boolean logFailuresOnly) { this.logFailuresOnly = logFailuresOnly; } + /** + * Initializes the connection to RMQ with a default connection factory. The user may override + * this method to setup and configure their own {@link ConnectionFactory}. + */ + protected ConnectionFactory setupConnectionFactory() throws Exception { + return rmqConnectionConfig.getConnectionFactory(); + } + + /** + * Initializes the connection to RMQ using the default connection factory from {@link #setupConnectionFactory()}. + * The user may override this method to setup and configure their own {@link Connection}. + */ + protected Connection setupConnection() throws Exception { + return setupConnectionFactory().newConnection(); + } + @Override public void open(Configuration config) throws Exception { - ConnectionFactory factory = rmqConnectionConfig.getConnectionFactory(); schema.open(() -> getRuntimeContext().getMetricGroup().addGroup("user")); try { - connection = factory.newConnection(); + connection = setupConnection(); channel = connection.createChannel(); if (channel == null) { throw new RuntimeException("None of RabbitMQ channels are available"); diff --git a/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSinkTest.java b/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSinkTest.java index ea126d01f94b9..72fe1af19f1da 100644 --- a/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSinkTest.java +++ b/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSinkTest.java @@ -43,6 +43,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -97,6 +98,24 @@ public void openCallDontDeclaresQueueInWithOptionsMode() throws Exception { verify(channel, never()).queueDeclare(null, true, false, false, null); } + @Test + public void testOverrideConnection() throws Exception { + final Connection mockConnection = mock(Connection.class); + Channel channel = mock(Channel.class); + when(mockConnection.createChannel()).thenReturn(channel); + + RMQSink rmqSink = new RMQSink(rmqConnectionConfig, QUEUE_NAME, serializationSchema) { + @Override + protected Connection setupConnection() throws Exception { + return mockConnection; + } + }; + + rmqSink.open(new Configuration()); + + verify(mockConnection, times(1)).createChannel(); + } + @Test public void throwExceptionIfChannelIsNull() throws Exception { when(connection.createChannel()).thenReturn(null);