Skip to content

Commit

Permalink
[FLINK-8208][network-tests] Reduce mockito usage in RecordWriterTest
Browse files Browse the repository at this point in the history
  • Loading branch information
pnowojski authored and StefanRRichter committed Jan 8, 2018
1 parent 409ea23 commit 97db0bf
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.flink.runtime.io.network.util.TestBufferFactory;
import org.apache.flink.runtime.io.network.util.TestPooledBufferProvider;
import org.apache.flink.runtime.io.network.util.TestTaskEvent;
import org.apache.flink.runtime.operators.testutils.ExpectedTestException;
import org.apache.flink.runtime.testutils.DiscardingRecycler;
import org.apache.flink.types.IntValue;
import org.apache.flink.util.XORShiftRandom;
Expand Down Expand Up @@ -172,12 +173,11 @@ public Void call() throws Exception {

@Test
public void testClearBuffersAfterExceptionInPartitionWriter() throws Exception {
NetworkBufferPool buffers = null;
NetworkBufferPool buffers = new NetworkBufferPool(1, 1024);
BufferPool bufferPool = null;

try {
buffers = new NetworkBufferPool(1, 1024);
bufferPool = spy(buffers.createBufferPool(1, Integer.MAX_VALUE));
bufferPool = buffers.createBufferPool(1, Integer.MAX_VALUE);

ResultPartitionWriter partitionWriter = mock(ResultPartitionWriter.class);
when(partitionWriter.getBufferProvider()).thenReturn(checkNotNull(bufferPool));
Expand All @@ -190,12 +190,19 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
Buffer buffer = (Buffer) invocation.getArguments()[0];
buffer.recycle();

throw new RuntimeException("Expected test Exception");
throw new ExpectedTestException();
}
}).when(partitionWriter).writeBuffer(any(Buffer.class), anyInt());

RecordWriter<IntValue> recordWriter = new RecordWriter<>(partitionWriter);

// Validate that memory segment was assigned to recordWriter
assertEquals(1, buffers.getNumberOfAvailableMemorySegments());
assertEquals(0, bufferPool.getNumberOfAvailableMemorySegments());
recordWriter.emit(new IntValue(0));
assertEquals(0, buffers.getNumberOfAvailableMemorySegments());
assertEquals(0, bufferPool.getNumberOfAvailableMemorySegments());

