Skip to content

Commit

Permalink
Merge pull request apache#10124: [BEAM-8670] Manage environment paral…
Browse files Browse the repository at this point in the history
…lelism in DefaultJobBundleFactory
  • Loading branch information
tweise committed Nov 16, 2019
2 parents 7beb9ee + 45dc280 commit 1386b94
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 194 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import org.apache.beam.runners.core.construction.PipelineOptionsTranslation;
import org.apache.beam.runners.fnexecution.control.DefaultExecutableStageContext.MultiInstanceFactory;
import org.apache.beam.runners.fnexecution.control.DefaultExecutableStageContext;
import org.apache.beam.runners.fnexecution.control.ExecutableStageContext;
import org.apache.beam.runners.fnexecution.control.ReferenceCountingExecutableStageContextFactory;
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
import org.apache.beam.sdk.options.PortablePipelineOptions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects;
import org.apache.flink.api.java.ExecutionEnvironment;

/** Singleton class that contains one {@link MultiInstanceFactory} per job. */
/** Singleton class that contains one {@link ExecutableStageContext.Factory} per job. */
public class FlinkExecutableStageContextFactory implements ExecutableStageContext.Factory {

private static final FlinkExecutableStageContextFactory instance =
Expand All @@ -36,7 +34,7 @@ public class FlinkExecutableStageContextFactory implements ExecutableStageContex
// classloader and therefore its own instance of FlinkExecutableStageContextFactory. This
// code supports multiple JobInfos in order to provide a sensible implementation of
// Factory.get(JobInfo), which in theory could be called with different JobInfos.
private static final ConcurrentMap<String, MultiInstanceFactory> jobFactories =
private static final ConcurrentMap<String, ExecutableStageContext.Factory> jobFactories =
new ConcurrentHashMap<>();

private FlinkExecutableStageContextFactory() {}
Expand All @@ -47,17 +45,12 @@ public static FlinkExecutableStageContextFactory getInstance() {

@Override
public ExecutableStageContext get(JobInfo jobInfo) {
MultiInstanceFactory jobFactory =
ExecutableStageContext.Factory jobFactory =
jobFactories.computeIfAbsent(
jobInfo.jobId(),
k -> {
PortablePipelineOptions portableOptions =
PipelineOptionsTranslation.fromProto(jobInfo.pipelineOptions())
.as(PortablePipelineOptions.class);

return new MultiInstanceFactory(
MoreObjects.firstNonNull(portableOptions.getSdkWorkerParallelism(), 1L)
.intValue(),
return ReferenceCountingExecutableStageContextFactory.create(
DefaultExecutableStageContext::create,
// Clean up context immediately if its class is not loaded on Flink parent
// classloader.
(caller) ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,14 @@
*/
package org.apache.beam.runners.fnexecution.control;

import java.util.ArrayList;
import java.util.List;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;

/** Implementation of a {@link ExecutableStageContext}. */
public class DefaultExecutableStageContext implements ExecutableStageContext, AutoCloseable {
private final JobBundleFactory jobBundleFactory;

private static DefaultExecutableStageContext create(JobInfo jobInfo) {
public static DefaultExecutableStageContext create(JobInfo jobInfo) {
JobBundleFactory jobBundleFactory = DefaultJobBundleFactory.create(jobInfo);
return new DefaultExecutableStageContext(jobBundleFactory);
}
Expand All @@ -46,54 +42,4 @@ public StageBundleFactory getStageBundleFactory(ExecutableStage executableStage)
public void close() throws Exception {
jobBundleFactory.close();
}

/**
* {@link ExecutableStageContext.Factory} that creates and round-robins between a number of child
* {@link ExecutableStageContext.Factory} instances.
*/
public static class MultiInstanceFactory implements ExecutableStageContext.Factory {

private int index = 0;
private final List<ReferenceCountingExecutableStageContextFactory> factories =
new ArrayList<>();
private final int maxFactories;
private final SerializableFunction<Object, Boolean> isReleaseSynchronous;

public MultiInstanceFactory(
int maxFactories, SerializableFunction<Object, Boolean> isReleaseSynchronous) {
this.isReleaseSynchronous = isReleaseSynchronous;
Preconditions.checkArgument(maxFactories >= 0, "sdk_worker_parallelism must be >= 0");

if (maxFactories == 0) {
// if this is 0, use the auto behavior of num_cores - 1 so that we leave some resources
// available for the java process
this.maxFactories = Math.max(Runtime.getRuntime().availableProcessors() - 1, 1);
} else {
this.maxFactories = maxFactories;
}
}

private synchronized ExecutableStageContext.Factory getFactory() {
ReferenceCountingExecutableStageContextFactory factory;
// If we haven't yet created maxFactories factories, create a new one. Otherwise use an
// existing one from factories.
if (factories.size() < maxFactories) {
factory =
ReferenceCountingExecutableStageContextFactory.create(
DefaultExecutableStageContext::create, isReleaseSynchronous);
factories.add(factory);
} else {
factory = factories.get(index);
}

index = (index + 1) % maxFactories;

return factory;
}

@Override
public ExecutableStageContext get(JobInfo jobInfo) {
return getFactory().get(jobInfo);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,13 @@
import org.apache.beam.sdk.options.PortablePipelineOptions;
import org.apache.beam.sdk.options.PortablePipelineOptions.RetrievalServiceType;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheBuilder;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheLoader;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LoadingCache;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.RemovalNotification;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.slf4j.Logger;
Expand All @@ -85,7 +87,8 @@ public class DefaultJobBundleFactory implements JobBundleFactory {
private static final IdGenerator factoryIdGenerator = IdGenerators.incrementingLongs();

private final String factoryId = factoryIdGenerator.getId();
private final LoadingCache<Environment, WrappedSdkHarnessClient> environmentCache;
private final ImmutableList<LoadingCache<Environment, WrappedSdkHarnessClient>> environmentCaches;
private final AtomicInteger stageBundleCount = new AtomicInteger();
private final Map<String, EnvironmentFactory.Provider> environmentFactoryProviderMap;
private final ExecutorService executor;
private final MapControlClientPool clientPool;
Expand Down Expand Up @@ -121,8 +124,10 @@ public static DefaultJobBundleFactory create(
this.clientPool = MapControlClientPool.create();
this.stageIdGenerator = () -> factoryId + "-" + stageIdSuffixGenerator.getId();
this.environmentExpirationMillis = getEnvironmentExpirationMillis(jobInfo);
this.environmentCache =
createEnvironmentCache(serverFactory -> createServerInfo(jobInfo, serverFactory));
this.environmentCaches =
createEnvironmentCaches(
serverFactory -> createServerInfo(jobInfo, serverFactory),
getMaxEnvironmentClients(jobInfo));
}

@VisibleForTesting
Expand All @@ -136,17 +141,12 @@ public static DefaultJobBundleFactory create(
this.clientPool = MapControlClientPool.create();
this.stageIdGenerator = stageIdGenerator;
this.environmentExpirationMillis = getEnvironmentExpirationMillis(jobInfo);
this.environmentCache = createEnvironmentCache(serverFactory -> serverInfo);
this.environmentCaches =
createEnvironmentCaches(serverFactory -> serverInfo, getMaxEnvironmentClients(jobInfo));
}

private static int getEnvironmentExpirationMillis(JobInfo jobInfo) {
PipelineOptions pipelineOptions =
PipelineOptionsTranslation.fromProto(jobInfo.pipelineOptions());
return pipelineOptions.as(PortablePipelineOptions.class).getEnvironmentExpirationMillis();
}

private LoadingCache<Environment, WrappedSdkHarnessClient> createEnvironmentCache(
ThrowingFunction<ServerFactory, ServerInfo> serverInfoCreator) {
private ImmutableList<LoadingCache<Environment, WrappedSdkHarnessClient>> createEnvironmentCaches(
ThrowingFunction<ServerFactory, ServerInfo> serverInfoCreator, int count) {
CacheBuilder builder =
CacheBuilder.newBuilder()
.removalListener(
Expand All @@ -161,26 +161,55 @@ private LoadingCache<Environment, WrappedSdkHarnessClient> createEnvironmentCach
if (environmentExpirationMillis > 0) {
builder = builder.expireAfterWrite(environmentExpirationMillis, TimeUnit.MILLISECONDS);
}
return builder.build(
new CacheLoader<Environment, WrappedSdkHarnessClient>() {
@Override
public WrappedSdkHarnessClient load(Environment environment) throws Exception {
EnvironmentFactory.Provider environmentFactoryProvider =
environmentFactoryProviderMap.get(environment.getUrn());
ServerFactory serverFactory = environmentFactoryProvider.getServerFactory();
ServerInfo serverInfo = serverInfoCreator.apply(serverFactory);
EnvironmentFactory environmentFactory =
environmentFactoryProvider.createEnvironmentFactory(
serverInfo.getControlServer(),
serverInfo.getLoggingServer(),
serverInfo.getRetrievalServer(),
serverInfo.getProvisioningServer(),
clientPool,
stageIdGenerator);
return WrappedSdkHarnessClient.wrapping(
environmentFactory.createEnvironment(environment), serverInfo);
}
});

ImmutableList.Builder<LoadingCache<Environment, WrappedSdkHarnessClient>> caches =
ImmutableList.builder();
for (int i = 0; i < count; i++) {
LoadingCache<Environment, WrappedSdkHarnessClient> cache =
builder.build(
new CacheLoader<Environment, WrappedSdkHarnessClient>() {
@Override
public WrappedSdkHarnessClient load(Environment environment) throws Exception {
EnvironmentFactory.Provider environmentFactoryProvider =
environmentFactoryProviderMap.get(environment.getUrn());
ServerFactory serverFactory = environmentFactoryProvider.getServerFactory();
ServerInfo serverInfo = serverInfoCreator.apply(serverFactory);
EnvironmentFactory environmentFactory =
environmentFactoryProvider.createEnvironmentFactory(
serverInfo.getControlServer(),
serverInfo.getLoggingServer(),
serverInfo.getRetrievalServer(),
serverInfo.getProvisioningServer(),
clientPool,
stageIdGenerator);
return WrappedSdkHarnessClient.wrapping(
environmentFactory.createEnvironment(environment), serverInfo);
}
});
caches.add(cache);
}
return caches.build();
}

private static int getEnvironmentExpirationMillis(JobInfo jobInfo) {
PipelineOptions pipelineOptions =
PipelineOptionsTranslation.fromProto(jobInfo.pipelineOptions());
return pipelineOptions.as(PortablePipelineOptions.class).getEnvironmentExpirationMillis();
}

private static int getMaxEnvironmentClients(JobInfo jobInfo) {
PortablePipelineOptions pipelineOptions =
PipelineOptionsTranslation.fromProto(jobInfo.pipelineOptions())
.as(PortablePipelineOptions.class);
int maxEnvironments =
MoreObjects.firstNonNull(pipelineOptions.getSdkWorkerParallelism(), 1L).intValue();
Preconditions.checkArgument(maxEnvironments >= 0, "sdk_worker_parallelism must be >= 0");
if (maxEnvironments == 0) {
// if this is 0, use the auto behavior of num_cores - 1 so that we leave some resources
// available for the java process
maxEnvironments = Math.max(Runtime.getRuntime().availableProcessors() - 1, 1);
}
return maxEnvironments;
}

@Override
Expand All @@ -192,9 +221,10 @@ public StageBundleFactory forStage(ExecutableStage executableStage) {
public void close() throws Exception {
// Clear the cache. This closes all active environments.
// note this may cause open calls to be cancelled by the peer
environmentCache.invalidateAll();
environmentCache.cleanUp();

for (LoadingCache<Environment, WrappedSdkHarnessClient> environmentCache : environmentCaches) {
environmentCache.invalidateAll();
environmentCache.cleanUp();
}
executor.shutdown();
}

Expand All @@ -205,13 +235,16 @@ public void close() throws Exception {
private class SimpleStageBundleFactory implements StageBundleFactory {

private final ExecutableStage executableStage;
private final int environmentIndex;
private BundleProcessor processor;
private ExecutableProcessBundleDescriptor processBundleDescriptor;
private WrappedSdkHarnessClient wrappedClient;

private SimpleStageBundleFactory(ExecutableStage executableStage) {
this.executableStage = executableStage;
prepare(environmentCache.getUnchecked(executableStage.getEnvironment()));
this.environmentIndex = stageBundleCount.getAndIncrement() % environmentCaches.size();
prepare(
environmentCaches.get(environmentIndex).getUnchecked(executableStage.getEnvironment()));
}

private void prepare(WrappedSdkHarnessClient wrappedClient) {
Expand Down Expand Up @@ -266,7 +299,7 @@ public RemoteBundle getBundle(
}

final WrappedSdkHarnessClient client =
environmentCache.getUnchecked(executableStage.getEnvironment());
environmentCaches.get(environmentIndex).getUnchecked(executableStage.getEnvironment());
client.ref();

if (client != wrappedClient) {
Expand Down

This file was deleted.

Loading

0 comments on commit 1386b94

Please sign in to comment.