| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.runners.direct; |
| |
| import java.util.ArrayList; |
| import java.util.Collection; |
| import java.util.Map; |
| import java.util.Optional; |
| import java.util.concurrent.BlockingQueue; |
| import java.util.concurrent.ConcurrentLinkedQueue; |
| import java.util.concurrent.ExecutorService; |
| import java.util.concurrent.Executors; |
| import java.util.concurrent.LinkedBlockingQueue; |
| import java.util.concurrent.TimeUnit; |
| import java.util.concurrent.atomic.AtomicReference; |
| import java.util.stream.Collectors; |
| import javax.annotation.Nullable; |
| import org.apache.beam.runners.local.ExecutionDriver; |
| import org.apache.beam.runners.local.ExecutionDriver.DriverState; |
| import org.apache.beam.runners.local.PipelineMessageReceiver; |
| import org.apache.beam.sdk.Pipeline; |
| import org.apache.beam.sdk.PipelineResult.State; |
| import org.apache.beam.sdk.runners.AppliedPTransform; |
| import org.apache.beam.sdk.util.UserCodeException; |
| import org.apache.beam.sdk.values.PCollection; |
| import org.apache.beam.sdk.values.PValue; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheBuilder; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheLoader; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LoadingCache; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.RemovalListener; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.MoreExecutors; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; |
| import org.joda.time.Duration; |
| import org.joda.time.Instant; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| /** |
| * An {@link PipelineExecutor} that uses an underlying {@link ExecutorService} and {@link |
| * EvaluationContext} to execute a {@link Pipeline}. |
| */ |
| final class ExecutorServiceParallelExecutor |
| implements PipelineExecutor, |
| BundleProcessor<PCollection<?>, CommittedBundle<?>, AppliedPTransform<?, ?, ?>> { |
| private static final Logger LOG = LoggerFactory.getLogger(ExecutorServiceParallelExecutor.class); |
| |
| private final int targetParallelism; |
| private final ExecutorService executorService; |
| |
| private final TransformEvaluatorRegistry registry; |
| |
| private final EvaluationContext evaluationContext; |
| |
| private final TransformExecutorFactory executorFactory; |
| private final TransformExecutorService parallelExecutorService; |
| private final LoadingCache<StepAndKey, TransformExecutorService> serialExecutorServices; |
| |
| private final QueueMessageReceiver visibleUpdates; |
| |
| private final ExecutorService metricsExecutor; |
| |
| private AtomicReference<State> pipelineState = new AtomicReference<>(State.RUNNING); |
| |
| public static ExecutorServiceParallelExecutor create( |
| int targetParallelism, |
| TransformEvaluatorRegistry registry, |
| Map<String, Collection<ModelEnforcementFactory>> transformEnforcements, |
| EvaluationContext context, |
| ExecutorService metricsExecutor) { |
| return new ExecutorServiceParallelExecutor( |
| targetParallelism, registry, transformEnforcements, context, metricsExecutor); |
| } |
| |
| private ExecutorServiceParallelExecutor( |
| int targetParallelism, |
| TransformEvaluatorRegistry registry, |
| Map<String, Collection<ModelEnforcementFactory>> transformEnforcements, |
| EvaluationContext context, |
| ExecutorService metricsExecutor) { |
| this.targetParallelism = targetParallelism; |
| this.metricsExecutor = metricsExecutor; |
| // Don't use Daemon threads for workers. The Pipeline should continue to execute even if there |
| // are no other active threads (for example, because waitUntilFinish was not called) |
| this.executorService = |
| Executors.newFixedThreadPool( |
| targetParallelism, |
| new ThreadFactoryBuilder() |
| .setThreadFactory(MoreExecutors.platformThreadFactory()) |
| .setNameFormat("direct-runner-worker") |
| .build()); |
| this.registry = registry; |
| this.evaluationContext = context; |
| |
| // Weak Values allows TransformExecutorServices that are no longer in use to be reclaimed. |
| // Executing TransformExecutorServices have a strong reference to their TransformExecutorService |
| // which stops the TransformExecutorServices from being prematurely garbage collected |
| serialExecutorServices = |
| CacheBuilder.newBuilder() |
| .weakValues() |
| .removalListener(shutdownExecutorServiceListener()) |
| .build(serialTransformExecutorServiceCacheLoader()); |
| |
| this.visibleUpdates = new QueueMessageReceiver(); |
| |
| parallelExecutorService = TransformExecutorServices.parallel(executorService); |
| executorFactory = new DirectTransformExecutor.Factory(context, registry, transformEnforcements); |
| } |
| |
| private CacheLoader<StepAndKey, TransformExecutorService> |
| serialTransformExecutorServiceCacheLoader() { |
| return new CacheLoader<StepAndKey, TransformExecutorService>() { |
| @Override |
| public TransformExecutorService load(StepAndKey stepAndKey) throws Exception { |
| return TransformExecutorServices.serial(executorService); |
| } |
| }; |
| } |
| |
| private RemovalListener<StepAndKey, TransformExecutorService> shutdownExecutorServiceListener() { |
| return notification -> { |
| TransformExecutorService service = notification.getValue(); |
| if (service != null) { |
| service.shutdown(); |
| } |
| }; |
| } |
| |
| @Override |
| // TODO: [BEAM-4563] Pass Future back to consumer to check for async errors |
| @SuppressWarnings("FutureReturnValueIgnored") |
| public void start(DirectGraph graph, RootProviderRegistry rootProviderRegistry) { |
| int numTargetSplits = Math.max(3, targetParallelism); |
| ImmutableMap.Builder<AppliedPTransform<?, ?, ?>, ConcurrentLinkedQueue<CommittedBundle<?>>> |
| pendingRootBundles = ImmutableMap.builder(); |
| for (AppliedPTransform<?, ?, ?> root : graph.getRootTransforms()) { |
| ConcurrentLinkedQueue<CommittedBundle<?>> pending = new ConcurrentLinkedQueue<>(); |
| try { |
| Collection<CommittedBundle<?>> initialInputs = |
| rootProviderRegistry.getInitialInputs(root, numTargetSplits); |
| pending.addAll(initialInputs); |
| } catch (Exception e) { |
| throw UserCodeException.wrap(e); |
| } |
| pendingRootBundles.put(root, pending); |
| } |
| evaluationContext.initialize(pendingRootBundles.build()); |
| final ExecutionDriver executionDriver = |
| QuiescenceDriver.create( |
| evaluationContext, graph, this, visibleUpdates, pendingRootBundles.build()); |
| executorService.submit( |
| new Runnable() { |
| @Override |
| public void run() { |
| DriverState drive = executionDriver.drive(); |
| if (drive.isTermainal()) { |
| State newPipelineState = State.UNKNOWN; |
| switch (drive) { |
| case FAILED: |
| newPipelineState = State.FAILED; |
| break; |
| case SHUTDOWN: |
| newPipelineState = State.DONE; |
| break; |
| case CONTINUE: |
| throw new IllegalStateException( |
| String.format("%s should not be a terminal state", DriverState.CONTINUE)); |
| default: |
| throw new IllegalArgumentException( |
| String.format("Unknown %s %s", DriverState.class.getSimpleName(), drive)); |
| } |
| shutdownIfNecessary(newPipelineState); |
| } else { |
| executorService.submit(this); |
| } |
| } |
| }); |
| } |
| |
| @SuppressWarnings("unchecked") |
| @Override |
| public void process( |
| CommittedBundle<?> bundle, |
| AppliedPTransform<?, ?, ?> consumer, |
| CompletionCallback onComplete) { |
| evaluateBundle(consumer, bundle, onComplete); |
| } |
| |
| private <T> void evaluateBundle( |
| final AppliedPTransform<?, ?, ?> transform, |
| final CommittedBundle<T> bundle, |
| final CompletionCallback onComplete) { |
| TransformExecutorService transformExecutor; |
| |
| if (isKeyed(bundle.getPCollection())) { |
| final StepAndKey stepAndKey = StepAndKey.of(transform, bundle.getKey()); |
| // This executor will remain reachable until it has executed all scheduled transforms. |
| // The TransformExecutors keep a strong reference to the Executor, the ExecutorService keeps |
| // a reference to the scheduled DirectTransformExecutor callable. Follow-up TransformExecutors |
| // (scheduled due to the completion of another DirectTransformExecutor) are provided to the |
| // ExecutorService before the Earlier DirectTransformExecutor callable completes. |
| transformExecutor = serialExecutorServices.getUnchecked(stepAndKey); |
| } else { |
| transformExecutor = parallelExecutorService; |
| } |
| |
| TransformExecutor callable = |
| executorFactory.create(bundle, transform, onComplete, transformExecutor); |
| if (!pipelineState.get().isTerminal()) { |
| transformExecutor.schedule(callable); |
| } |
| } |
| |
| private boolean isKeyed(PValue pvalue) { |
| return evaluationContext.isKeyed(pvalue); |
| } |
| |
| @Override |
| public State waitUntilFinish(Duration duration) throws Exception { |
| Instant completionTime; |
| if (duration.equals(Duration.ZERO)) { |
| completionTime = new Instant(Long.MAX_VALUE); |
| } else { |
| completionTime = Instant.now().plus(duration); |
| } |
| |
| VisibleExecutorUpdate update = null; |
| while (Instant.now().isBefore(completionTime) |
| && (update == null || isTerminalStateUpdate(update))) { |
| // Get an update; don't block forever if another thread has handled it. The call to poll will |
| // wait the entire timeout; this call primarily exists to relinquish any core. |
| update = visibleUpdates.tryNext(Duration.millis(25L)); |
| if (update == null && pipelineState.get().isTerminal()) { |
| // there are no updates to process and no updates will ever be published because the |
| // executor is shutdown |
| return pipelineState.get(); |
| } else if (update != null && update.thrown.isPresent()) { |
| Throwable thrown = update.thrown.get(); |
| if (thrown instanceof Exception) { |
| throw (Exception) thrown; |
| } else if (thrown instanceof Error) { |
| throw (Error) thrown; |
| } else { |
| throw new Exception("Unknown Type of Throwable", thrown); |
| } |
| } |
| } |
| return pipelineState.get(); |
| } |
| |
| @Override |
| public State getPipelineState() { |
| return pipelineState.get(); |
| } |
| |
| private boolean isTerminalStateUpdate(VisibleExecutorUpdate update) { |
| return !(update.getNewState() == null && update.getNewState().isTerminal()); |
| } |
| |
| @Override |
| public void stop() { |
| shutdownIfNecessary(State.CANCELLED); |
| visibleUpdates.cancelled(); |
| } |
| |
| private void shutdownIfNecessary(State newState) { |
| if (!newState.isTerminal()) { |
| return; |
| } |
| LOG.debug("Pipeline has terminated. Shutting down."); |
| |
| final Collection<Exception> errors = new ArrayList<>(); |
| // Stop accepting new work before shutting down the executor. This ensures that thread don't try |
| // to add work to the shutdown executor. |
| try { |
| serialExecutorServices.invalidateAll(); |
| } catch (final RuntimeException re) { |
| errors.add(re); |
| } |
| try { |
| serialExecutorServices.cleanUp(); |
| } catch (final RuntimeException re) { |
| errors.add(re); |
| } |
| try { |
| parallelExecutorService.shutdown(); |
| } catch (final RuntimeException re) { |
| errors.add(re); |
| } |
| try { |
| executorService.shutdown(); |
| } catch (final RuntimeException re) { |
| errors.add(re); |
| } |
| try { |
| metricsExecutor.shutdown(); |
| } catch (final RuntimeException re) { |
| errors.add(re); |
| } |
| try { |
| registry.cleanup(); |
| } catch (final Exception e) { |
| errors.add(e); |
| } |
| pipelineState.compareAndSet(State.RUNNING, newState); // ensure we hit a terminal node |
| if (!errors.isEmpty()) { |
| final IllegalStateException exception = |
| new IllegalStateException( |
| "Error" |
| + (errors.size() == 1 ? "" : "s") |
| + " during executor shutdown:\n" |
| + errors.stream() |
| .map(Exception::getMessage) |
| .collect(Collectors.joining("\n- ", "- ", ""))); |
| visibleUpdates.failed(exception); |
| throw exception; |
| } |
| } |
| |
| /** |
| * An update of interest to the user. Used in {@link #waitUntilFinish} to decide whether to return |
| * normally or throw an exception. |
| */ |
| private static class VisibleExecutorUpdate { |
| private final Optional<? extends Throwable> thrown; |
| @Nullable private final State newState; |
| |
| public static VisibleExecutorUpdate fromException(Exception e) { |
| return new VisibleExecutorUpdate(null, e); |
| } |
| |
| public static VisibleExecutorUpdate fromError(Error err) { |
| return new VisibleExecutorUpdate(State.FAILED, err); |
| } |
| |
| public static VisibleExecutorUpdate finished() { |
| return new VisibleExecutorUpdate(State.DONE, null); |
| } |
| |
| public static VisibleExecutorUpdate cancelled() { |
| return new VisibleExecutorUpdate(State.CANCELLED, null); |
| } |
| |
| private VisibleExecutorUpdate(State newState, @Nullable Throwable exception) { |
| this.thrown = Optional.ofNullable(exception); |
| this.newState = newState; |
| } |
| |
| State getNewState() { |
| return newState; |
| } |
| } |
| |
| private static class QueueMessageReceiver implements PipelineMessageReceiver { |
| // If the type of BlockingQueue changes, ensure the findbugs filter is updated appropriately |
| private final BlockingQueue<VisibleExecutorUpdate> updates = new LinkedBlockingQueue<>(); |
| |
| @Override |
| public void failed(Exception e) { |
| updates.offer(VisibleExecutorUpdate.fromException(e)); |
| } |
| |
| @Override |
| public void failed(Error e) { |
| updates.offer(VisibleExecutorUpdate.fromError(e)); |
| } |
| |
| @Override |
| public void cancelled() { |
| updates.offer(VisibleExecutorUpdate.cancelled()); |
| } |
| |
| @Override |
| public void completed() { |
| updates.offer(VisibleExecutorUpdate.finished()); |
| } |
| |
| /** Try to get the next unconsumed message in this {@link QueueMessageReceiver}. */ |
| @Nullable |
| private VisibleExecutorUpdate tryNext(Duration timeout) throws InterruptedException { |
| return updates.poll(timeout.getMillis(), TimeUnit.MILLISECONDS); |
| } |
| } |
| } |