Skip to content

Commit

Permalink
Merge pull request #10313: [BEAM-8816] Option to load balance bundle …
Browse files Browse the repository at this point in the history
…processing w/ multiple SDK workers
  • Loading branch information
tweise committed Dec 16, 2019
2 parents b816ae3 + c8595ec commit e486202
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@

import com.google.auto.value.AutoValue;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.concurrent.ThreadSafe;
Expand Down Expand Up @@ -58,6 +61,7 @@
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.fn.stream.OutboundObserverFactory;
import org.apache.beam.sdk.function.ThrowingFunction;
import org.apache.beam.sdk.options.ExperimentalOptions;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PortablePipelineOptions;
import org.apache.beam.sdk.options.PortablePipelineOptions.RetrievalServiceType;
Expand Down Expand Up @@ -88,12 +92,16 @@ public class DefaultJobBundleFactory implements JobBundleFactory {

private final String factoryId = factoryIdGenerator.getId();
private final ImmutableList<LoadingCache<Environment, WrappedSdkHarnessClient>> environmentCaches;
private final AtomicInteger stageBundleCount = new AtomicInteger();
private final AtomicInteger stageBundleFactoryCount = new AtomicInteger();
private final Map<String, EnvironmentFactory.Provider> environmentFactoryProviderMap;
private final ExecutorService executor;
private final MapControlClientPool clientPool;
private final IdGenerator stageIdGenerator;
private final int environmentExpirationMillis;
private final Semaphore availableCachesSemaphore;
private final LinkedBlockingDeque<LoadingCache<Environment, WrappedSdkHarnessClient>>
availableCaches;
private final boolean loadBalanceBundles;

public static DefaultJobBundleFactory create(JobInfo jobInfo) {
PipelineOptions pipelineOptions =
Expand Down Expand Up @@ -124,10 +132,13 @@ public static DefaultJobBundleFactory create(
this.clientPool = MapControlClientPool.create();
this.stageIdGenerator = () -> factoryId + "-" + stageIdSuffixGenerator.getId();
this.environmentExpirationMillis = getEnvironmentExpirationMillis(jobInfo);
this.loadBalanceBundles = shouldLoadBalanceBundles(jobInfo);
this.environmentCaches =
createEnvironmentCaches(
serverFactory -> createServerInfo(jobInfo, serverFactory),
getMaxEnvironmentClients(jobInfo));
this.availableCachesSemaphore = new Semaphore(environmentCaches.size(), true);
this.availableCaches = new LinkedBlockingDeque<>(environmentCaches);
}

@VisibleForTesting
Expand All @@ -141,8 +152,11 @@ public static DefaultJobBundleFactory create(
this.clientPool = MapControlClientPool.create();
this.stageIdGenerator = stageIdGenerator;
this.environmentExpirationMillis = getEnvironmentExpirationMillis(jobInfo);
this.loadBalanceBundles = shouldLoadBalanceBundles(jobInfo);
this.environmentCaches =
createEnvironmentCaches(serverFactory -> serverInfo, getMaxEnvironmentClients(jobInfo));
this.availableCachesSemaphore = new Semaphore(environmentCaches.size(), true);
this.availableCaches = new LinkedBlockingDeque<>(environmentCaches);
}

private ImmutableList<LoadingCache<Environment, WrappedSdkHarnessClient>> createEnvironmentCaches(
Expand Down Expand Up @@ -211,6 +225,26 @@ private static int getMaxEnvironmentClients(JobInfo jobInfo) {
return maxEnvironments;
}

private static boolean shouldLoadBalanceBundles(JobInfo jobInfo) {
PipelineOptions pipelineOptions =
PipelineOptionsTranslation.fromProto(jobInfo.pipelineOptions());
boolean loadBalanceBundles =
pipelineOptions.as(PortablePipelineOptions.class).getLoadBalanceBundles();
if (loadBalanceBundles) {
int stateCacheSize =
Integer.parseInt(
MoreObjects.firstNonNull(
ExperimentalOptions.getExperimentValue(
pipelineOptions, ExperimentalOptions.STATE_CACHE_SIZE),
"0"));
Preconditions.checkArgument(
stateCacheSize == 0,
"%s must be 0 when using bundle load balancing",
ExperimentalOptions.STATE_CACHE_SIZE);
}
return loadBalanceBundles;
}

@Override
public StageBundleFactory forStage(ExecutableStage executableStage) {
return new SimpleStageBundleFactory(executableStage);
Expand All @@ -227,6 +261,58 @@ public void close() throws Exception {
executor.shutdown();
}

private static ImmutableMap.Builder<String, RemoteOutputReceiver<?>> getOutputReceivers(
ExecutableProcessBundleDescriptor processBundleDescriptor,
OutputReceiverFactory outputReceiverFactory) {
ImmutableMap.Builder<String, RemoteOutputReceiver<?>> outputReceivers = ImmutableMap.builder();
for (Map.Entry<String, Coder> remoteOutputCoder :
processBundleDescriptor.getRemoteOutputCoders().entrySet()) {
String outputTransform = remoteOutputCoder.getKey();
Coder coder = remoteOutputCoder.getValue();
String bundleOutputPCollection =
Iterables.getOnlyElement(
processBundleDescriptor
.getProcessBundleDescriptor()
.getTransformsOrThrow(outputTransform)
.getInputsMap()
.values());
FnDataReceiver outputReceiver = outputReceiverFactory.create(bundleOutputPCollection);
outputReceivers.put(outputTransform, RemoteOutputReceiver.of(coder, outputReceiver));
}
return outputReceivers;
}

private static class PreparedClient {
private BundleProcessor processor;
private ExecutableProcessBundleDescriptor processBundleDescriptor;
private WrappedSdkHarnessClient wrappedClient;
}

private PreparedClient prepare(
WrappedSdkHarnessClient wrappedClient, ExecutableStage executableStage) {
PreparedClient preparedClient = new PreparedClient();
try {
preparedClient.wrappedClient = wrappedClient;
preparedClient.processBundleDescriptor =
ProcessBundleDescriptors.fromExecutableStage(
stageIdGenerator.getId(),
executableStage,
wrappedClient.getServerInfo().getDataServer().getApiServiceDescriptor(),
wrappedClient.getServerInfo().getStateServer().getApiServiceDescriptor());
} catch (IOException e) {
throw new RuntimeException("Failed to create ProcessBundleDescriptor.", e);
}

preparedClient.processor =
wrappedClient
.getClient()
.getProcessor(
preparedClient.processBundleDescriptor.getProcessBundleDescriptor(),
preparedClient.processBundleDescriptor.getRemoteInputDestinations(),
wrappedClient.getServerInfo().getStateServer().getService());
return preparedClient;
}

/**
* A {@link StageBundleFactory} for remotely processing bundles that supports environment
* expiration.
Expand All @@ -235,37 +321,16 @@ private class SimpleStageBundleFactory implements StageBundleFactory {

private final ExecutableStage executableStage;
private final int environmentIndex;
private BundleProcessor processor;
private ExecutableProcessBundleDescriptor processBundleDescriptor;
private WrappedSdkHarnessClient wrappedClient;
private final HashMap<WrappedSdkHarnessClient, PreparedClient> preparedClients = new HashMap();
private PreparedClient currentClient;

private SimpleStageBundleFactory(ExecutableStage executableStage) {
this.executableStage = executableStage;
this.environmentIndex = stageBundleCount.getAndIncrement() % environmentCaches.size();
prepare(
environmentCaches.get(environmentIndex).getUnchecked(executableStage.getEnvironment()));
}

private void prepare(WrappedSdkHarnessClient wrappedClient) {
try {
this.wrappedClient = wrappedClient;
this.processBundleDescriptor =
ProcessBundleDescriptors.fromExecutableStage(
stageIdGenerator.getId(),
executableStage,
wrappedClient.getServerInfo().getDataServer().getApiServiceDescriptor(),
wrappedClient.getServerInfo().getStateServer().getApiServiceDescriptor());
} catch (IOException e) {
throw new RuntimeException("Failed to create ProcessBundleDescriptor.", e);
}

this.processor =
wrappedClient
.getClient()
.getProcessor(
processBundleDescriptor.getProcessBundleDescriptor(),
processBundleDescriptor.getRemoteInputDestinations(),
wrappedClient.getServerInfo().getStateServer().getService());
this.environmentIndex = stageBundleFactoryCount.getAndIncrement() % environmentCaches.size();
WrappedSdkHarnessClient client =
environmentCaches.get(environmentIndex).getUnchecked(executableStage.getEnvironment());
this.currentClient = prepare(client, executableStage);
this.preparedClients.put(client, currentClient);
}

@Override
Expand All @@ -276,38 +341,53 @@ public RemoteBundle getBundle(
throws Exception {
// TODO: Consider having BundleProcessor#newBundle take in an OutputReceiverFactory rather
// than constructing the receiver map here. Every bundle factory will need this.
ImmutableMap.Builder<String, RemoteOutputReceiver<?>> outputReceivers =
ImmutableMap.builder();
for (Map.Entry<String, Coder> remoteOutputCoder :
processBundleDescriptor.getRemoteOutputCoders().entrySet()) {
String outputTransform = remoteOutputCoder.getKey();
Coder coder = remoteOutputCoder.getValue();
String bundleOutputPCollection =
Iterables.getOnlyElement(
processBundleDescriptor
.getProcessBundleDescriptor()
.getTransformsOrThrow(outputTransform)
.getInputsMap()
.values());
FnDataReceiver outputReceiver = outputReceiverFactory.create(bundleOutputPCollection);
outputReceivers.put(outputTransform, RemoteOutputReceiver.of(coder, outputReceiver));
}

if (environmentExpirationMillis == 0) {
return processor.newBundle(outputReceivers.build(), stateRequestHandler, progressHandler);
if (environmentExpirationMillis == 0 && !loadBalanceBundles) {
return currentClient.processor.newBundle(
getOutputReceivers(currentClient.processBundleDescriptor, outputReceiverFactory)
.build(),
stateRequestHandler,
progressHandler);
}

final WrappedSdkHarnessClient client =
environmentCaches.get(environmentIndex).getUnchecked(executableStage.getEnvironment());
client.ref();
final LoadingCache<Environment, WrappedSdkHarnessClient> currentCache;
if (loadBalanceBundles) {
// The semaphore is used to ensure fairness, i.e. first stop first go.
availableCachesSemaphore.acquire();
// The blocking queue of caches for serving multiple bundles concurrently.
currentCache = availableCaches.take();
WrappedSdkHarnessClient client =
currentCache.getUnchecked(executableStage.getEnvironment());
client.ref();

currentClient = preparedClients.get(client);
if (currentClient == null) {
// we are using this client for the first time
preparedClients.put(client, currentClient = prepare(client, executableStage));
// cleanup any expired clients
preparedClients.keySet().removeIf(c -> c.bundleRefCount.get() == 0);
}

if (client != wrappedClient) {
// reset after environment expired
prepare(client);
} else {
currentCache = environmentCaches.get(environmentIndex);
WrappedSdkHarnessClient client =
currentCache.getUnchecked(executableStage.getEnvironment());
client.ref();

if (currentClient.wrappedClient != client) {
// reset after environment expired
preparedClients.clear();
currentClient = prepare(client, executableStage);
preparedClients.put(client, currentClient);
}
}

final RemoteBundle bundle =
processor.newBundle(outputReceivers.build(), stateRequestHandler, progressHandler);
currentClient.processor.newBundle(
getOutputReceivers(currentClient.processBundleDescriptor, outputReceiverFactory)
.build(),
stateRequestHandler,
progressHandler);
return new RemoteBundle() {
@Override
public String getId() {
Expand All @@ -322,20 +402,24 @@ public Map<String, FnDataReceiver> getInputReceivers() {
@Override
public void close() throws Exception {
bundle.close();
client.unref();
currentClient.wrappedClient.unref();
if (loadBalanceBundles) {
availableCaches.offer(currentCache);
availableCachesSemaphore.release();
}
}
};
}

@Override
public ExecutableProcessBundleDescriptor getProcessBundleDescriptor() {
return processBundleDescriptor;
return currentClient.processBundleDescriptor;
}

@Override
public void close() throws Exception {
// Clear reference to encourage cache eviction. Values are weakly referenced.
wrappedClient = null;
preparedClients.clear();
}
}

Expand Down
Loading

0 comments on commit e486202

Please sign in to comment.