blob: 5b17cbc18fe36dbe257ec26f539ba95424a01bca [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker;
import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument;
import com.google.api.services.dataflow.model.InstructionOutput;
import com.google.api.services.dataflow.model.MapTask;
import com.google.api.services.dataflow.model.MultiOutputInfo;
import com.google.api.services.dataflow.model.ParDoInstruction;
import com.google.api.services.dataflow.model.ParallelInstruction;
import com.google.api.services.dataflow.model.PartialGroupByKeyInstruction;
import com.google.api.services.dataflow.model.ReadInstruction;
import com.google.api.services.dataflow.model.Source;
import com.google.api.services.dataflow.model.WriteInstruction;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;
import org.apache.beam.model.pipeline.v1.Endpoints;
import org.apache.beam.runners.core.ElementByteSizeObservable;
import org.apache.beam.runners.dataflow.DataflowRunner;
import org.apache.beam.runners.dataflow.options.DataflowPipelineDebugOptions;
import org.apache.beam.runners.dataflow.util.CloudObject;
import org.apache.beam.runners.dataflow.util.CloudObjects;
import org.apache.beam.runners.dataflow.worker.counters.CounterFactory;
import org.apache.beam.runners.dataflow.worker.counters.CounterSet;
import org.apache.beam.runners.dataflow.worker.counters.NameContext;
import org.apache.beam.runners.dataflow.worker.graph.Edges.Edge;
import org.apache.beam.runners.dataflow.worker.graph.Edges.MultiOutputInfoEdge;
import org.apache.beam.runners.dataflow.worker.graph.Networks;
import org.apache.beam.runners.dataflow.worker.graph.Networks.TypeSafeNodeFunction;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.Node;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.OperationNode;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.OutputReceiverNode;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode;
import org.apache.beam.runners.dataflow.worker.util.CloudSourceUtils;
import org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter;
import org.apache.beam.runners.dataflow.worker.util.common.worker.FlattenOperation;
import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader;
import org.apache.beam.runners.dataflow.worker.util.common.worker.Operation;
import org.apache.beam.runners.dataflow.worker.util.common.worker.OperationContext;
import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputReceiver;
import org.apache.beam.runners.dataflow.worker.util.common.worker.ParDoFn;
import org.apache.beam.runners.dataflow.worker.util.common.worker.ParDoOperation;
import org.apache.beam.runners.dataflow.worker.util.common.worker.ReadOperation;
import org.apache.beam.runners.dataflow.worker.util.common.worker.Receiver;
import org.apache.beam.runners.dataflow.worker.util.common.worker.Sink;
import org.apache.beam.runners.dataflow.worker.util.common.worker.WriteOperation;
import org.apache.beam.runners.fnexecution.GrpcFnServer;
import org.apache.beam.runners.fnexecution.control.InstructionRequestHandler;
import org.apache.beam.runners.fnexecution.data.GrpcDataService;
import org.apache.beam.runners.fnexecution.state.GrpcStateService;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.fn.IdGenerator;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v20_0.com.google.common.graph.MutableNetwork;
import org.apache.beam.vendor.guava.v20_0.com.google.common.graph.Network;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** Creates a {@link DataflowMapTaskExecutor} from a {@link MapTask} definition. */
public class IntrinsicMapTaskExecutorFactory implements DataflowMapTaskExecutorFactory {
private static final Logger LOG = LoggerFactory.getLogger(IntrinsicMapTaskExecutorFactory.class);
public static IntrinsicMapTaskExecutorFactory defaultFactory() {
return new IntrinsicMapTaskExecutorFactory();
}
private IntrinsicMapTaskExecutorFactory() {}
/**
* Creates a new {@link DataflowMapTaskExecutor} from the given {@link MapTask} definition using
* the provided {@link ReaderFactory}.
*/
@Override
public DataflowMapTaskExecutor create(
InstructionRequestHandler instructionRequestHandler,
GrpcFnServer<GrpcDataService> grpcDataFnServer,
Endpoints.ApiServiceDescriptor dataApiServiceDescriptor,
GrpcFnServer<GrpcStateService> grpcStateFnServer,
MutableNetwork<Node, Edge> network,
PipelineOptions options,
String stageName,
ReaderFactory readerFactory,
SinkFactory sinkFactory,
DataflowExecutionContext<?> executionContext,
CounterSet counterSet,
IdGenerator idGenerator) {
// TODO: remove this once we trust the code paths
checkArgument(
!DataflowRunner.hasExperiment(
options.as(DataflowPipelineDebugOptions.class), "beam_fn_api"),
"experiment beam_fn_api turned on but non-Fn API MapTaskExecutorFactory invoked");
// Swap out all the InstructionOutput nodes with OutputReceiver nodes
Networks.replaceDirectedNetworkNodes(
network, createOutputReceiversTransform(stageName, counterSet));
// Swap out all the ParallelInstruction nodes with Operation nodes
Networks.replaceDirectedNetworkNodes(
network,
createOperationTransformForParallelInstructionNodes(
stageName, network, options, readerFactory, sinkFactory, executionContext));
// Collect all the operations within the network and attach all the operations as receivers
// to preceding output receivers.
List<Operation> topoSortedOperations = new ArrayList<>();
for (OperationNode node :
Iterables.filter(Networks.topologicalOrder(network), OperationNode.class)) {
topoSortedOperations.add(node.getOperation());
for (Node predecessor :
Iterables.filter(network.predecessors(node), OutputReceiverNode.class)) {
((OutputReceiverNode) predecessor)
.getOutputReceiver()
.addOutput((Receiver) node.getOperation());
}
}
if (LOG.isDebugEnabled()) {
LOG.info("Map task network: {}", Networks.toDot(network));
}
return IntrinsicMapTaskExecutor.withSharedCounterSet(
topoSortedOperations, counterSet, executionContext.getExecutionStateTracker());
}
/**
* Creates an {@link Operation} from the given {@link ParallelInstruction} definition using the
* provided {@link ReaderFactory}.
*/
Function<Node, Node> createOperationTransformForParallelInstructionNodes(
final String stageName,
final Network<Node, Edge> network,
final PipelineOptions options,
final ReaderFactory readerFactory,
final SinkFactory sinkFactory,
final DataflowExecutionContext<?> executionContext) {
return new TypeSafeNodeFunction<ParallelInstructionNode>(ParallelInstructionNode.class) {
@Override
public Node typedApply(ParallelInstructionNode node) {
ParallelInstruction instruction = node.getParallelInstruction();
NameContext nameContext =
NameContext.create(
stageName,
instruction.getOriginalName(),
instruction.getSystemName(),
instruction.getName());
try {
DataflowOperationContext context = executionContext.createOperationContext(nameContext);
if (instruction.getRead() != null) {
return createReadOperation(
network, node, options, readerFactory, executionContext, context);
} else if (instruction.getWrite() != null) {
return createWriteOperation(node, options, sinkFactory, executionContext, context);
} else if (instruction.getParDo() != null) {
return createParDoOperation(network, node, options, executionContext, context);
} else if (instruction.getPartialGroupByKey() != null) {
return createPartialGroupByKeyOperation(
network, node, options, executionContext, context);
} else if (instruction.getFlatten() != null) {
return createFlattenOperation(network, node, context);
} else {
throw new IllegalArgumentException(
String.format("Unexpected instruction: %s", instruction));
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
};
}
OperationNode createReadOperation(
Network<Node, Edge> network,
ParallelInstructionNode node,
PipelineOptions options,
ReaderFactory readerFactory,
DataflowExecutionContext<?> executionContext,
DataflowOperationContext operationContext)
throws Exception {
ParallelInstruction instruction = node.getParallelInstruction();
ReadInstruction read = instruction.getRead();
Source cloudSource = CloudSourceUtils.flattenBaseSpecs(read.getSource());
CloudObject sourceSpec = CloudObject.fromSpec(cloudSource.getSpec());
Coder<?> coder =
CloudObjects.coderFromCloudObject(CloudObject.fromSpec(cloudSource.getCodec()));
NativeReader<?> reader =
readerFactory.create(sourceSpec, coder, options, executionContext, operationContext);
OutputReceiver[] receivers = getOutputReceivers(network, node);
return OperationNode.create(ReadOperation.create(reader, receivers, operationContext));
}
OperationNode createWriteOperation(
ParallelInstructionNode node,
PipelineOptions options,
SinkFactory sinkFactory,
DataflowExecutionContext executionContext,
DataflowOperationContext context)
throws Exception {
ParallelInstruction instruction = node.getParallelInstruction();
WriteInstruction write = instruction.getWrite();
Coder<?> coder =
CloudObjects.coderFromCloudObject(CloudObject.fromSpec(write.getSink().getCodec()));
CloudObject cloudSink = CloudObject.fromSpec(write.getSink().getSpec());
Sink<?> sink = sinkFactory.create(cloudSink, coder, options, executionContext, context);
return OperationNode.create(WriteOperation.create(sink, EMPTY_OUTPUT_RECEIVER_ARRAY, context));
}
private ParDoFnFactory parDoFnFactory = new DefaultParDoFnFactory();
private OperationNode createParDoOperation(
Network<Node, Edge> network,
ParallelInstructionNode node,
PipelineOptions options,
DataflowExecutionContext<?> executionContext,
DataflowOperationContext operationContext)
throws Exception {
ParallelInstruction instruction = node.getParallelInstruction();
ParDoInstruction parDo = instruction.getParDo();
TupleTag<?> mainOutputTag = tupleTag(parDo.getMultiOutputInfos().get(0));
ImmutableMap.Builder<TupleTag<?>, Integer> outputTagsToReceiverIndicesBuilder =
ImmutableMap.builder();
int successorOffset = 0;
for (Node successor : network.successors(node)) {
for (Edge edge : network.edgesConnecting(node, successor)) {
outputTagsToReceiverIndicesBuilder.put(
tupleTag(((MultiOutputInfoEdge) edge).getMultiOutputInfo()), successorOffset);
}
successorOffset += 1;
}
ParDoFn fn =
parDoFnFactory.create(
options,
CloudObject.fromSpec(parDo.getUserFn()),
parDo.getSideInputs(),
mainOutputTag,
outputTagsToReceiverIndicesBuilder.build(),
executionContext,
operationContext);
OutputReceiver[] receivers = getOutputReceivers(network, node);
return OperationNode.create(new ParDoOperation(fn, receivers, operationContext));
}
private static <V> TupleTag<V> tupleTag(MultiOutputInfo multiOutputInfo) {
return new TupleTag<>(multiOutputInfo.getTag());
}
<K> OperationNode createPartialGroupByKeyOperation(
Network<Node, Edge> network,
ParallelInstructionNode node,
PipelineOptions options,
DataflowExecutionContext<?> executionContext,
DataflowOperationContext operationContext)
throws Exception {
ParallelInstruction instruction = node.getParallelInstruction();
PartialGroupByKeyInstruction pgbk = instruction.getPartialGroupByKey();
OutputReceiver[] receivers = getOutputReceivers(network, node);
Coder<?> windowedCoder =
CloudObjects.coderFromCloudObject(CloudObject.fromSpec(pgbk.getInputElementCodec()));
if (!(windowedCoder instanceof WindowedValueCoder)) {
throw new IllegalArgumentException(
String.format(
"unexpected kind of input coder for PartialGroupByKeyOperation: %s", windowedCoder));
}
Coder<?> elemCoder = ((WindowedValueCoder<?>) windowedCoder).getValueCoder();
if (!(elemCoder instanceof KvCoder)) {
throw new IllegalArgumentException(
String.format(
"unexpected kind of input element coder for PartialGroupByKeyOperation: %s",
elemCoder));
}
@SuppressWarnings("unchecked")
KvCoder<K, ?> keyedElementCoder = (KvCoder<K, ?>) elemCoder;
CloudObject cloudUserFn =
pgbk.getValueCombiningFn() != null
? CloudObject.fromSpec(pgbk.getValueCombiningFn())
: null;
ParDoFn fn =
PartialGroupByKeyParDoFns.create(
options,
keyedElementCoder,
cloudUserFn,
pgbk.getSideInputs(),
Arrays.<Receiver>asList(receivers),
executionContext,
operationContext);
return OperationNode.create(new ParDoOperation(fn, receivers, operationContext));
}
OperationNode createFlattenOperation(
Network<Node, Edge> network, ParallelInstructionNode node, OperationContext context) {
OutputReceiver[] receivers = getOutputReceivers(network, node);
return OperationNode.create(new FlattenOperation(receivers, context));
}
/**
* Returns a function which can convert {@link InstructionOutput}s into {@link OutputReceiver}s.
*/
static Function<Node, Node> createOutputReceiversTransform(
final String stageName, final CounterFactory counterFactory) {
return new TypeSafeNodeFunction<InstructionOutputNode>(InstructionOutputNode.class) {
@Override
public Node typedApply(InstructionOutputNode input) {
InstructionOutput cloudOutput = input.getInstructionOutput();
OutputReceiver outputReceiver = new OutputReceiver();
Coder<?> coder =
CloudObjects.coderFromCloudObject(CloudObject.fromSpec(cloudOutput.getCodec()));
@SuppressWarnings("unchecked")
ElementCounter outputCounter =
new DataflowOutputCounter(
cloudOutput.getName(),
new ElementByteSizeObservableCoder<>(coder),
counterFactory,
NameContext.create(
stageName,
cloudOutput.getOriginalName(),
cloudOutput.getSystemName(),
cloudOutput.getName()));
outputReceiver.addOutputCounter(outputCounter);
return OutputReceiverNode.create(outputReceiver, coder, input.getPcollectionId());
}
};
}
private static final OutputReceiver[] EMPTY_OUTPUT_RECEIVER_ARRAY = new OutputReceiver[0];
static OutputReceiver[] getOutputReceivers(Network<Node, Edge> network, Node node) {
int outDegree = network.outDegree(node);
if (outDegree == 0) {
return EMPTY_OUTPUT_RECEIVER_ARRAY;
}
OutputReceiver[] receivers = new OutputReceiver[outDegree];
Iterator<Node> receiverNodes = network.successors(node).iterator();
int i = 0;
do {
receivers[i] = ((OutputReceiverNode) receiverNodes.next()).getOutputReceiver();
i += 1;
} while (receiverNodes.hasNext());
return receivers;
}
/** Adapts a Coder to the ElementByteSizeObservable interface. */
public static class ElementByteSizeObservableCoder<T> implements ElementByteSizeObservable<T> {
final Coder<T> coder;
public ElementByteSizeObservableCoder(Coder<T> coder) {
this.coder = coder;
}
@Override
public boolean isRegisterByteSizeObserverCheap(T value) {
return coder.isRegisterByteSizeObserverCheap(value);
}
@Override
public void registerByteSizeObserver(T value, ElementByteSizeObserver observer)
throws Exception {
coder.registerByteSizeObserver(value, observer);
}
}
}