diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java index a3f6f01fb9299..e5e4c391334dc 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java @@ -19,9 +19,12 @@ import com.google.auto.value.AutoValue; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import javax.annotation.concurrent.ThreadSafe; @@ -58,6 +61,7 @@ import org.apache.beam.sdk.fn.data.FnDataReceiver; import org.apache.beam.sdk.fn.stream.OutboundObserverFactory; import org.apache.beam.sdk.function.ThrowingFunction; +import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PortablePipelineOptions; import org.apache.beam.sdk.options.PortablePipelineOptions.RetrievalServiceType; @@ -88,12 +92,16 @@ public class DefaultJobBundleFactory implements JobBundleFactory { private final String factoryId = factoryIdGenerator.getId(); private final ImmutableList> environmentCaches; - private final AtomicInteger stageBundleCount = new AtomicInteger(); + private final AtomicInteger stageBundleFactoryCount = new AtomicInteger(); private final Map environmentFactoryProviderMap; private final ExecutorService executor; private final MapControlClientPool clientPool; private final IdGenerator stageIdGenerator; private final int environmentExpirationMillis; + private final Semaphore availableCachesSemaphore; + private final LinkedBlockingDeque> + availableCaches; + private final boolean loadBalanceBundles; public static DefaultJobBundleFactory create(JobInfo jobInfo) { PipelineOptions pipelineOptions = @@ -124,10 +132,13 @@ public static DefaultJobBundleFactory create( this.clientPool = MapControlClientPool.create(); this.stageIdGenerator = () -> factoryId + "-" + stageIdSuffixGenerator.getId(); this.environmentExpirationMillis = getEnvironmentExpirationMillis(jobInfo); + this.loadBalanceBundles = shouldLoadBalanceBundles(jobInfo); this.environmentCaches = createEnvironmentCaches( serverFactory -> createServerInfo(jobInfo, serverFactory), getMaxEnvironmentClients(jobInfo)); + this.availableCachesSemaphore = new Semaphore(environmentCaches.size(), true); + this.availableCaches = new LinkedBlockingDeque<>(environmentCaches); } @VisibleForTesting @@ -141,8 +152,11 @@ public static DefaultJobBundleFactory create( this.clientPool = MapControlClientPool.create(); this.stageIdGenerator = stageIdGenerator; this.environmentExpirationMillis = getEnvironmentExpirationMillis(jobInfo); + this.loadBalanceBundles = shouldLoadBalanceBundles(jobInfo); this.environmentCaches = createEnvironmentCaches(serverFactory -> serverInfo, getMaxEnvironmentClients(jobInfo)); + this.availableCachesSemaphore = new Semaphore(environmentCaches.size(), true); + this.availableCaches = new LinkedBlockingDeque<>(environmentCaches); } private ImmutableList> createEnvironmentCaches( @@ -211,6 +225,26 @@ private static int getMaxEnvironmentClients(JobInfo jobInfo) { return maxEnvironments; } + private static boolean shouldLoadBalanceBundles(JobInfo jobInfo) { + PipelineOptions pipelineOptions = + PipelineOptionsTranslation.fromProto(jobInfo.pipelineOptions()); + boolean loadBalanceBundles = + pipelineOptions.as(PortablePipelineOptions.class).getLoadBalanceBundles(); + if (loadBalanceBundles) { + int stateCacheSize = + Integer.parseInt( + MoreObjects.firstNonNull( + ExperimentalOptions.getExperimentValue( + pipelineOptions, ExperimentalOptions.STATE_CACHE_SIZE), + "0")); + Preconditions.checkArgument( + stateCacheSize == 0, + "%s must be 0 when using bundle load balancing", + ExperimentalOptions.STATE_CACHE_SIZE); + } + return loadBalanceBundles; + } + @Override public StageBundleFactory forStage(ExecutableStage executableStage) { return new SimpleStageBundleFactory(executableStage); @@ -227,6 +261,58 @@ public void close() throws Exception { executor.shutdown(); } + private static ImmutableMap.Builder> getOutputReceivers( + ExecutableProcessBundleDescriptor processBundleDescriptor, + OutputReceiverFactory outputReceiverFactory) { + ImmutableMap.Builder> outputReceivers = ImmutableMap.builder(); + for (Map.Entry remoteOutputCoder : + processBundleDescriptor.getRemoteOutputCoders().entrySet()) { + String outputTransform = remoteOutputCoder.getKey(); + Coder coder = remoteOutputCoder.getValue(); + String bundleOutputPCollection = + Iterables.getOnlyElement( + processBundleDescriptor + .getProcessBundleDescriptor() + .getTransformsOrThrow(outputTransform) + .getInputsMap() + .values()); + FnDataReceiver outputReceiver = outputReceiverFactory.create(bundleOutputPCollection); + outputReceivers.put(outputTransform, RemoteOutputReceiver.of(coder, outputReceiver)); + } + return outputReceivers; + } + + private static class PreparedClient { + private BundleProcessor processor; + private ExecutableProcessBundleDescriptor processBundleDescriptor; + private WrappedSdkHarnessClient wrappedClient; + } + + private PreparedClient prepare( + WrappedSdkHarnessClient wrappedClient, ExecutableStage executableStage) { + PreparedClient preparedClient = new PreparedClient(); + try { + preparedClient.wrappedClient = wrappedClient; + preparedClient.processBundleDescriptor = + ProcessBundleDescriptors.fromExecutableStage( + stageIdGenerator.getId(), + executableStage, + wrappedClient.getServerInfo().getDataServer().getApiServiceDescriptor(), + wrappedClient.getServerInfo().getStateServer().getApiServiceDescriptor()); + } catch (IOException e) { + throw new RuntimeException("Failed to create ProcessBundleDescriptor.", e); + } + + preparedClient.processor = + wrappedClient + .getClient() + .getProcessor( + preparedClient.processBundleDescriptor.getProcessBundleDescriptor(), + preparedClient.processBundleDescriptor.getRemoteInputDestinations(), + wrappedClient.getServerInfo().getStateServer().getService()); + return preparedClient; + } + /** * A {@link StageBundleFactory} for remotely processing bundles that supports environment * expiration. @@ -235,37 +321,16 @@ private class SimpleStageBundleFactory implements StageBundleFactory { private final ExecutableStage executableStage; private final int environmentIndex; - private BundleProcessor processor; - private ExecutableProcessBundleDescriptor processBundleDescriptor; - private WrappedSdkHarnessClient wrappedClient; + private final HashMap preparedClients = new HashMap(); + private PreparedClient currentClient; private SimpleStageBundleFactory(ExecutableStage executableStage) { this.executableStage = executableStage; - this.environmentIndex = stageBundleCount.getAndIncrement() % environmentCaches.size(); - prepare( - environmentCaches.get(environmentIndex).getUnchecked(executableStage.getEnvironment())); - } - - private void prepare(WrappedSdkHarnessClient wrappedClient) { - try { - this.wrappedClient = wrappedClient; - this.processBundleDescriptor = - ProcessBundleDescriptors.fromExecutableStage( - stageIdGenerator.getId(), - executableStage, - wrappedClient.getServerInfo().getDataServer().getApiServiceDescriptor(), - wrappedClient.getServerInfo().getStateServer().getApiServiceDescriptor()); - } catch (IOException e) { - throw new RuntimeException("Failed to create ProcessBundleDescriptor.", e); - } - - this.processor = - wrappedClient - .getClient() - .getProcessor( - processBundleDescriptor.getProcessBundleDescriptor(), - processBundleDescriptor.getRemoteInputDestinations(), - wrappedClient.getServerInfo().getStateServer().getService()); + this.environmentIndex = stageBundleFactoryCount.getAndIncrement() % environmentCaches.size(); + WrappedSdkHarnessClient client = + environmentCaches.get(environmentIndex).getUnchecked(executableStage.getEnvironment()); + this.currentClient = prepare(client, executableStage); + this.preparedClients.put(client, currentClient); } @Override @@ -276,38 +341,53 @@ public RemoteBundle getBundle( throws Exception { // TODO: Consider having BundleProcessor#newBundle take in an OutputReceiverFactory rather // than constructing the receiver map here. Every bundle factory will need this. - ImmutableMap.Builder> outputReceivers = - ImmutableMap.builder(); - for (Map.Entry remoteOutputCoder : - processBundleDescriptor.getRemoteOutputCoders().entrySet()) { - String outputTransform = remoteOutputCoder.getKey(); - Coder coder = remoteOutputCoder.getValue(); - String bundleOutputPCollection = - Iterables.getOnlyElement( - processBundleDescriptor - .getProcessBundleDescriptor() - .getTransformsOrThrow(outputTransform) - .getInputsMap() - .values()); - FnDataReceiver outputReceiver = outputReceiverFactory.create(bundleOutputPCollection); - outputReceivers.put(outputTransform, RemoteOutputReceiver.of(coder, outputReceiver)); - } - if (environmentExpirationMillis == 0) { - return processor.newBundle(outputReceivers.build(), stateRequestHandler, progressHandler); + if (environmentExpirationMillis == 0 && !loadBalanceBundles) { + return currentClient.processor.newBundle( + getOutputReceivers(currentClient.processBundleDescriptor, outputReceiverFactory) + .build(), + stateRequestHandler, + progressHandler); } - final WrappedSdkHarnessClient client = - environmentCaches.get(environmentIndex).getUnchecked(executableStage.getEnvironment()); - client.ref(); + final LoadingCache currentCache; + if (loadBalanceBundles) { + // The semaphore is used to ensure fairness, i.e. first stop first go. + availableCachesSemaphore.acquire(); + // The blocking queue of caches for serving multiple bundles concurrently. + currentCache = availableCaches.take(); + WrappedSdkHarnessClient client = + currentCache.getUnchecked(executableStage.getEnvironment()); + client.ref(); + + currentClient = preparedClients.get(client); + if (currentClient == null) { + // we are using this client for the first time + preparedClients.put(client, currentClient = prepare(client, executableStage)); + // cleanup any expired clients + preparedClients.keySet().removeIf(c -> c.bundleRefCount.get() == 0); + } - if (client != wrappedClient) { - // reset after environment expired - prepare(client); + } else { + currentCache = environmentCaches.get(environmentIndex); + WrappedSdkHarnessClient client = + currentCache.getUnchecked(executableStage.getEnvironment()); + client.ref(); + + if (currentClient.wrappedClient != client) { + // reset after environment expired + preparedClients.clear(); + currentClient = prepare(client, executableStage); + preparedClients.put(client, currentClient); + } } final RemoteBundle bundle = - processor.newBundle(outputReceivers.build(), stateRequestHandler, progressHandler); + currentClient.processor.newBundle( + getOutputReceivers(currentClient.processBundleDescriptor, outputReceiverFactory) + .build(), + stateRequestHandler, + progressHandler); return new RemoteBundle() { @Override public String getId() { @@ -322,20 +402,24 @@ public Map getInputReceivers() { @Override public void close() throws Exception { bundle.close(); - client.unref(); + currentClient.wrappedClient.unref(); + if (loadBalanceBundles) { + availableCaches.offer(currentCache); + availableCachesSemaphore.release(); + } } }; } @Override public ExecutableProcessBundleDescriptor getProcessBundleDescriptor() { - return processBundleDescriptor; + return currentClient.processBundleDescriptor; } @Override public void close() throws Exception { // Clear reference to encourage cache eviction. Values are weakly referenced. - wrappedClient = null; + preparedClients.clear(); } } diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactoryTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactoryTest.java index 1b43154d60797..b5ac3c6cef22d 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactoryTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactoryTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.fnexecution.control; +import static org.hamcrest.Matchers.containsString; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -26,7 +27,10 @@ import java.util.Collections; import java.util.HashSet; import java.util.Map; +import java.util.Timer; +import java.util.TimerTask; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionResponse; import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor; import org.apache.beam.model.pipeline.v1.RunnerApi.Coder; @@ -55,6 +59,7 @@ import org.apache.beam.sdk.fn.IdGenerator; import org.apache.beam.sdk.fn.IdGenerators; import org.apache.beam.sdk.fn.data.CloseableFnDataReceiver; +import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.PortablePipelineOptions; import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.ByteString; @@ -362,7 +367,7 @@ public void cachesEnvironment() throws Exception { StageBundleFactory bf1 = bundleFactory.forStage(getExecutableStage(environment)); StageBundleFactory bf2 = bundleFactory.forStage(getExecutableStage(environment)); // NOTE: We hang on to stage bundle references to ensure their underlying environments are not - // garbage collected. For additional safety, we print the factories to ensure the referernces + // garbage collected. For additional safety, we print the factories to ensure the references // are not optimized away. System.out.println("bundle factory 1:" + bf1); System.out.println("bundle factory 1:" + bf2); @@ -395,6 +400,78 @@ public void doesNotCacheDifferentEnvironments() throws Exception { } } + @Test + public void loadBalancesBundles() throws Exception { + PortablePipelineOptions portableOptions = + PipelineOptionsFactory.as(PortablePipelineOptions.class); + portableOptions.setSdkWorkerParallelism(2); + portableOptions.setLoadBalanceBundles(true); + Struct pipelineOptions = PipelineOptionsTranslation.toProto(portableOptions); + + try (DefaultJobBundleFactory bundleFactory = + new DefaultJobBundleFactory( + JobInfo.create("testJob", "testJob", "token", pipelineOptions), + envFactoryProviderMap, + stageIdGenerator, + serverInfo)) { + OutputReceiverFactory orf = mock(OutputReceiverFactory.class); + StateRequestHandler srh = mock(StateRequestHandler.class); + when(srh.getCacheTokens()).thenReturn(Collections.emptyList()); + StageBundleFactory sbf = bundleFactory.forStage(getExecutableStage(environment)); + RemoteBundle b1 = sbf.getBundle(orf, srh, BundleProgressHandler.ignored()); + verify(envFactory, Mockito.times(1)).createEnvironment(environment); + final RemoteBundle b2 = sbf.getBundle(orf, srh, BundleProgressHandler.ignored()); + verify(envFactory, Mockito.times(2)).createEnvironment(environment); + + long tms = System.currentTimeMillis(); + AtomicBoolean closed = new AtomicBoolean(); + // close to free up environment for another bundle + TimerTask closeBundleTask = + new TimerTask() { + @Override + public void run() { + try { + b2.close(); + closed.set(true); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + }; + new Timer().schedule(closeBundleTask, 100); + + RemoteBundle b3 = sbf.getBundle(orf, srh, BundleProgressHandler.ignored()); + // ensure we waited for close + Assert.assertTrue(System.currentTimeMillis() - tms >= 100 && closed.get()); + + verify(envFactory, Mockito.times(2)).createEnvironment(environment); + b3.close(); + b1.close(); + } + } + + @Test + public void rejectsStateCachingWithLoadBalancing() throws Exception { + PortablePipelineOptions portableOptions = + PipelineOptionsFactory.as(PortablePipelineOptions.class); + portableOptions.setLoadBalanceBundles(true); + ExperimentalOptions options = portableOptions.as(ExperimentalOptions.class); + ExperimentalOptions.addExperiment(options, "state_cache_size=1"); + Struct pipelineOptions = PipelineOptionsTranslation.toProto(options); + + Exception e = + Assert.assertThrows( + IllegalArgumentException.class, + () -> + new DefaultJobBundleFactory( + JobInfo.create("testJob", "testJob", "token", pipelineOptions), + envFactoryProviderMap, + stageIdGenerator, + serverInfo) + .close()); + Assert.assertThat(e.getMessage(), containsString("state_cache_size")); + } + private DefaultJobBundleFactory createDefaultJobBundleFactory( Map envFactoryProviderMap) { return new DefaultJobBundleFactory( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ExperimentalOptions.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ExperimentalOptions.java index b9825caea83ed..017b0d4436e36 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ExperimentalOptions.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/ExperimentalOptions.java @@ -30,6 +30,9 @@ @Experimental @Hidden public interface ExperimentalOptions extends PipelineOptions { + + String STATE_CACHE_SIZE = "state_cache_size"; + @Description( "[Experimental] Apache Beam provides a number of experimental features that can " + "be enabled with this flag. If executing against a managed service, please contact the " @@ -60,4 +63,22 @@ static void addExperiment(ExperimentalOptions options, String experiment) { } options.setExperiments(experiments); } + + /** Return the value for the specified experiment or null if not present. */ + static String getExperimentValue(PipelineOptions options, String experiment) { + if (options == null) { + return null; + } + List experiments = options.as(ExperimentalOptions.class).getExperiments(); + if (experiments == null) { + return null; + } + for (String experimentEntry : experiments) { + String[] tokens = experimentEntry.split(experiment + "=", -1); + if (tokens.length > 1) { + return tokens[1]; + } + } + return null; + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PortablePipelineOptions.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PortablePipelineOptions.java index a531e8605ba50..b67a0b8137971 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PortablePipelineOptions.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PortablePipelineOptions.java @@ -88,6 +88,13 @@ public interface PortablePipelineOptions extends PipelineOptions { void setEnvironmentExpirationMillis(int environmentExpirationMillis); + @Description( + "Specifies if bundles should be distributed to the next available free SDK worker. By default SDK workers are pinned to runner tasks for the duration of the pipeline. This option can help for pipelines with long and skewed bundle execution times to increase throughput and improve worker utilization.") + @Default.Boolean(false) + boolean getLoadBalanceBundles(); + + void setLoadBalanceBundles(boolean loadBalanceBundles); + @Description("The output path for the executable file to be created.") @Nullable String getOutputExecutablePath(); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/options/ExperimentalOptionsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/options/ExperimentalOptionsTest.java index c60007aaf0210..eebf6216fcd59 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/options/ExperimentalOptionsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/options/ExperimentalOptionsTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.options; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -36,4 +37,14 @@ public void testExperimentIsSet() { assertTrue(ExperimentalOptions.hasExperiment(options, "experimentB")); assertFalse(ExperimentalOptions.hasExperiment(options, "experimentC")); } + + @Test + public void testExperimentGetValue() { + ExperimentalOptions options = + PipelineOptionsFactory.fromArgs( + "--experiments=experimentA=0,state_cache_size=1,experimentB=0") + .as(ExperimentalOptions.class); + assertEquals( + "1", ExperimentalOptions.getExperimentValue(options, ExperimentalOptions.STATE_CACHE_SIZE)); + } }