blob: 5cf303ce191f9e8541be1be090ecf6344d1b7461 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.graph;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import com.google.api.services.dataflow.model.InstructionOutput;
import com.google.api.services.dataflow.model.ParDoInstruction;
import com.google.api.services.dataflow.model.ParallelInstruction;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import javax.annotation.Nullable;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
import org.apache.beam.model.pipeline.v1.RunnerApi.SdkFunctionSpec;
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.ParDoTranslation;
import org.apache.beam.runners.core.construction.PipelineTranslation;
import org.apache.beam.runners.core.construction.SdkComponents;
import org.apache.beam.runners.dataflow.util.CloudObject;
import org.apache.beam.runners.dataflow.util.PropertyNames;
import org.apache.beam.runners.dataflow.worker.NameContextsForTests;
import org.apache.beam.runners.dataflow.worker.graph.Edges.DefaultEdge;
import org.apache.beam.runners.dataflow.worker.graph.Edges.Edge;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutionLocation;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.FetchAndFilterStreamingSideInputsNode;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.Node;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.InvalidProtocolBufferException;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Equivalence;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Equivalence.Wrapper;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.graph.ImmutableNetwork;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.graph.MutableNetwork;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.graph.Network;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.graph.NetworkBuilder;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Tests for {@link InsertFetchAndFilterStreamingSideInputNodes}. */
@RunWith(JUnit4.class)
public class InsertFetchAndFilterStreamingSideInputNodesTest {
@Test
public void testWithoutPipeline() throws Exception {
Node unknown = createParDoNode("parDoId");
MutableNetwork<Node, Edge> network = createEmptyNetwork();
network.addNode(unknown);
Network<Node, Edge> inputNetwork = ImmutableNetwork.copyOf(network);
network = InsertFetchAndFilterStreamingSideInputNodes.with(null).forNetwork(network);
assertThatNetworksAreIdentical(inputNetwork, network);
}
@Test
public void testSdkParDoWithSideInput() throws Exception {
Pipeline p = Pipeline.create();
PCollection<String> pc = p.apply(Create.of("a", "b", "c"));
PCollectionView<Iterable<String>> pcView = pc.apply(View.asIterable());
pc.apply(ParDo.of(new TestDoFn(pcView)).withSideInputs(pcView));
RunnerApi.Pipeline pipeline = PipelineTranslation.toProto(p);
Node predecessor = createParDoNode("predecessor");
InstructionOutputNode mainInput =
InstructionOutputNode.create(new InstructionOutput(), "fakeId");
Node sideInputParDo = createParDoNode(findParDoWithSideInput(pipeline));
MutableNetwork<Node, Edge> network = createEmptyNetwork();
network.addNode(predecessor);
network.addNode(mainInput);
network.addNode(sideInputParDo);
network.addEdge(predecessor, mainInput, DefaultEdge.create());
network.addEdge(mainInput, sideInputParDo, DefaultEdge.create());
Network<Node, Edge> inputNetwork = ImmutableNetwork.copyOf(network);
network = InsertFetchAndFilterStreamingSideInputNodes.with(pipeline).forNetwork(network);
Node mainInputClone = InstructionOutputNode.create(mainInput.getInstructionOutput(), "fakeId");
Node fetchAndFilter =
FetchAndFilterStreamingSideInputsNode.create(
pcView.getWindowingStrategyInternal(),
ImmutableMap.of(
pcView,
ParDoTranslation.translateWindowMappingFn(
pcView.getWindowMappingFn(),
SdkComponents.create(PipelineOptionsFactory.create()))),
NameContextsForTests.nameContextForTest());
MutableNetwork<Node, Edge> expectedNetwork = createEmptyNetwork();
expectedNetwork.addNode(predecessor);
expectedNetwork.addNode(mainInputClone);
expectedNetwork.addNode(fetchAndFilter);
expectedNetwork.addNode(mainInput);
expectedNetwork.addNode(sideInputParDo);
expectedNetwork.addEdge(predecessor, mainInputClone, DefaultEdge.create());
expectedNetwork.addEdge(mainInputClone, fetchAndFilter, DefaultEdge.create());
expectedNetwork.addEdge(fetchAndFilter, mainInput, DefaultEdge.create());
expectedNetwork.addEdge(mainInput, sideInputParDo, DefaultEdge.create());
assertThatNetworksAreIdentical(expectedNetwork, network);
}
@Test
public void testSdkParDoWithoutSideInput() throws Exception {
Pipeline p = Pipeline.create();
PCollection<String> pc = p.apply(Create.of("a", "b", "c"));
pc.apply(ParDo.of(new TestDoFn(null)));
RunnerApi.Pipeline pipeline = PipelineTranslation.toProto(p);
Node predecessor = createParDoNode("predecessor");
Node mainInput = InstructionOutputNode.create(new InstructionOutput(), "fakeId");
Node sideInputParDo = createParDoNode("noSideInput");
MutableNetwork<Node, Edge> network = createEmptyNetwork();
network.addNode(predecessor);
network.addNode(mainInput);
network.addNode(sideInputParDo);
network.addEdge(predecessor, mainInput, DefaultEdge.create());
network.addEdge(mainInput, sideInputParDo, DefaultEdge.create());
Network<Node, Edge> inputNetwork = ImmutableNetwork.copyOf(network);
network = InsertFetchAndFilterStreamingSideInputNodes.with(pipeline).forNetwork(network);
assertThatNetworksAreIdentical(inputNetwork, network);
}
private String findParDoWithSideInput(RunnerApi.Pipeline pipeline) {
for (Map.Entry<String, RunnerApi.PTransform> entry :
pipeline.getComponents().getTransformsMap().entrySet()) {
if (!PTransformTranslation.PAR_DO_TRANSFORM_URN.equals(entry.getValue().getSpec().getUrn())) {
continue;
}
try {
ParDoPayload payload = ParDoPayload.parseFrom(entry.getValue().getSpec().getPayload());
if (!payload.getSideInputsMap().isEmpty()) {
return entry.getKey();
}
} catch (InvalidProtocolBufferException e) {
throw new IllegalStateException(String.format("Failed to parse PTransform %s", entry));
}
}
throw new IllegalStateException("No side input ptransform found");
}
private static class TestDoFn extends DoFn<String, Iterable<String>> {
@Nullable private final PCollectionView<Iterable<String>> pCollectionView;
private TestDoFn(@Nullable PCollectionView<Iterable<String>> pCollectionView) {
this.pCollectionView = pCollectionView;
}
@ProcessElement
public void processElement(ProcessContext context) {}
}
private static final class NodeEquivalence extends Equivalence<Node> {
static final NodeEquivalence INSTANCE = new NodeEquivalence();
@Override
protected boolean doEquivalent(Node a, Node b) {
if (a instanceof FetchAndFilterStreamingSideInputsNode
&& b instanceof FetchAndFilterStreamingSideInputsNode) {
FetchAndFilterStreamingSideInputsNode nodeA = (FetchAndFilterStreamingSideInputsNode) a;
FetchAndFilterStreamingSideInputsNode nodeB = (FetchAndFilterStreamingSideInputsNode) b;
Map.Entry<PCollectionView<?>, SdkFunctionSpec> nodeAEntry =
Iterables.getOnlyElement(nodeA.getPCollectionViewsToWindowMappingFns().entrySet());
Map.Entry<PCollectionView<?>, SdkFunctionSpec> nodeBEntry =
Iterables.getOnlyElement(nodeB.getPCollectionViewsToWindowMappingFns().entrySet());
return Objects.equals(
nodeAEntry.getKey().getTagInternal(), nodeBEntry.getKey().getTagInternal())
&& Objects.equals(nodeAEntry.getValue(), nodeBEntry.getValue());
} else if (a instanceof InstructionOutputNode && b instanceof InstructionOutputNode) {
return Objects.equals(
((InstructionOutputNode) a).getInstructionOutput(),
((InstructionOutputNode) b).getInstructionOutput());
} else {
return a.equals(b); // Make sure that other nodes haven't been modified
}
}
@Override
protected int doHash(Node n) {
return n.hashCode();
}
}
/**
* Asserts that the structure and nodes of two graphs are identical except for the deduced
* ExecutionLocations, and that all paths through the graph still exist.
*/
private void assertThatNetworksAreIdentical(
Network<Node, Edge> oldNetwork, Network<Node, Edge> newNetwork) {
// Assert that both networks still have same number of nodes and edges.
assertEquals(oldNetwork.nodes().size(), newNetwork.nodes().size());
assertEquals(oldNetwork.edges().size(), newNetwork.edges().size());
// Assert that all paths still exist with identical nodes in each path.
List<List<Wrapper<Node>>> oldPaths = allPathsWithWrappedNodes(oldNetwork);
List<List<Equivalence.Wrapper<Node>>> newPaths = allPathsWithWrappedNodes(newNetwork);
assertThat(oldPaths, containsInAnyOrder(newPaths.toArray()));
}
private List<List<Equivalence.Wrapper<Node>>> allPathsWithWrappedNodes(
Network<Node, Edge> network) {
List<List<Node>> paths = Networks.allPathsFromRootsToLeaves(network);
List<List<Equivalence.Wrapper<Node>>> wrappedPaths = new ArrayList<>();
for (List<Node> path : paths) {
List<Equivalence.Wrapper<Node>> wrappedPath = new ArrayList<>();
for (Node node : path) {
wrappedPath.add(NodeEquivalence.INSTANCE.wrap(node));
}
wrappedPaths.add(wrappedPath);
}
return wrappedPaths;
}
private static MutableNetwork<Node, Edge> createEmptyNetwork() {
return NetworkBuilder.directed()
.allowsSelfLoops(false)
.allowsParallelEdges(true)
.<Node, Edge>build();
}
private static ParallelInstructionNode createParDoNode(String parDoId) {
CloudObject userFn = CloudObject.forClassName("DoFn");
userFn.put(PropertyNames.SERIALIZED_FN, parDoId);
return ParallelInstructionNode.create(
new ParallelInstruction().setParDo(new ParDoInstruction().setUserFn(userFn)),
ExecutionLocation.SDK_HARNESS);
}
}