Skip to content

Commit

Permalink
Merge pull request apache#11364: [BEAM-9651] Prevent StreamPool and s…
Browse files Browse the repository at this point in the history
…tream initialization livelock
  • Loading branch information
reuvenlax committed Apr 9, 2020
2 parents 95a5944 + d36f873 commit 28b081f
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.beam.runners.dataflow.worker.windmill;

import java.util.concurrent.Phaser;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.stub.CallStreamObserver;
import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.stub.StreamObserver;
Expand All @@ -43,13 +45,20 @@ public DirectStreamObserver(Phaser phaser, CallStreamObserver<T> outboundObserve

@Override
public void onNext(T value) {
int phase = phaser.getPhase();
if (!outboundObserver.isReady()) {
int phase = phaser.getPhase(); // A negative phase indicates it has been terminated.
// The registered onReady may be blocked, so we periodically poll the observer directly.
// Additionally to avoid becoming permanently stuck due to synchronization we fallback
// to queuing in the outbound observer after 1 minute, see BEAM-9651 for more context.
for (int waitLoops = 0;
phase >= 0 && !outboundObserver.isReady() && waitLoops < 600;
++waitLoops) {
try {
phaser.awaitAdvanceInterruptibly(phase);
phase = phaser.awaitAdvanceInterruptibly(phase, 100, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
} catch (TimeoutException e) {
// Polling isReady in case the callback is delayed
}
}
synchronized (outboundObserver) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public void onCompleted() {
}

@Override
public void beforeStart(ClientCallStreamObserver<RespT> stream) {
stream.setOnReadyHandler(onReadyHandler);
public void beforeStart(ClientCallStreamObserver<RespT> requestStream) {
requestStream.setOnReadyHandler(onReadyHandler);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,22 @@
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Supplier;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Suppliers;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.net.HostAndPort;
import org.joda.time.Duration;
import org.joda.time.Instant;
Expand Down Expand Up @@ -154,65 +156,113 @@ boolean commitWorkItem(
public static class StreamPool<S extends WindmillStream> {

private final Duration streamTimeout;
private final Supplier<S> supplier;

private final class StreamData {
final S stream = supplier.get();
int holds = 1;
final Supplier<S> lazyStream = Suppliers.memoize(supplier);
Instant streamCreated = Instant.now();
AtomicInteger holds = new AtomicInteger(1);
};

private final List<StreamData> streams;
private final Supplier<S> supplier;
private final HashMap<S, StreamData> holds;
private final ConcurrentHashMap<S, StreamData> holds;

public StreamPool(int numStreams, Duration streamTimeout, Supplier<S> supplier) {
this.streamTimeout = streamTimeout;
this.supplier = supplier;
this.streams = new ArrayList<>(numStreams);
for (int i = 0; i < numStreams; i++) {
streams.add(null);
}
this.streamTimeout = streamTimeout;
this.supplier = supplier;
this.holds = new HashMap<>();
this.holds = new ConcurrentHashMap<>();
}

// Returns a stream for use that may be cached from a previous call. Each call of getStream
// must be matched with a call of releaseStream.
public S getStream() {
int index = ThreadLocalRandom.current().nextInt(streams.size());
S result;
S closeStream = null;
Instant timeoutThreshold = Instant.now().minus(streamTimeout);
StreamData streamData = null;
StreamData closeStream = null;
synchronized (this) {
StreamData streamData = streams.get(index);
if (streamData == null
|| streamData.stream.startTime().isBefore(Instant.now().minus(streamTimeout))) {
if (streamData != null && --streamData.holds == 0) {
holds.remove(streamData.stream);
closeStream = streamData.stream;
streamData = streams.get(index);
if (streamData != null) {
if (streamData.streamCreated.isBefore(timeoutThreshold)) {
if (streamData.holds.decrementAndGet() <= 0) {
closeStream = streamData;
}
streamData = null; // Fall through to create a new stream
}
}
if (streamData == null) {
streamData = new StreamData();
streams.set(index, streamData);
holds.put(streamData.stream, streamData);
}
streamData.holds++;
result = streamData.stream;
// The hold is decremented by releaseStream.
streamData.holds.incrementAndGet();
}
// Close the previous stream if it was retired and there were no other holds.
if (closeStream != null) {
closeStream.close();
assert (closeStream.holds.intValue() == 0);
S stream = closeStream.lazyStream.get();
StreamData removed = holds.remove(stream);
assert (removed == closeStream);
stream.close();
}
return result;
// Initialize the stream outside the synchronized section so that slow initialization does
// not block other streams.
S stream = streamData.lazyStream.get();
holds.put(stream, streamData);
return stream;
}

// Releases a stream that was obtained with getStream.
// Releases a stream that was obtained with getStream. If the stream was retired and this was
// the final hold it is closed.
public void releaseStream(S stream) {
boolean closeStream = false;
StreamData streamData = holds.get(stream);
if (streamData.holds.decrementAndGet() <= 0) {
StreamData removed = holds.remove(stream);
assert (removed == streamData);
stream.close();
}
}

// Closes and awaits termination for all streams that do not have an active external hold,
// returning true if all streams were closed.
public boolean closeIdle(int duration, TimeUnit unit) throws InterruptedException {
boolean removedAll = true;
ArrayList<StreamData> streamsCopy = null;
synchronized (this) {
if (--holds.get(stream).holds == 0) {
closeStream = true;
holds.remove(stream);
streamsCopy = new ArrayList<>(streams.size());
for (int i = 0; i < streams.size(); ++i) {
StreamData streamData = streams.get(i);
streams.set(i, null);
streamsCopy.add(streamData);
}
}
if (closeStream) {
stream.close();
for (int i = 0; i < streamsCopy.size(); ++i) {
StreamData streamData = streamsCopy.get(i);
if (streamData == null) {
continue;
}

if (streamData.holds.decrementAndGet() <= 0) {
S stream = streamData.lazyStream.get();
StreamData removed = holds.remove(stream);
assert (removed == streamData);
stream.close();
} else {
removedAll = false;
streamsCopy.set(i, null);
}
}
for (StreamData streamData : streamsCopy) {
if (streamData == null) {
continue;
}
streamData.lazyStream.get().awaitTermination(duration, unit);
}
return removedAll;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -59,6 +60,7 @@
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.CommitWorkStream;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.GetDataStream;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.GetWorkStream;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.StreamPool;
import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.Server;
import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.Status;
Expand All @@ -67,6 +69,7 @@
import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.stub.StreamObserver;
import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.util.MutableHandlerRegistry;
import org.hamcrest.Matchers;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.After;
import org.junit.Before;
Expand Down Expand Up @@ -268,14 +271,11 @@ public void onCompleted() {
assertEquals(workItem.getKey(), ByteString.copyFromUtf8("somewhat_long_key"));
});
assertTrue(latch.await(30, TimeUnit.SECONDS));

stream.close();
assertTrue(stream.awaitTermination(30, TimeUnit.SECONDS));
}

@Test
@SuppressWarnings("FutureReturnValueIgnored")
public void testStreamingGetData() throws Exception {
private void addGetDataService() {
// This server responds to GetDataRequests with responses that mirror the requests.
serviceRegistry.addService(
new CloudWindmillServiceV1Alpha1ImplBase() {
Expand Down Expand Up @@ -406,36 +406,83 @@ private void flushResponse() {
};
}
});
}

@Test
public void testStreamingGetData() throws Exception {
addGetDataService();
GetDataStream stream = client.getDataStream();

// Make requests of varying sizes to test chunking, and verify the responses.
ExecutorService executor = Executors.newFixedThreadPool(50);
final CountDownLatch done = new CountDownLatch(200);
List<Future> futures = new ArrayList<>(200);

for (int i = 0; i < 100; ++i) {
final String key = "key" + i;
final String s = i % 5 == 0 ? largeString(i) : "tag";
executor.submit(
() -> {
errorCollector.checkThat(
stream.requestKeyedData("computation", makeGetDataRequest(key, s)),
Matchers.equalTo(makeGetDataResponse(key, s)));
done.countDown();
});
executor.execute(
() -> {
errorCollector.checkThat(
stream.requestGlobalData(makeGlobalDataRequest(key)),
Matchers.equalTo(makeGlobalDataResponse(key)));
done.countDown();
});
futures.add(
executor.submit(
() -> {
errorCollector.checkThat(
stream.requestKeyedData("computation", makeGetDataRequest(key, s)),
Matchers.equalTo(makeGetDataResponse(key, s)));
}));
futures.add(
executor.submit(
() -> {
errorCollector.checkThat(
stream.requestGlobalData(makeGlobalDataRequest(key)),
Matchers.equalTo(makeGlobalDataResponse(key)));
}));
Thread.sleep((i * 17) % 50);
}
for (Future f : futures) {
f.get();
}
done.await();
stream.close();
assertTrue(stream.awaitTermination(60, TimeUnit.SECONDS));
executor.shutdown();
}

@Test
public void testStreamingGetDataWithPool() throws Exception {
addGetDataService();

final StreamPool<GetDataStream> streamPool =
new StreamPool<GetDataStream>(4, Duration.standardSeconds(1), () -> client.getDataStream());

// Make requests of varying sizes to test chunking, and verify the responses.
ExecutorService executor = Executors.newFixedThreadPool(50);
List<Future> futures = new ArrayList<>(200);
for (int i = 0; i < 100; ++i) {
final String key = "key" + i;
final String s = i % 5 == 0 ? largeString(i) : "tag";
futures.add(
executor.submit(
() -> {
GetDataStream stream = streamPool.getStream();
errorCollector.checkThat(
stream.requestKeyedData("computation", makeGetDataRequest(key, s)),
Matchers.equalTo(makeGetDataResponse(key, s)));
streamPool.releaseStream(stream);
}));
futures.add(
executor.submit(
() -> {
GetDataStream stream = streamPool.getStream();
errorCollector.checkThat(
stream.requestGlobalData(makeGlobalDataRequest(key)),
Matchers.equalTo(makeGlobalDataResponse(key)));
streamPool.releaseStream(stream);
}));
Thread.sleep((i * 17) % 50);
}
for (Future f : futures) {
f.get();
}
assertTrue(streamPool.closeIdle(60, TimeUnit.SECONDS));
executor.shutdown();
}

private String largeString(int length) {
return String.join("", Collections.nCopies(length, "."));
}
Expand Down

0 comments on commit 28b081f

Please sign in to comment.