| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.fn.harness; |
| |
| import com.google.auto.service.AutoService; |
| import java.io.IOException; |
| import java.util.Map; |
| import java.util.function.Supplier; |
| import org.apache.beam.fn.harness.control.BundleSplitListener; |
| import org.apache.beam.fn.harness.data.BeamFnDataClient; |
| import org.apache.beam.fn.harness.data.PCollectionConsumerRegistry; |
| import org.apache.beam.fn.harness.data.PTransformFunctionRegistry; |
| import org.apache.beam.fn.harness.state.BeamFnStateClient; |
| import org.apache.beam.model.pipeline.v1.RunnerApi; |
| import org.apache.beam.model.pipeline.v1.RunnerApi.CombinePayload; |
| import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection; |
| import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform; |
| import org.apache.beam.model.pipeline.v1.RunnerApi.StandardPTransforms; |
| import org.apache.beam.runners.core.construction.BeamUrns; |
| import org.apache.beam.runners.core.construction.RehydratedComponents; |
| import org.apache.beam.sdk.coders.Coder; |
| import org.apache.beam.sdk.coders.KvCoder; |
| import org.apache.beam.sdk.fn.data.FnDataReceiver; |
| import org.apache.beam.sdk.function.ThrowingFunction; |
| import org.apache.beam.sdk.options.PipelineOptions; |
| import org.apache.beam.sdk.transforms.Combine.CombineFn; |
| import org.apache.beam.sdk.util.SerializableUtils; |
| import org.apache.beam.sdk.util.WindowedValue; |
| import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; |
| import org.apache.beam.sdk.values.KV; |
| import org.apache.beam.vendor.guava.v20_0.com.google.common.annotations.VisibleForTesting; |
| import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap; |
| import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables; |
| |
| /** Executes different components of Combine PTransforms. */ |
| public class CombineRunners { |
| |
| /** A registrar which provides a factory to handle combine component PTransforms. */ |
| @AutoService(PTransformRunnerFactory.Registrar.class) |
| public static class Registrar implements PTransformRunnerFactory.Registrar { |
| |
| @Override |
| public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() { |
| return ImmutableMap.of( |
| BeamUrns.getUrn(StandardPTransforms.CombineComponents.COMBINE_PER_KEY_PRECOMBINE), |
| new PrecombineFactory(), |
| BeamUrns.getUrn(StandardPTransforms.CombineComponents.COMBINE_PER_KEY_MERGE_ACCUMULATORS), |
| MapFnRunners.forValueMapFnFactory(CombineRunners::createMergeAccumulatorsMapFunction), |
| BeamUrns.getUrn(StandardPTransforms.CombineComponents.COMBINE_PER_KEY_EXTRACT_OUTPUTS), |
| MapFnRunners.forValueMapFnFactory(CombineRunners::createExtractOutputsMapFunction), |
| BeamUrns.getUrn(StandardPTransforms.CombineComponents.COMBINE_GROUPED_VALUES), |
| MapFnRunners.forValueMapFnFactory(CombineRunners::createCombineGroupedValuesMapFunction)); |
| } |
| } |
| |
| private static class PrecombineRunner<KeyT, InputT, AccumT> { |
| private PipelineOptions options; |
| private CombineFn<InputT, AccumT, ?> combineFn; |
| private FnDataReceiver<WindowedValue<KV<KeyT, AccumT>>> output; |
| private Coder<KeyT> keyCoder; |
| private GroupingTable<WindowedValue<KeyT>, InputT, AccumT> groupingTable; |
| private Coder<AccumT> accumCoder; |
| |
| PrecombineRunner( |
| PipelineOptions options, |
| CombineFn<InputT, AccumT, ?> combineFn, |
| FnDataReceiver<WindowedValue<KV<KeyT, AccumT>>> output, |
| Coder<KeyT> keyCoder, |
| Coder<AccumT> accumCoder) { |
| this.options = options; |
| this.combineFn = combineFn; |
| this.output = output; |
| this.keyCoder = keyCoder; |
| this.accumCoder = accumCoder; |
| } |
| |
| void startBundle() { |
| groupingTable = |
| PrecombineGroupingTable.combiningAndSampling( |
| options, combineFn, keyCoder, accumCoder, 0.001 /*sizeEstimatorSampleRate*/); |
| } |
| |
| void processElement(WindowedValue<KV<KeyT, InputT>> elem) throws Exception { |
| groupingTable.put( |
| elem, (Object outputElem) -> output.accept((WindowedValue<KV<KeyT, AccumT>>) outputElem)); |
| } |
| |
| void finishBundle() throws Exception { |
| groupingTable.flush( |
| (Object outputElem) -> output.accept((WindowedValue<KV<KeyT, AccumT>>) outputElem)); |
| } |
| } |
| |
| /** A factory for {@link PrecombineRunner}s. */ |
| @VisibleForTesting |
| public static class PrecombineFactory<KeyT, InputT, AccumT> |
| implements PTransformRunnerFactory<PrecombineRunner<KeyT, InputT, AccumT>> { |
| |
| @Override |
| public PrecombineRunner<KeyT, InputT, AccumT> createRunnerForPTransform( |
| PipelineOptions pipelineOptions, |
| BeamFnDataClient beamFnDataClient, |
| BeamFnStateClient beamFnStateClient, |
| String pTransformId, |
| PTransform pTransform, |
| Supplier<String> processBundleInstructionId, |
| Map<String, PCollection> pCollections, |
| Map<String, RunnerApi.Coder> coders, |
| Map<String, RunnerApi.WindowingStrategy> windowingStrategies, |
| PCollectionConsumerRegistry pCollectionConsumerRegistry, |
| PTransformFunctionRegistry startFunctionRegistry, |
| PTransformFunctionRegistry finishFunctionRegistry, |
| BundleSplitListener splitListener) |
| throws IOException { |
| // Get objects needed to create the runner. |
| RehydratedComponents rehydratedComponents = |
| RehydratedComponents.forComponents( |
| RunnerApi.Components.newBuilder() |
| .putAllCoders(coders) |
| .putAllWindowingStrategies(windowingStrategies) |
| .build()); |
| String mainInputTag = Iterables.getOnlyElement(pTransform.getInputsMap().keySet()); |
| RunnerApi.PCollection mainInput = pCollections.get(pTransform.getInputsOrThrow(mainInputTag)); |
| |
| // Input coder may sometimes be WindowedValueCoder depending on runner, instead of the |
| // expected KvCoder. |
| Coder<?> uncastInputCoder = rehydratedComponents.getCoder(mainInput.getCoderId()); |
| KvCoder<KeyT, InputT> inputCoder; |
| if (uncastInputCoder instanceof WindowedValueCoder) { |
| inputCoder = |
| (KvCoder<KeyT, InputT>) |
| ((WindowedValueCoder<KV<KeyT, InputT>>) uncastInputCoder).getValueCoder(); |
| } else { |
| inputCoder = (KvCoder<KeyT, InputT>) rehydratedComponents.getCoder(mainInput.getCoderId()); |
| } |
| Coder<KeyT> keyCoder = inputCoder.getKeyCoder(); |
| |
| CombinePayload combinePayload = CombinePayload.parseFrom(pTransform.getSpec().getPayload()); |
| CombineFn<InputT, AccumT, ?> combineFn = |
| (CombineFn) |
| SerializableUtils.deserializeFromByteArray( |
| combinePayload.getCombineFn().getSpec().getPayload().toByteArray(), "CombineFn"); |
| Coder<AccumT> accumCoder = |
| (Coder<AccumT>) rehydratedComponents.getCoder(combinePayload.getAccumulatorCoderId()); |
| |
| FnDataReceiver<WindowedValue<KV<KeyT, AccumT>>> consumer = |
| (FnDataReceiver) |
| pCollectionConsumerRegistry.getMultiplexingConsumer( |
| Iterables.getOnlyElement(pTransform.getOutputsMap().values())); |
| |
| PrecombineRunner<KeyT, InputT, AccumT> runner = |
| new PrecombineRunner<>(pipelineOptions, combineFn, consumer, keyCoder, accumCoder); |
| |
| // Register the appropriate handlers. |
| startFunctionRegistry.register(pTransformId, runner::startBundle); |
| pCollectionConsumerRegistry.register( |
| Iterables.getOnlyElement(pTransform.getInputsMap().values()), |
| pTransformId, |
| (FnDataReceiver) |
| (FnDataReceiver<WindowedValue<KV<KeyT, InputT>>>) runner::processElement); |
| finishFunctionRegistry.register(pTransformId, runner::finishBundle); |
| |
| return runner; |
| } |
| } |
| |
| static <KeyT, AccumT> |
| ThrowingFunction<KV<KeyT, Iterable<AccumT>>, KV<KeyT, AccumT>> |
| createMergeAccumulatorsMapFunction(String pTransformId, PTransform pTransform) |
| throws IOException { |
| CombinePayload combinePayload = CombinePayload.parseFrom(pTransform.getSpec().getPayload()); |
| CombineFn<?, AccumT, ?> combineFn = |
| (CombineFn) |
| SerializableUtils.deserializeFromByteArray( |
| combinePayload.getCombineFn().getSpec().getPayload().toByteArray(), "CombineFn"); |
| |
| return (KV<KeyT, Iterable<AccumT>> input) -> |
| KV.of(input.getKey(), combineFn.mergeAccumulators(input.getValue())); |
| } |
| |
| static <KeyT, AccumT, OutputT> |
| ThrowingFunction<KV<KeyT, AccumT>, KV<KeyT, OutputT>> createExtractOutputsMapFunction( |
| String pTransformId, PTransform pTransform) throws IOException { |
| CombinePayload combinePayload = CombinePayload.parseFrom(pTransform.getSpec().getPayload()); |
| CombineFn<?, AccumT, OutputT> combineFn = |
| (CombineFn) |
| SerializableUtils.deserializeFromByteArray( |
| combinePayload.getCombineFn().getSpec().getPayload().toByteArray(), "CombineFn"); |
| |
| return (KV<KeyT, AccumT> input) -> |
| KV.of(input.getKey(), combineFn.extractOutput(input.getValue())); |
| } |
| |
| static <KeyT, InputT, AccumT, OutputT> |
| ThrowingFunction<KV<KeyT, Iterable<InputT>>, KV<KeyT, OutputT>> |
| createCombineGroupedValuesMapFunction(String pTransformId, PTransform pTransform) |
| throws IOException { |
| CombinePayload combinePayload = CombinePayload.parseFrom(pTransform.getSpec().getPayload()); |
| CombineFn<InputT, AccumT, OutputT> combineFn = |
| (CombineFn) |
| SerializableUtils.deserializeFromByteArray( |
| combinePayload.getCombineFn().getSpec().getPayload().toByteArray(), "CombineFn"); |
| |
| return (KV<KeyT, Iterable<InputT>> input) -> { |
| return KV.of(input.getKey(), combineFn.apply(input.getValue())); |
| }; |
| } |
| } |