| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.runners.fnexecution.control; |
| |
| import com.google.auto.value.AutoValue; |
| import java.io.IOException; |
| import java.util.Map; |
| import java.util.concurrent.ExecutorService; |
| import java.util.concurrent.Executors; |
| import java.util.concurrent.TimeUnit; |
| import java.util.concurrent.atomic.AtomicInteger; |
| import javax.annotation.concurrent.ThreadSafe; |
| import org.apache.beam.model.pipeline.v1.RunnerApi.Environment; |
| import org.apache.beam.model.pipeline.v1.RunnerApi.StandardEnvironments; |
| import org.apache.beam.runners.core.construction.BeamUrns; |
| import org.apache.beam.runners.core.construction.Environments; |
| import org.apache.beam.runners.core.construction.PipelineOptionsTranslation; |
| import org.apache.beam.runners.core.construction.graph.ExecutableStage; |
| import org.apache.beam.runners.fnexecution.GrpcContextHeaderAccessorProvider; |
| import org.apache.beam.runners.fnexecution.GrpcFnServer; |
| import org.apache.beam.runners.fnexecution.ServerFactory; |
| import org.apache.beam.runners.fnexecution.artifact.ArtifactRetrievalService; |
| import org.apache.beam.runners.fnexecution.artifact.BeamFileSystemArtifactRetrievalService; |
| import org.apache.beam.runners.fnexecution.artifact.ClassLoaderArtifactRetrievalService; |
| import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.ExecutableProcessBundleDescriptor; |
| import org.apache.beam.runners.fnexecution.control.SdkHarnessClient.BundleProcessor; |
| import org.apache.beam.runners.fnexecution.data.GrpcDataService; |
| import org.apache.beam.runners.fnexecution.environment.DockerEnvironmentFactory; |
| import org.apache.beam.runners.fnexecution.environment.EmbeddedEnvironmentFactory; |
| import org.apache.beam.runners.fnexecution.environment.EnvironmentFactory; |
| import org.apache.beam.runners.fnexecution.environment.ExternalEnvironmentFactory; |
| import org.apache.beam.runners.fnexecution.environment.ProcessEnvironmentFactory; |
| import org.apache.beam.runners.fnexecution.environment.RemoteEnvironment; |
| import org.apache.beam.runners.fnexecution.logging.GrpcLoggingService; |
| import org.apache.beam.runners.fnexecution.logging.Slf4jLogWriter; |
| import org.apache.beam.runners.fnexecution.provisioning.JobInfo; |
| import org.apache.beam.runners.fnexecution.provisioning.StaticGrpcProvisionService; |
| import org.apache.beam.runners.fnexecution.state.GrpcStateService; |
| import org.apache.beam.runners.fnexecution.state.StateRequestHandler; |
| import org.apache.beam.sdk.coders.Coder; |
| import org.apache.beam.sdk.fn.IdGenerator; |
| import org.apache.beam.sdk.fn.IdGenerators; |
| import org.apache.beam.sdk.fn.data.FnDataReceiver; |
| import org.apache.beam.sdk.fn.stream.OutboundObserverFactory; |
| import org.apache.beam.sdk.function.ThrowingFunction; |
| import org.apache.beam.sdk.options.PipelineOptions; |
| import org.apache.beam.sdk.options.PortablePipelineOptions; |
| import org.apache.beam.sdk.options.PortablePipelineOptions.RetrievalServiceType; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheBuilder; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheLoader; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LoadingCache; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.RemovalNotification; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| /** |
| * A {@link JobBundleFactory} for which the implementation can specify a custom {@link |
| * EnvironmentFactory} for environment management. Note that returned {@link StageBundleFactory |
| * stage bundle factories} are not thread-safe. Instead, a new stage factory should be created for |
| * each client. {@link DefaultJobBundleFactory} initializes the Environment lazily when the forStage |
| * is called for a stage. |
| */ |
| @ThreadSafe |
| public class DefaultJobBundleFactory implements JobBundleFactory { |
| private static final Logger LOG = LoggerFactory.getLogger(DefaultJobBundleFactory.class); |
| private static final IdGenerator factoryIdGenerator = IdGenerators.incrementingLongs(); |
| |
| private final String factoryId = factoryIdGenerator.getId(); |
| private final ImmutableList<LoadingCache<Environment, WrappedSdkHarnessClient>> environmentCaches; |
| private final AtomicInteger stageBundleCount = new AtomicInteger(); |
| private final Map<String, EnvironmentFactory.Provider> environmentFactoryProviderMap; |
| private final ExecutorService executor; |
| private final MapControlClientPool clientPool; |
| private final IdGenerator stageIdGenerator; |
| private final int environmentExpirationMillis; |
| |
| public static DefaultJobBundleFactory create(JobInfo jobInfo) { |
| PipelineOptions pipelineOptions = |
| PipelineOptionsTranslation.fromProto(jobInfo.pipelineOptions()); |
| Map<String, EnvironmentFactory.Provider> environmentFactoryProviderMap = |
| ImmutableMap.of( |
| BeamUrns.getUrn(StandardEnvironments.Environments.DOCKER), |
| new DockerEnvironmentFactory.Provider(pipelineOptions), |
| BeamUrns.getUrn(StandardEnvironments.Environments.PROCESS), |
| new ProcessEnvironmentFactory.Provider(pipelineOptions), |
| BeamUrns.getUrn(StandardEnvironments.Environments.EXTERNAL), |
| new ExternalEnvironmentFactory.Provider(), |
| Environments.ENVIRONMENT_EMBEDDED, // Non Public urn for testing. |
| new EmbeddedEnvironmentFactory.Provider(pipelineOptions)); |
| return new DefaultJobBundleFactory(jobInfo, environmentFactoryProviderMap); |
| } |
| |
| public static DefaultJobBundleFactory create( |
| JobInfo jobInfo, Map<String, EnvironmentFactory.Provider> environmentFactoryProviderMap) { |
| return new DefaultJobBundleFactory(jobInfo, environmentFactoryProviderMap); |
| } |
| |
| DefaultJobBundleFactory( |
| JobInfo jobInfo, Map<String, EnvironmentFactory.Provider> environmentFactoryMap) { |
| IdGenerator stageIdSuffixGenerator = IdGenerators.incrementingLongs(); |
| this.environmentFactoryProviderMap = environmentFactoryMap; |
| this.executor = Executors.newCachedThreadPool(); |
| this.clientPool = MapControlClientPool.create(); |
| this.stageIdGenerator = () -> factoryId + "-" + stageIdSuffixGenerator.getId(); |
| this.environmentExpirationMillis = getEnvironmentExpirationMillis(jobInfo); |
| this.environmentCaches = |
| createEnvironmentCaches( |
| serverFactory -> createServerInfo(jobInfo, serverFactory), |
| getMaxEnvironmentClients(jobInfo)); |
| } |
| |
| @VisibleForTesting |
| DefaultJobBundleFactory( |
| JobInfo jobInfo, |
| Map<String, EnvironmentFactory.Provider> environmentFactoryMap, |
| IdGenerator stageIdGenerator, |
| ServerInfo serverInfo) { |
| this.environmentFactoryProviderMap = environmentFactoryMap; |
| this.executor = Executors.newCachedThreadPool(); |
| this.clientPool = MapControlClientPool.create(); |
| this.stageIdGenerator = stageIdGenerator; |
| this.environmentExpirationMillis = getEnvironmentExpirationMillis(jobInfo); |
| this.environmentCaches = |
| createEnvironmentCaches(serverFactory -> serverInfo, getMaxEnvironmentClients(jobInfo)); |
| } |
| |
| private ImmutableList<LoadingCache<Environment, WrappedSdkHarnessClient>> createEnvironmentCaches( |
| ThrowingFunction<ServerFactory, ServerInfo> serverInfoCreator, int count) { |
| CacheBuilder builder = |
| CacheBuilder.newBuilder() |
| .removalListener( |
| (RemovalNotification<Environment, WrappedSdkHarnessClient> notification) -> { |
| int refCount = notification.getValue().unref(); |
| LOG.debug( |
| "Removed environment {} with {} remaining bundle references.", |
| notification.getKey(), |
| refCount); |
| }); |
| |
| if (environmentExpirationMillis > 0) { |
| builder = builder.expireAfterWrite(environmentExpirationMillis, TimeUnit.MILLISECONDS); |
| } |
| |
| ImmutableList.Builder<LoadingCache<Environment, WrappedSdkHarnessClient>> caches = |
| ImmutableList.builder(); |
| for (int i = 0; i < count; i++) { |
| LoadingCache<Environment, WrappedSdkHarnessClient> cache = |
| builder.build( |
| new CacheLoader<Environment, WrappedSdkHarnessClient>() { |
| @Override |
| public WrappedSdkHarnessClient load(Environment environment) throws Exception { |
| EnvironmentFactory.Provider environmentFactoryProvider = |
| environmentFactoryProviderMap.get(environment.getUrn()); |
| ServerFactory serverFactory = environmentFactoryProvider.getServerFactory(); |
| ServerInfo serverInfo = serverInfoCreator.apply(serverFactory); |
| EnvironmentFactory environmentFactory = |
| environmentFactoryProvider.createEnvironmentFactory( |
| serverInfo.getControlServer(), |
| serverInfo.getLoggingServer(), |
| serverInfo.getRetrievalServer(), |
| serverInfo.getProvisioningServer(), |
| clientPool, |
| stageIdGenerator); |
| return WrappedSdkHarnessClient.wrapping( |
| environmentFactory.createEnvironment(environment), serverInfo); |
| } |
| }); |
| caches.add(cache); |
| } |
| return caches.build(); |
| } |
| |
| private static int getEnvironmentExpirationMillis(JobInfo jobInfo) { |
| PipelineOptions pipelineOptions = |
| PipelineOptionsTranslation.fromProto(jobInfo.pipelineOptions()); |
| return pipelineOptions.as(PortablePipelineOptions.class).getEnvironmentExpirationMillis(); |
| } |
| |
| private static int getMaxEnvironmentClients(JobInfo jobInfo) { |
| PortablePipelineOptions pipelineOptions = |
| PipelineOptionsTranslation.fromProto(jobInfo.pipelineOptions()) |
| .as(PortablePipelineOptions.class); |
| int maxEnvironments = |
| MoreObjects.firstNonNull(pipelineOptions.getSdkWorkerParallelism(), 1L).intValue(); |
| Preconditions.checkArgument(maxEnvironments >= 0, "sdk_worker_parallelism must be >= 0"); |
| if (maxEnvironments == 0) { |
| // if this is 0, use the auto behavior of num_cores - 1 so that we leave some resources |
| // available for the java process |
| maxEnvironments = Math.max(Runtime.getRuntime().availableProcessors() - 1, 1); |
| } |
| return maxEnvironments; |
| } |
| |
| @Override |
| public StageBundleFactory forStage(ExecutableStage executableStage) { |
| return new SimpleStageBundleFactory(executableStage); |
| } |
| |
| @Override |
| public void close() throws Exception { |
| // Clear the cache. This closes all active environments. |
| // note this may cause open calls to be cancelled by the peer |
| for (LoadingCache<Environment, WrappedSdkHarnessClient> environmentCache : environmentCaches) { |
| environmentCache.invalidateAll(); |
| environmentCache.cleanUp(); |
| } |
| executor.shutdown(); |
| } |
| |
| /** |
| * A {@link StageBundleFactory} for remotely processing bundles that supports environment |
| * expiration. |
| */ |
| private class SimpleStageBundleFactory implements StageBundleFactory { |
| |
| private final ExecutableStage executableStage; |
| private final int environmentIndex; |
| private BundleProcessor processor; |
| private ExecutableProcessBundleDescriptor processBundleDescriptor; |
| private WrappedSdkHarnessClient wrappedClient; |
| |
| private SimpleStageBundleFactory(ExecutableStage executableStage) { |
| this.executableStage = executableStage; |
| this.environmentIndex = stageBundleCount.getAndIncrement() % environmentCaches.size(); |
| prepare( |
| environmentCaches.get(environmentIndex).getUnchecked(executableStage.getEnvironment())); |
| } |
| |
| private void prepare(WrappedSdkHarnessClient wrappedClient) { |
| try { |
| this.wrappedClient = wrappedClient; |
| this.processBundleDescriptor = |
| ProcessBundleDescriptors.fromExecutableStage( |
| stageIdGenerator.getId(), |
| executableStage, |
| wrappedClient.getServerInfo().getDataServer().getApiServiceDescriptor(), |
| wrappedClient.getServerInfo().getStateServer().getApiServiceDescriptor()); |
| } catch (IOException e) { |
| throw new RuntimeException("Failed to create ProcessBundleDescriptor.", e); |
| } |
| |
| this.processor = |
| wrappedClient |
| .getClient() |
| .getProcessor( |
| processBundleDescriptor.getProcessBundleDescriptor(), |
| processBundleDescriptor.getRemoteInputDestinations(), |
| wrappedClient.getServerInfo().getStateServer().getService()); |
| } |
| |
| @Override |
| public RemoteBundle getBundle( |
| OutputReceiverFactory outputReceiverFactory, |
| StateRequestHandler stateRequestHandler, |
| BundleProgressHandler progressHandler) |
| throws Exception { |
| // TODO: Consider having BundleProcessor#newBundle take in an OutputReceiverFactory rather |
| // than constructing the receiver map here. Every bundle factory will need this. |
| ImmutableMap.Builder<String, RemoteOutputReceiver<?>> outputReceivers = |
| ImmutableMap.builder(); |
| for (Map.Entry<String, Coder> remoteOutputCoder : |
| processBundleDescriptor.getRemoteOutputCoders().entrySet()) { |
| String outputTransform = remoteOutputCoder.getKey(); |
| Coder coder = remoteOutputCoder.getValue(); |
| String bundleOutputPCollection = |
| Iterables.getOnlyElement( |
| processBundleDescriptor |
| .getProcessBundleDescriptor() |
| .getTransformsOrThrow(outputTransform) |
| .getInputsMap() |
| .values()); |
| FnDataReceiver outputReceiver = outputReceiverFactory.create(bundleOutputPCollection); |
| outputReceivers.put(outputTransform, RemoteOutputReceiver.of(coder, outputReceiver)); |
| } |
| |
| if (environmentExpirationMillis == 0) { |
| return processor.newBundle(outputReceivers.build(), stateRequestHandler, progressHandler); |
| } |
| |
| final WrappedSdkHarnessClient client = |
| environmentCaches.get(environmentIndex).getUnchecked(executableStage.getEnvironment()); |
| client.ref(); |
| |
| if (client != wrappedClient) { |
| // reset after environment expired |
| prepare(client); |
| } |
| |
| final RemoteBundle bundle = |
| processor.newBundle(outputReceivers.build(), stateRequestHandler, progressHandler); |
| return new RemoteBundle() { |
| @Override |
| public String getId() { |
| return bundle.getId(); |
| } |
| |
| @Override |
| public Map<String, FnDataReceiver> getInputReceivers() { |
| return bundle.getInputReceivers(); |
| } |
| |
| @Override |
| public void close() throws Exception { |
| bundle.close(); |
| client.unref(); |
| } |
| }; |
| } |
| |
| @Override |
| public ExecutableProcessBundleDescriptor getProcessBundleDescriptor() { |
| return processBundleDescriptor; |
| } |
| |
| @Override |
| public void close() throws Exception { |
| // Clear reference to encourage cache eviction. Values are weakly referenced. |
| wrappedClient = null; |
| } |
| } |
| |
| /** |
| * Holder for an {@link SdkHarnessClient} along with its associated state and data servers. As of |
| * now, there is a 1:1 relationship between data services and harness clients. The servers are |
| * packaged here to tie server lifetimes to harness client lifetimes. |
| */ |
| protected static class WrappedSdkHarnessClient implements AutoCloseable { |
| |
| private final RemoteEnvironment environment; |
| private final SdkHarnessClient client; |
| private final ServerInfo serverInfo; |
| private final AtomicInteger bundleRefCount = new AtomicInteger(); |
| |
| static WrappedSdkHarnessClient wrapping(RemoteEnvironment environment, ServerInfo serverInfo) { |
| SdkHarnessClient client = |
| SdkHarnessClient.usingFnApiClient( |
| environment.getInstructionRequestHandler(), serverInfo.getDataServer().getService()); |
| return new WrappedSdkHarnessClient(environment, client, serverInfo); |
| } |
| |
| private WrappedSdkHarnessClient( |
| RemoteEnvironment environment, SdkHarnessClient client, ServerInfo serverInfo) { |
| this.environment = environment; |
| this.client = client; |
| this.serverInfo = serverInfo; |
| ref(); |
| } |
| |
| SdkHarnessClient getClient() { |
| return client; |
| } |
| |
| ServerInfo getServerInfo() { |
| return serverInfo; |
| } |
| |
| @Override |
| public void close() throws Exception { |
| try (AutoCloseable envCloser = environment) { |
| // Wrap resources in try-with-resources to ensure all are cleaned up. |
| } |
| try (AutoCloseable stateServer = serverInfo.getStateServer(); |
| AutoCloseable dateServer = serverInfo.getDataServer(); |
| AutoCloseable controlServer = serverInfo.getControlServer(); |
| AutoCloseable loggingServer = serverInfo.getLoggingServer(); |
| AutoCloseable retrievalServer = serverInfo.getRetrievalServer(); |
| AutoCloseable provisioningServer = serverInfo.getProvisioningServer()) {} |
| // TODO: Wait for executor shutdown? |
| } |
| |
| private int ref() { |
| return bundleRefCount.incrementAndGet(); |
| } |
| |
| private int unref() { |
| int count = bundleRefCount.decrementAndGet(); |
| if (count == 0) { |
| // Close environment after it was removed from cache and all bundles finished. |
| LOG.info("Closing environment {}", environment.getEnvironment()); |
| try { |
| close(); |
| } catch (Exception e) { |
| LOG.warn("Error cleaning up environment {}", environment.getEnvironment(), e); |
| } |
| } |
| return count; |
| } |
| } |
| |
| private ServerInfo createServerInfo(JobInfo jobInfo, ServerFactory serverFactory) |
| throws IOException { |
| Preconditions.checkNotNull(serverFactory, "serverFactory can not be null"); |
| |
| PortablePipelineOptions portableOptions = |
| PipelineOptionsTranslation.fromProto(jobInfo.pipelineOptions()) |
| .as(PortablePipelineOptions.class); |
| ArtifactRetrievalService artifactRetrievalService; |
| |
| if (portableOptions.getRetrievalServiceType() == RetrievalServiceType.CLASSLOADER) { |
| artifactRetrievalService = new ClassLoaderArtifactRetrievalService(); |
| } else { |
| artifactRetrievalService = BeamFileSystemArtifactRetrievalService.create(); |
| } |
| |
| GrpcFnServer<FnApiControlClientPoolService> controlServer = |
| GrpcFnServer.allocatePortAndCreateFor( |
| FnApiControlClientPoolService.offeringClientsToPool( |
| clientPool.getSink(), GrpcContextHeaderAccessorProvider.getHeaderAccessor()), |
| serverFactory); |
| GrpcFnServer<GrpcLoggingService> loggingServer = |
| GrpcFnServer.allocatePortAndCreateFor( |
| GrpcLoggingService.forWriter(Slf4jLogWriter.getDefault()), serverFactory); |
| GrpcFnServer<ArtifactRetrievalService> retrievalServer = |
| GrpcFnServer.allocatePortAndCreateFor(artifactRetrievalService, serverFactory); |
| GrpcFnServer<StaticGrpcProvisionService> provisioningServer = |
| GrpcFnServer.allocatePortAndCreateFor( |
| StaticGrpcProvisionService.create(jobInfo.toProvisionInfo()), serverFactory); |
| GrpcFnServer<GrpcDataService> dataServer = |
| GrpcFnServer.allocatePortAndCreateFor( |
| GrpcDataService.create(executor, OutboundObserverFactory.serverDirect()), |
| serverFactory); |
| GrpcFnServer<GrpcStateService> stateServer = |
| GrpcFnServer.allocatePortAndCreateFor(GrpcStateService.create(), serverFactory); |
| |
| ServerInfo serverInfo = |
| new AutoValue_DefaultJobBundleFactory_ServerInfo.Builder() |
| .setControlServer(controlServer) |
| .setLoggingServer(loggingServer) |
| .setRetrievalServer(retrievalServer) |
| .setProvisioningServer(provisioningServer) |
| .setDataServer(dataServer) |
| .setStateServer(stateServer) |
| .build(); |
| return serverInfo; |
| } |
| |
| /** A container for EnvironmentFactory and its corresponding Grpc servers. */ |
| @AutoValue |
| public abstract static class ServerInfo { |
| abstract GrpcFnServer<FnApiControlClientPoolService> getControlServer(); |
| |
| abstract GrpcFnServer<GrpcLoggingService> getLoggingServer(); |
| |
| abstract GrpcFnServer<ArtifactRetrievalService> getRetrievalServer(); |
| |
| abstract GrpcFnServer<StaticGrpcProvisionService> getProvisioningServer(); |
| |
| abstract GrpcFnServer<GrpcDataService> getDataServer(); |
| |
| abstract GrpcFnServer<GrpcStateService> getStateServer(); |
| |
| abstract Builder toBuilder(); |
| |
| @AutoValue.Builder |
| abstract static class Builder { |
| abstract Builder setControlServer(GrpcFnServer<FnApiControlClientPoolService> server); |
| |
| abstract Builder setLoggingServer(GrpcFnServer<GrpcLoggingService> server); |
| |
| abstract Builder setRetrievalServer(GrpcFnServer<ArtifactRetrievalService> server); |
| |
| abstract Builder setProvisioningServer(GrpcFnServer<StaticGrpcProvisionService> server); |
| |
| abstract Builder setDataServer(GrpcFnServer<GrpcDataService> server); |
| |
| abstract Builder setStateServer(GrpcFnServer<GrpcStateService> server); |
| |
| abstract ServerInfo build(); |
| } |
| } |
| } |