Skip to content

Commit

Permalink
Parameterize StreamingDataflowWorkerTest so we run all tests with/wit…
Browse files Browse the repository at this point in the history
…hout streaming engine. I had to implement FakeServer.getDataStream. This is super basic at the moment. Some tests are skipped in streaming engine mode, I left TODO to go back and revisit those.
  • Loading branch information
drieber committed Feb 20, 2019
1 parent 90b3e45 commit 18370fb
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -734,11 +734,6 @@ public final synchronized void close() {
requestObserver.onCompleted();
}

@Override
public final void awaitTermination() throws InterruptedException {
finishLatch.await();
}

@Override
public final boolean awaitTermination(int time, TimeUnit unit) throws InterruptedException {
return finishLatch.await(time, unit);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,6 @@ public interface WindmillStream {
/** Indicates that no more requests will be sent. */
void close();

/**
* Waits for the server to close its end of the connection.
*
* <p>Should only be called after calling close.
*/
void awaitTermination() throws InterruptedException;

/** Waits for the server to close its end of the connection, with timeout. */
boolean awaitTermination(int time, TimeUnit unit) throws InterruptedException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.not;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
Expand Down Expand Up @@ -219,11 +217,6 @@ public void close() {
done.countDown();
}

@Override
public void awaitTermination() throws InterruptedException {
done.await();
}

@Override
public boolean awaitTermination(int time, TimeUnit unit) throws InterruptedException {
return done.await(time, unit);
Expand All @@ -238,7 +231,62 @@ public Instant startTime() {

@Override
public GetDataStream getDataStream() {
throw new UnsupportedOperationException();
Instant startTime = Instant.now();
return new GetDataStream() {
@Override
public Windmill.KeyedGetDataResponse requestKeyedData(
String computation, KeyedGetDataRequest request) {
Windmill.GetDataRequest getDataRequest =
GetDataRequest.newBuilder()
.addRequests(
ComputationGetDataRequest.newBuilder()
.setComputationId(computation)
.addRequests(request)
.build())
.build();
GetDataResponse getDataResponse = getData(getDataRequest);
if (getDataResponse.getDataList().isEmpty()) {
return null;
}
assertEquals(1, getDataResponse.getDataCount());
if (getDataResponse.getData(0).getDataList().isEmpty()) {
return null;
}
assertEquals(1, getDataResponse.getData(0).getDataCount());
return getDataResponse.getData(0).getData(0);
}

@Override
public Windmill.GlobalData requestGlobalData(Windmill.GlobalDataRequest request) {
Windmill.GetDataRequest getDataRequest =
GetDataRequest.newBuilder().addGlobalDataFetchRequests(request).build();
GetDataResponse getDataResponse = getData(getDataRequest);
if (getDataResponse.getGlobalDataList().isEmpty()) {
return null;
}
assertEquals(1, getDataResponse.getGlobalDataCount());
return getDataResponse.getGlobalData(0);
}

@Override
public void refreshActiveWork(Map<String, List<KeyedGetDataRequest>> active) {}

@Override
public void close() {}

@Override
public boolean awaitTermination(int time, TimeUnit unit) {
return true;
}

@Override
public void closeAfterDefaultTimeout() {}

@Override
public Instant startTime() {
return startTime;
}
};
}

@Override
Expand Down Expand Up @@ -266,9 +314,6 @@ public void flush() {}
@Override
public void close() {}

@Override
public void awaitTermination() {}

@Override
public boolean awaitTermination(int time, TimeUnit unit) {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,27 @@
import org.junit.rules.TestRule;
import org.junit.runner.Description;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.junit.runners.Parameterized;
import org.junit.runners.model.Statement;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Unit tests for {@link StreamingDataflowWorker}. */
@RunWith(JUnit4.class)
@RunWith(Parameterized.class)
public class StreamingDataflowWorkerTest {
private final boolean streamingEngine;

@Parameterized.Parameters(name = "{index}: [streamingEngine={0}]")
public static Iterable<Object[]> data() {
return Arrays.asList(new Object[][] {{false}, {true}});
}

public StreamingDataflowWorkerTest(Boolean streamingEngine) {
this.streamingEngine = streamingEngine;
}

private static final Logger LOG = LoggerFactory.getLogger(StreamingDataflowWorkerTest.class);

private static final IntervalWindow DEFAULT_WINDOW =
Expand Down Expand Up @@ -593,8 +604,13 @@ private ByteString addPaneTag(PaneInfo pane, byte[] windowBytes)

private StreamingDataflowWorkerOptions createTestingPipelineOptions(
FakeWindmillServer server, String... args) {
List<String> argsList = Lists.newArrayList(args);
if (streamingEngine) {
argsList.add("--experiments=enable_streaming_engine");
}
StreamingDataflowWorkerOptions options =
PipelineOptionsFactory.fromArgs(args).as(StreamingDataflowWorkerOptions.class);
PipelineOptionsFactory.fromArgs(argsList.toArray(new String[0]))
.as(StreamingDataflowWorkerOptions.class);
options.setAppName("StreamingWorkerHarnessTest");
options.setJobId("test_job_id");
options.setStreaming(true);
Expand Down Expand Up @@ -650,45 +666,7 @@ public void testBasicHarness() throws Exception {
}

@Test
public void testBasicWindmillServiceHarness() throws Exception {
List<ParallelInstruction> instructions =
Arrays.asList(
makeSourceInstruction(StringUtf8Coder.of()),
makeSinkInstruction(StringUtf8Coder.of(), 0));

FakeWindmillServer server = new FakeWindmillServer(errorCollector);
server.setIsReady(false);

StreamingConfigTask streamingConfig = new StreamingConfigTask();
streamingConfig.setStreamingComputationConfigs(
ImmutableList.of(makeDefaultStreamingComputationConfig(instructions)));
streamingConfig.setWindmillServiceEndpoint("foo");
WorkItem workItem = new WorkItem();
workItem.setStreamingConfigTask(streamingConfig);
when(mockWorkUnitClient.getGlobalStreamingConfigWorkItem()).thenReturn(Optional.of(workItem));

StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server, "--experiments=enable_windmill_service");
StreamingDataflowWorker worker = makeWorker(instructions, options, true /* publishCounters */);
worker.start();

final int numIters = 2000;
for (int i = 0; i < numIters; ++i) {
server.addWorkToOffer(makeInput(i, TimeUnit.MILLISECONDS.toMicros(i)));
}

Map<Long, Windmill.WorkItemCommitRequest> result = server.waitForAndGetCommits(numIters);
worker.stop();

for (int i = 0; i < numIters; ++i) {
assertTrue(result.containsKey((long) i));
assertEquals(
makeExpectedOutput(i, TimeUnit.MILLISECONDS.toMicros(i)).build(), result.get((long) i));
}
}

@Test
public void testBasicWindmillServiceAsStreamingEngineHarness() throws Exception {
public void testBasic() throws Exception {
List<ParallelInstruction> instructions =
Arrays.asList(
makeSourceInstruction(StringUtf8Coder.of()),
Expand All @@ -705,8 +683,7 @@ public void testBasicWindmillServiceAsStreamingEngineHarness() throws Exception
workItem.setStreamingConfigTask(streamingConfig);
when(mockWorkUnitClient.getGlobalStreamingConfigWorkItem()).thenReturn(Optional.of(workItem));

StreamingDataflowWorkerOptions options =
createTestingPipelineOptions(server, "--experiments=enable_streaming_engine");
StreamingDataflowWorkerOptions options = createTestingPipelineOptions(server);
StreamingDataflowWorker worker = makeWorker(instructions, options, true /* publishCounters */);
worker.start();

Expand Down Expand Up @@ -887,6 +864,10 @@ public void processElement(ProcessContext c) {

@Test
public void testKeyTokenInvalidException() throws Exception {
if (streamingEngine) {
// TODO: This test needs to be adapted to work with streamingEngine=true.
return;
}
KvCoder<String, String> kvCoder = KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of());

List<ParallelInstruction> instructions =
Expand Down Expand Up @@ -1092,6 +1073,10 @@ public void processElement(ProcessContext c) throws Exception {

@Test(timeout = 30000)
public void testExceptions() throws Exception {
if (streamingEngine) {
// TODO: This test needs to be adapted to work with streamingEngine=true.
return;
}
List<ParallelInstruction> instructions =
Arrays.asList(
makeSourceInstruction(StringUtf8Coder.of()),
Expand Down Expand Up @@ -2422,6 +2407,10 @@ public void processElement(ProcessContext c) throws Exception {

@Test
public void testActiveWorkRefresh() throws Exception {
if (streamingEngine) {
// TODO: This test needs to be adapted to work with streamingEngine=true.
return;
}
List<ParallelInstruction> instructions =
Arrays.asList(
makeSourceInstruction(StringUtf8Coder.of()),
Expand Down

0 comments on commit 18370fb

Please sign in to comment.