try {
// Verify that emit correctly clears the buffer. The infinite loop looks
// dangerous indeed, but the buffer will only be flushed after its full. Adding a
Expand All @@ -204,7 +211,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
recordWriter.emit(new IntValue(0));
}
}
catch (Exception e) {
catch (ExpectedTestException e) {
// Verify that the buffer is not part of the record writer state after a failure
// to flush it out. If the buffer is still part of the record writer state, this
// will fail, because the buffer has already been recycled. NOTE: The mock
Expand All @@ -214,87 +221,88 @@ public Void answer(InvocationOnMock invocation) throws Throwable {

// Verify expected methods have been called
verify(partitionWriter, times(1)).writeBuffer(any(Buffer.class), anyInt());
verify(bufferPool, times(1)).requestBufferBlocking();
assertEquals(1, bufferPool.getNumberOfAvailableMemorySegments());

try {
// Verify that manual flushing correctly clears the buffer.
recordWriter.emit(new IntValue(0));
assertEquals(0, bufferPool.getNumberOfAvailableMemorySegments());
recordWriter.flush();

Assert.fail("Did not throw expected test Exception");
}
catch (Exception e) {
catch (ExpectedTestException e) {
recordWriter.clearBuffers();
}

// Verify expected methods have been called
verify(partitionWriter, times(2)).writeBuffer(any(Buffer.class), anyInt());
verify(bufferPool, times(2)).requestBufferBlocking();
assertEquals(1, bufferPool.getNumberOfAvailableMemorySegments());

try {
// Verify that broadcast emit correctly clears the buffer.
recordWriter.broadcastEmit(new IntValue(0));
assertEquals(0, bufferPool.getNumberOfAvailableMemorySegments());

for (;;) {
recordWriter.broadcastEmit(new IntValue(0));
}
}
catch (Exception e) {
catch (ExpectedTestException e) {
recordWriter.clearBuffers();
}

// Verify expected methods have been called
verify(partitionWriter, times(3)).writeBuffer(any(Buffer.class), anyInt());
verify(bufferPool, times(3)).requestBufferBlocking();
assertEquals(1, bufferPool.getNumberOfAvailableMemorySegments());

try {
// Verify that end of super step correctly clears the buffer.
recordWriter.emit(new IntValue(0));
assertEquals(0, bufferPool.getNumberOfAvailableMemorySegments());
recordWriter.broadcastEvent(EndOfSuperstepEvent.INSTANCE);

Assert.fail("Did not throw expected test Exception");
}
catch (Exception e) {
catch (ExpectedTestException e) {
recordWriter.clearBuffers();
}

// Verify expected methods have been called
verify(partitionWriter, times(4)).writeBuffer(any(Buffer.class), anyInt());
verify(bufferPool, times(4)).requestBufferBlocking();
assertEquals(1, bufferPool.getNumberOfAvailableMemorySegments());

try {
// Verify that broadcasting and event correctly clears the buffer.
recordWriter.emit(new IntValue(0));
assertEquals(0, bufferPool.getNumberOfAvailableMemorySegments());
recordWriter.broadcastEvent(new TestTaskEvent());

Assert.fail("Did not throw expected test Exception");
}
catch (Exception e) {
catch (ExpectedTestException e) {
recordWriter.clearBuffers();
}

// Verify expected methods have been called
verify(partitionWriter, times(5)).writeBuffer(any(Buffer.class), anyInt());
verify(bufferPool, times(5)).requestBufferBlocking();
assertEquals(1, bufferPool.getNumberOfAvailableMemorySegments());
}
finally {
if (bufferPool != null) {
assertEquals(1, bufferPool.getNumberOfAvailableMemorySegments());
bufferPool.lazyDestroy();
}

if (buffers != null) {
assertEquals(1, buffers.getNumberOfAvailableMemorySegments());
buffers.destroy();
}
assertEquals(1, buffers.getNumberOfAvailableMemorySegments());
buffers.destroy();
}
}

@Test
public void testSerializerClearedAfterClearBuffers() throws Exception {

final Buffer buffer = TestBufferFactory.createBuffer(16);

ResultPartitionWriter partitionWriter = createResultPartitionWriter(
createBufferProvider(buffer));
new TestPooledBufferProvider(1, 16));

RecordWriter<IntValue> recordWriter = new RecordWriter<IntValue>(partitionWriter);

Expand Down Expand Up @@ -324,7 +332,7 @@ public void testBroadcastEventNoRecords() throws Exception {
queues[i] = new ArrayDeque<>();
}

BufferProvider bufferProvider = createBufferProvider(bufferSize);
TestPooledBufferProvider bufferProvider = new TestPooledBufferProvider(Integer.MAX_VALUE, bufferSize);

ResultPartitionWriter partitionWriter = createCollectingPartitionWriter(queues, bufferProvider);
RecordWriter<ByteArrayIO> writer = new RecordWriter<>(partitionWriter, new RoundRobin<ByteArrayIO>());
Expand All @@ -333,7 +341,7 @@ public void testBroadcastEventNoRecords() throws Exception {
// No records emitted yet, broadcast should not request a buffer
writer.broadcastEvent(barrier);

verify(bufferProvider, times(0)).requestBufferBlocking();
assertEquals(0, bufferProvider.getNumberOfCreatedBuffers());

for (Queue<BufferOrEvent> queue : queues) {
assertEquals(1, queue.size());
Expand All @@ -360,7 +368,7 @@ public void testBroadcastEventMixedRecords() throws Exception {
queues[i] = new ArrayDeque<>();
}

BufferProvider bufferProvider = createBufferProvider(bufferSize);
TestPooledBufferProvider bufferProvider = new TestPooledBufferProvider(Integer.MAX_VALUE, bufferSize);

ResultPartitionWriter partitionWriter = createCollectingPartitionWriter(queues, bufferProvider);
RecordWriter<ByteArrayIO> writer = new RecordWriter<>(partitionWriter, new RoundRobin<ByteArrayIO>());
Expand Down Expand Up @@ -393,7 +401,7 @@ public void testBroadcastEventMixedRecords() throws Exception {
// (v) Broadcast the event
writer.broadcastEvent(barrier);

verify(bufferProvider, times(4)).requestBufferBlocking();
assertEquals(4, bufferProvider.getNumberOfCreatedBuffers());

assertEquals(2, queues[0].size()); // 1 buffer + 1 event
assertEquals(3, queues[1].size()); // 2 buffers + 1 event
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,14 @@ public class TestPooledBufferProvider implements BufferProvider {
private final PooledBufferProviderRecycler bufferRecycler;

public TestPooledBufferProvider(int poolSize) {
this(poolSize, 32 * 1024);
}

public TestPooledBufferProvider(int poolSize, int bufferSize) {
checkArgument(poolSize > 0);

this.bufferRecycler = new PooledBufferProviderRecycler(buffers);
this.bufferFactory = new TestBufferFactory(poolSize, 32 * 1024, bufferRecycler);
this.bufferFactory = new TestBufferFactory(poolSize, bufferSize, bufferRecycler);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,14 @@
package org.apache.flink.streaming.runtime.io;

import org.apache.flink.core.io.IOReadableWritable;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.io.network.api.writer.ChannelSelector;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.io.network.api.writer.RoundRobinChannelSelector;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferProvider;
import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.util.TestPooledBufferProvider;
import org.apache.flink.types.LongValue;

import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

import java.io.IOException;

Expand Down Expand Up @@ -86,16 +82,7 @@ public void testPropagateAsyncFlushError() {
}

private static ResultPartitionWriter getMockWriter(int numPartitions) throws Exception {
BufferProvider mockProvider = mock(BufferProvider.class);
when(mockProvider.requestBufferBlocking()).thenAnswer(new Answer<Buffer>() {
@Override
public Buffer answer(InvocationOnMock invocation) {
return new Buffer(
MemorySegmentFactory.allocateUnpooledSegment(4096),
FreeingBufferRecycler.INSTANCE);
}
});

BufferProvider mockProvider = new TestPooledBufferProvider(Integer.MAX_VALUE, 4096);
ResultPartitionWriter mockWriter = mock(ResultPartitionWriter.class);
when(mockWriter.getBufferProvider()).thenReturn(mockProvider);
when(mockWriter.getNumberOfSubpartitions()).thenReturn(numPartitions);
Expand Down

0 comments on commit 97db0bf

Please sign in to comment.