blob: acdd034985e0036f59addd29398ea7694443fc72 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.graph;
import com.google.api.services.dataflow.model.ParallelInstruction;
import java.util.Set;
import java.util.function.Function;
import org.apache.beam.runners.dataflow.worker.graph.Edges.Edge;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutionLocation;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.Node;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v20_0.com.google.common.graph.MutableNetwork;
/**
* A function which optimises the execution of {@link FlattenInstruction}s with ambiguous {@link
* ExecutionLocation}s by splitting it into two copies, with one executing on the SDK harness and
* one on the runner harness.
*
* <p>After performing this function, each of the new flattens will retain the predecessors of the
* original flatten, but only the successors that occur in the same {@link ExecutionLocation}. For
* example, the following graph:
*
* <pre><code>
* SdkPredecessor -----> out --\ /--> SdkSuccessor
* AmbiguousFlatten --> out
* RunnerPredecessor --> out --/ \--> RunnerSuccessor
* </code></pre>
*
* Should produce:
*
* <pre><code>
* SdkPredecessor -----> out --> SdkFlatten --> out --> SdkSuccessor
* X
* RunnerPredecessor --> out --> RunnerFlatten --> out --> RunnerSuccessor
* </code></pre>
*
* <p>The reason for performing this cloning is to prevent data from having to perform a "round
* trip" through gRPC ports, which is what would happen if ambiguous flattens were executed on only
* one harness. For example, if a flatten is executed in the runner harness, then the path from SDK
* predecessor to SDK successor will require data to cross to the runner harness, get flattened, and
* then cross back. With this optimization, the round trip will not occur on any paths.
*/
public class CloneAmbiguousFlattensFunction
implements Function<MutableNetwork<Node, Edge>, MutableNetwork<Node, Edge>> {
@Override
public MutableNetwork<Node, Edge> apply(MutableNetwork<Node, Edge> network) {
// Important: The cloning technique only works when the flatten being cloned has no ambiguous
// descendants, so to ensure this is always true we iterate through the network in reverse
// topological order.
Set<Node> sortedNodesSet = Networks.topologicalOrder(network);
Node[] sortedNodes = sortedNodesSet.toArray(new Node[sortedNodesSet.size()]);
for (int i = sortedNodes.length - 1; i >= 0; i--) {
Node node = sortedNodes[i];
if (node instanceof ParallelInstructionNode
&& ((ParallelInstructionNode) node).getParallelInstruction().getFlatten() != null
&& ((ParallelInstructionNode) node).getExecutionLocation()
== ExecutionLocation.AMBIGUOUS) {
cloneFlatten(node, network);
}
}
return network;
}
/**
* A helper function which performs the actual cloning procedure, which means creating the runner
* and SDK versions of both the ambiguous flatten and its PCollection, attaching the old flatten's
* predecessors and successors properly, and then removing the ambiguous flatten from the network.
*/
private void cloneFlatten(Node flatten, MutableNetwork<Node, Edge> network) {
// Start by creating the clones of the flatten and its PCollection.
InstructionOutputNode flattenOut =
(InstructionOutputNode) Iterables.getOnlyElement(network.successors(flatten));
ParallelInstruction flattenInstruction =
((ParallelInstructionNode) flatten).getParallelInstruction();
Node runnerFlatten =
ParallelInstructionNode.create(flattenInstruction, ExecutionLocation.RUNNER_HARNESS);
Node runnerFlattenOut =
InstructionOutputNode.create(
flattenOut.getInstructionOutput(), flattenOut.getPcollectionId());
network.addNode(runnerFlatten);
network.addNode(runnerFlattenOut);
Node sdkFlatten =
ParallelInstructionNode.create(flattenInstruction, ExecutionLocation.SDK_HARNESS);
Node sdkFlattenOut =
InstructionOutputNode.create(
flattenOut.getInstructionOutput(), flattenOut.getPcollectionId());
network.addNode(sdkFlatten);
network.addNode(sdkFlattenOut);
for (Edge edge : ImmutableList.copyOf(network.edgesConnecting(flatten, flattenOut))) {
network.addEdge(runnerFlatten, runnerFlattenOut, edge.clone());
network.addEdge(sdkFlatten, sdkFlattenOut, edge.clone());
}
// Copy over predecessor edges to both cloned nodes.
for (Node predecessor : network.predecessors(flatten)) {
for (Edge edge : ImmutableList.copyOf(network.edgesConnecting(predecessor, flatten))) {
network.addEdge(predecessor, runnerFlatten, edge.clone());
network.addEdge(predecessor, sdkFlatten, edge.clone());
}
}
// Copy over successor edges depending on execution locations of successors.
for (Node successor : network.successors(flattenOut)) {
// Connect successor to SDK harness only if sure it executes in SDK.
Node selectedOutput = executesInSdkHarness(successor) ? sdkFlattenOut : runnerFlattenOut;
for (Edge edge : ImmutableList.copyOf(network.edgesConnecting(flattenOut, successor))) {
network.addEdge(selectedOutput, successor, edge.clone());
}
}
network.removeNode(flatten);
network.removeNode(flattenOut);
}
/**
* Returns true iff the given node is a {@link ParallelInstruction} which represents an
* instruction that executes within the SDK harness. For details on how node locations are deduced
* refer to {@link DeduceNodeLocationsFunction#executesInSdkHarness}.
*/
private static boolean executesInSdkHarness(Node node) {
return node instanceof ParallelInstructionNode
&& ((ParallelInstructionNode) node).getExecutionLocation()
== Nodes.ExecutionLocation.SDK_HARNESS;
}
}