| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.runners.direct; |
| |
| import static com.google.common.base.Preconditions.checkState; |
| |
| import com.google.auto.value.AutoValue; |
| import com.google.common.cache.CacheBuilder; |
| import com.google.common.cache.CacheLoader; |
| import com.google.common.cache.LoadingCache; |
| import com.google.common.collect.Lists; |
| import java.util.Collections; |
| import java.util.HashMap; |
| import java.util.Map; |
| import org.apache.beam.runners.core.KeyedWorkItem; |
| import org.apache.beam.runners.core.KeyedWorkItems; |
| import org.apache.beam.runners.core.StateNamespace; |
| import org.apache.beam.runners.core.StateNamespaces; |
| import org.apache.beam.runners.core.StateNamespaces.WindowNamespace; |
| import org.apache.beam.runners.core.StateTag; |
| import org.apache.beam.runners.core.StateTags; |
| import org.apache.beam.runners.core.TimerInternals.TimerData; |
| import org.apache.beam.runners.direct.DirectExecutionContext.DirectStepContext; |
| import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; |
| import org.apache.beam.runners.direct.ParDoMultiOverrideFactory.StatefulParDo; |
| import org.apache.beam.sdk.coders.Coder; |
| import org.apache.beam.sdk.transforms.AppliedPTransform; |
| import org.apache.beam.sdk.transforms.DoFn; |
| import org.apache.beam.sdk.transforms.ParDo; |
| import org.apache.beam.sdk.transforms.reflect.DoFnSignature; |
| import org.apache.beam.sdk.transforms.reflect.DoFnSignature.StateDeclaration; |
| import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; |
| import org.apache.beam.sdk.transforms.windowing.BoundedWindow; |
| import org.apache.beam.sdk.util.WindowedValue; |
| import org.apache.beam.sdk.util.WindowingStrategy; |
| import org.apache.beam.sdk.util.state.StateSpec; |
| import org.apache.beam.sdk.values.KV; |
| import org.apache.beam.sdk.values.PCollection; |
| import org.apache.beam.sdk.values.PCollectionTuple; |
| import org.apache.beam.sdk.values.TaggedPValue; |
| import org.apache.beam.sdk.values.TupleTag; |
| |
| /** A {@link TransformEvaluatorFactory} for stateful {@link ParDo}. */ |
| final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements TransformEvaluatorFactory { |
| |
| private final LoadingCache<AppliedPTransformOutputKeyAndWindow<K, InputT, OutputT>, Runnable> |
| cleanupRegistry; |
| |
| private final ParDoEvaluatorFactory<KV<K, InputT>, OutputT> delegateFactory; |
| |
| StatefulParDoEvaluatorFactory(EvaluationContext evaluationContext) { |
| this.delegateFactory = new ParDoEvaluatorFactory<>(evaluationContext); |
| this.cleanupRegistry = |
| CacheBuilder.newBuilder() |
| .weakValues() |
| .build(new CleanupSchedulingLoader(evaluationContext)); |
| } |
| |
| @Override |
| public <T> TransformEvaluator<T> forApplication( |
| AppliedPTransform<?, ?, ?> application, CommittedBundle<?> inputBundle) throws Exception { |
| @SuppressWarnings({"unchecked", "rawtypes"}) |
| TransformEvaluator<T> evaluator = |
| (TransformEvaluator<T>) |
| createEvaluator((AppliedPTransform) application, (CommittedBundle) inputBundle); |
| return evaluator; |
| } |
| |
| @Override |
| public void cleanup() throws Exception { |
| delegateFactory.cleanup(); |
| } |
| |
| @SuppressWarnings({"unchecked", "rawtypes"}) |
| private TransformEvaluator<KeyedWorkItem<K, KV<K, InputT>>> createEvaluator( |
| AppliedPTransform< |
| PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple, |
| StatefulParDo<K, InputT, OutputT>> |
| application, |
| CommittedBundle<KeyedWorkItem<K, KV<K, InputT>>> inputBundle) |
| throws Exception { |
| |
| final DoFn<KV<K, InputT>, OutputT> doFn = |
| application.getTransform().getUnderlyingParDo().getFn(); |
| final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); |
| |
| // If the DoFn is stateful, schedule state clearing. |
| // It is semantically correct to schedule any number of redundant clear tasks; the |
| // cache is used to limit the number of tasks to avoid performance degradation. |
| if (signature.stateDeclarations().size() > 0) { |
| for (final WindowedValue<?> element : inputBundle.getElements()) { |
| for (final BoundedWindow window : element.getWindows()) { |
| cleanupRegistry.get( |
| AppliedPTransformOutputKeyAndWindow.create( |
| application, (StructuralKey<K>) inputBundle.getKey(), window)); |
| } |
| } |
| } |
| |
| DoFnLifecycleManagerRemovingTransformEvaluator<KV<K, InputT>> delegateEvaluator = |
| delegateFactory.createEvaluator( |
| (AppliedPTransform) application, |
| inputBundle.getKey(), |
| doFn, |
| application.getTransform().getUnderlyingParDo().getSideInputs(), |
| application.getTransform().getUnderlyingParDo().getMainOutputTag(), |
| application.getTransform().getUnderlyingParDo().getSideOutputTags().getAll()); |
| |
| return new StatefulParDoEvaluator<>(delegateEvaluator); |
| } |
| |
| private class CleanupSchedulingLoader |
| extends CacheLoader<AppliedPTransformOutputKeyAndWindow<K, InputT, OutputT>, Runnable> { |
| |
| private final EvaluationContext evaluationContext; |
| |
| public CleanupSchedulingLoader(EvaluationContext evaluationContext) { |
| this.evaluationContext = evaluationContext; |
| } |
| |
| @Override |
| public Runnable load( |
| final AppliedPTransformOutputKeyAndWindow<K, InputT, OutputT> transformOutputWindow) { |
| String stepName = evaluationContext.getStepName(transformOutputWindow.getTransform()); |
| |
| Map<TupleTag<?>, PCollection<?>> taggedValues = new HashMap<>(); |
| for (TaggedPValue pv : transformOutputWindow.getTransform().getOutputs()) { |
| taggedValues.put(pv.getTag(), (PCollection<?>) pv.getValue()); |
| } |
| PCollection<?> pc = |
| taggedValues |
| .get( |
| transformOutputWindow |
| .getTransform() |
| .getTransform() |
| .getUnderlyingParDo() |
| .getMainOutputTag()); |
| WindowingStrategy<?, ?> windowingStrategy = pc.getWindowingStrategy(); |
| BoundedWindow window = transformOutputWindow.getWindow(); |
| final DoFn<?, ?> doFn = |
| transformOutputWindow.getTransform().getTransform().getUnderlyingParDo().getFn(); |
| final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); |
| |
| final DirectStepContext stepContext = |
| evaluationContext |
| .getExecutionContext( |
| transformOutputWindow.getTransform(), transformOutputWindow.getKey()) |
| .getOrCreateStepContext(stepName, stepName); |
| |
| final StateNamespace namespace = |
| StateNamespaces.window( |
| (Coder<BoundedWindow>) windowingStrategy.getWindowFn().windowCoder(), window); |
| |
| Runnable cleanup = |
| new Runnable() { |
| @Override |
| public void run() { |
| for (StateDeclaration stateDecl : signature.stateDeclarations().values()) { |
| StateTag<Object, ?> tag; |
| try { |
| tag = |
| StateTags.tagForSpec(stateDecl.id(), (StateSpec) stateDecl.field().get(doFn)); |
| } catch (IllegalAccessException e) { |
| throw new RuntimeException( |
| String.format( |
| "Error accessing %s for %s", |
| StateSpec.class.getName(), doFn.getClass().getName()), |
| e); |
| } |
| stepContext.stateInternals().state(namespace, tag).clear(); |
| } |
| cleanupRegistry.invalidate(transformOutputWindow); |
| } |
| }; |
| |
| evaluationContext.scheduleAfterWindowExpiration( |
| transformOutputWindow.getTransform(), window, windowingStrategy, cleanup); |
| return cleanup; |
| } |
| } |
| |
| @AutoValue |
| abstract static class AppliedPTransformOutputKeyAndWindow<K, InputT, OutputT> { |
| abstract AppliedPTransform< |
| PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple, |
| StatefulParDo<K, InputT, OutputT>> |
| getTransform(); |
| |
| abstract StructuralKey<K> getKey(); |
| |
| abstract BoundedWindow getWindow(); |
| |
| static <K, InputT, OutputT> AppliedPTransformOutputKeyAndWindow<K, InputT, OutputT> create( |
| AppliedPTransform< |
| PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple, |
| StatefulParDo<K, InputT, OutputT>> |
| transform, |
| StructuralKey<K> key, |
| BoundedWindow w) { |
| return new AutoValue_StatefulParDoEvaluatorFactory_AppliedPTransformOutputKeyAndWindow<>( |
| transform, key, w); |
| } |
| } |
| |
| private static class StatefulParDoEvaluator<K, InputT> |
| implements TransformEvaluator<KeyedWorkItem<K, KV<K, InputT>>> { |
| |
| private final DoFnLifecycleManagerRemovingTransformEvaluator<KV<K, InputT>> delegateEvaluator; |
| |
| public StatefulParDoEvaluator( |
| DoFnLifecycleManagerRemovingTransformEvaluator<KV<K, InputT>> delegateEvaluator) { |
| this.delegateEvaluator = delegateEvaluator; |
| } |
| |
| @Override |
| public void processElement(WindowedValue<KeyedWorkItem<K, KV<K, InputT>>> gbkResult) |
| throws Exception { |
| for (WindowedValue<KV<K, InputT>> windowedValue : gbkResult.getValue().elementsIterable()) { |
| delegateEvaluator.processElement(windowedValue); |
| } |
| |
| for (TimerData timer : gbkResult.getValue().timersIterable()) { |
| checkState( |
| timer.getNamespace() instanceof WindowNamespace, |
| "Expected Timer %s to be in a %s, but got %s", |
| timer, |
| WindowNamespace.class.getSimpleName(), |
| timer.getNamespace().getClass().getName()); |
| WindowNamespace<?> windowNamespace = (WindowNamespace) timer.getNamespace(); |
| BoundedWindow timerWindow = windowNamespace.getWindow(); |
| delegateEvaluator.onTimer(timer, timerWindow); |
| } |
| } |
| |
| @Override |
| public TransformResult<KeyedWorkItem<K, KV<K, InputT>>> finishBundle() throws Exception { |
| TransformResult<KV<K, InputT>> delegateResult = delegateEvaluator.finishBundle(); |
| |
| StepTransformResult.Builder<KeyedWorkItem<K, KV<K, InputT>>> regroupedResult = |
| StepTransformResult.<KeyedWorkItem<K, KV<K, InputT>>>withHold( |
| delegateResult.getTransform(), delegateResult.getWatermarkHold()) |
| .withTimerUpdate(delegateResult.getTimerUpdate()) |
| .withState(delegateResult.getState()) |
| .withAggregatorChanges(delegateResult.getAggregatorChanges()) |
| .withMetricUpdates(delegateResult.getLogicalMetricUpdates()) |
| .addOutput(Lists.newArrayList(delegateResult.getOutputBundles())); |
| |
| // The delegate may have pushed back unprocessed elements across multiple keys and windows. |
| // Since processing is single-threaded per key and window, we don't need to regroup the |
| // outputs, but just make a bunch of singletons |
| for (WindowedValue<?> untypedUnprocessed : delegateResult.getUnprocessedElements()) { |
| WindowedValue<KV<K, InputT>> windowedKv = (WindowedValue<KV<K, InputT>>) untypedUnprocessed; |
| WindowedValue<KeyedWorkItem<K, KV<K, InputT>>> pushedBack = |
| windowedKv.withValue( |
| KeyedWorkItems.elementsWorkItem( |
| windowedKv.getValue().getKey(), Collections.singleton(windowedKv))); |
| |
| regroupedResult.addUnprocessedElements(pushedBack); |
| } |
| |
| return regroupedResult.build(); |
| } |
| } |
| } |