blob: 09d2205eaed115c7b4029ee03ce230ecc94ee5fd [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.graph;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.collection.IsEmptyCollection.emptyCollectionOf;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import com.google.api.services.dataflow.model.FlattenInstruction;
import com.google.api.services.dataflow.model.InstructionInput;
import com.google.api.services.dataflow.model.InstructionOutput;
import com.google.api.services.dataflow.model.MapTask;
import com.google.api.services.dataflow.model.MultiOutputInfo;
import com.google.api.services.dataflow.model.ParDoInstruction;
import com.google.api.services.dataflow.model.ParallelInstruction;
import com.google.api.services.dataflow.model.PartialGroupByKeyInstruction;
import com.google.api.services.dataflow.model.ReadInstruction;
import com.google.api.services.dataflow.model.WriteInstruction;
import java.util.Arrays;
import org.apache.beam.runners.dataflow.worker.graph.Edges.DefaultEdge;
import org.apache.beam.runners.dataflow.worker.graph.Edges.Edge;
import org.apache.beam.runners.dataflow.worker.graph.Edges.MultiOutputInfoEdge;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.Node;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode;
import org.apache.beam.sdk.extensions.gcp.util.Transport;
import org.apache.beam.sdk.fn.IdGenerators;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.graph.Network;
import org.hamcrest.Matchers;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Tests for {@link MapTaskToNetworkFunction}. */
@RunWith(JUnit4.class)
public class MapTaskToNetworkFunctionTest {
@Test
public void testEmptyMapTask() {
Network<Node, Edge> network =
new MapTaskToNetworkFunction(IdGenerators.decrementingLongs()).apply(new MapTask());
assertTrue(network.isDirected());
assertTrue(network.allowsParallelEdges());
assertFalse(network.allowsSelfLoops());
assertThat(network.nodes(), emptyCollectionOf(Node.class));
}
@Test
public void testRead() {
InstructionOutput readOutput = createInstructionOutput("Read.out");
ParallelInstruction read = createParallelInstruction("Read", readOutput);
read.setRead(new ReadInstruction());
MapTask mapTask = new MapTask();
mapTask.setInstructions(ImmutableList.of(read));
mapTask.setFactory(Transport.getJsonFactory());
Network<Node, Edge> network =
new MapTaskToNetworkFunction(IdGenerators.decrementingLongs()).apply(mapTask);
assertNetworkProperties(network);
assertEquals(2, network.nodes().size());
assertEquals(1, network.edges().size());
ParallelInstructionNode readNode = get(network, read);
InstructionOutputNode readOutputNode = getOnlySuccessor(network, readNode);
assertEquals(readOutput, readOutputNode.getInstructionOutput());
}
@Test
public void testParDo() {
InstructionOutput readOutput = createInstructionOutput("Read.out");
ParallelInstruction read = createParallelInstruction("Read", readOutput);
read.setRead(new ReadInstruction());
MultiOutputInfo parDoMultiOutput = createMultiOutputInfo("output");
ParDoInstruction parDoInstruction = new ParDoInstruction();
parDoInstruction.setInput(createInstructionInput(0, 0)); // Read.out
parDoInstruction.setMultiOutputInfos(ImmutableList.of(parDoMultiOutput));
InstructionOutput parDoOutput = createInstructionOutput("ParDo.out");
ParallelInstruction parDo = createParallelInstruction("ParDo", parDoOutput);
parDo.setParDo(parDoInstruction);
MapTask mapTask = new MapTask();
mapTask.setInstructions(ImmutableList.of(read, parDo));
mapTask.setFactory(Transport.getJsonFactory());
Network<Node, Edge> network =
new MapTaskToNetworkFunction(IdGenerators.decrementingLongs()).apply(mapTask);
assertNetworkProperties(network);
assertEquals(4, network.nodes().size());
assertEquals(3, network.edges().size());
ParallelInstructionNode readNode = get(network, read);
InstructionOutputNode readOutputNode = getOnlySuccessor(network, readNode);
assertEquals(readOutput, readOutputNode.getInstructionOutput());
ParallelInstructionNode parDoNode = getOnlySuccessor(network, readOutputNode);
InstructionOutputNode parDoOutputNode = getOnlySuccessor(network, parDoNode);
assertEquals(parDoOutput, parDoOutputNode.getInstructionOutput());
assertEquals(
parDoMultiOutput,
((MultiOutputInfoEdge)
Iterables.getOnlyElement(network.edgesConnecting(parDoNode, parDoOutputNode)))
.getMultiOutputInfo());
}
@Test
public void testFlatten() {
// ReadA --\
// |--> Flatten
// ReadB --/
InstructionOutput readOutputA = createInstructionOutput("ReadA.out");
ParallelInstruction readA = createParallelInstruction("ReadA", readOutputA);
readA.setRead(new ReadInstruction());
InstructionOutput readOutputB = createInstructionOutput("ReadB.out");
ParallelInstruction readB = createParallelInstruction("ReadB", readOutputB);
readB.setRead(new ReadInstruction());
FlattenInstruction flattenInstruction = new FlattenInstruction();
flattenInstruction.setInputs(
ImmutableList.of(
createInstructionInput(0, 0), // ReadA.out
createInstructionInput(1, 0))); // ReadB.out
InstructionOutput flattenOutput = createInstructionOutput("Flatten.out");
ParallelInstruction flatten = createParallelInstruction("Flatten", flattenOutput);
flatten.setFlatten(flattenInstruction);
MapTask mapTask = new MapTask();
mapTask.setInstructions(ImmutableList.of(readA, readB, flatten));
mapTask.setFactory(Transport.getJsonFactory());
Network<Node, Edge> network =
new MapTaskToNetworkFunction(IdGenerators.decrementingLongs()).apply(mapTask);
assertNetworkProperties(network);
assertEquals(6, network.nodes().size());
assertEquals(5, network.edges().size());
ParallelInstructionNode readANode = get(network, readA);
InstructionOutputNode readOutputANode = getOnlySuccessor(network, readANode);
assertEquals(readOutputA, readOutputANode.getInstructionOutput());
ParallelInstructionNode readBNode = get(network, readB);
InstructionOutputNode readOutputBNode = getOnlySuccessor(network, readBNode);
assertEquals(readOutputB, readOutputBNode.getInstructionOutput());
// Make sure the successors for both ReadA and ReadB output PCollections are the same
assertEquals(network.successors(readOutputANode), network.successors(readOutputBNode));
ParallelInstructionNode flattenNode = getOnlySuccessor(network, readOutputANode);
InstructionOutputNode flattenOutputNode = getOnlySuccessor(network, flattenNode);
assertEquals(flattenOutput, flattenOutputNode.getInstructionOutput());
}
@Test
public void testParallelEdgeFlatten() {
// /---\
// Read --> Read.out --> Flatten
// \---/
InstructionOutput readOutput = createInstructionOutput("Read.out");
ParallelInstruction read = createParallelInstruction("Read", readOutput);
read.setRead(new ReadInstruction());
FlattenInstruction flattenInstruction = new FlattenInstruction();
flattenInstruction.setInputs(
ImmutableList.of(
createInstructionInput(0, 0), // Read.out
createInstructionInput(0, 0), // Read.out
createInstructionInput(0, 0))); // Read.out
InstructionOutput flattenOutput = createInstructionOutput("Flatten.out");
ParallelInstruction flatten = createParallelInstruction("Flatten", flattenOutput);
flatten.setFlatten(flattenInstruction);
MapTask mapTask = new MapTask();
mapTask.setInstructions(ImmutableList.of(read, flatten));
mapTask.setFactory(Transport.getJsonFactory());
Network<Node, Edge> network =
new MapTaskToNetworkFunction(IdGenerators.decrementingLongs()).apply(mapTask);
assertNetworkProperties(network);
assertEquals(4, network.nodes().size());
assertEquals(5, network.edges().size());
ParallelInstructionNode readNode = get(network, read);
InstructionOutputNode readOutputNode = getOnlySuccessor(network, readNode);
assertEquals(readOutput, readOutputNode.getInstructionOutput());
ParallelInstructionNode flattenNode = getOnlySuccessor(network, readOutputNode);
// Assert that the three parallel edges are maintained
assertEquals(3, network.edgesConnecting(readOutputNode, flattenNode).size());
InstructionOutputNode flattenOutputNode = getOnlySuccessor(network, flattenNode);
assertEquals(flattenOutput, flattenOutputNode.getInstructionOutput());
}
@Test
public void testWrite() {
InstructionOutput readOutput = createInstructionOutput("Read.out");
ParallelInstruction read = createParallelInstruction("Read", readOutput);
read.setRead(new ReadInstruction());
WriteInstruction writeInstruction = new WriteInstruction();
writeInstruction.setInput(createInstructionInput(0, 0)); // Read.out
ParallelInstruction write = createParallelInstruction("Write");
write.setWrite(writeInstruction);
MapTask mapTask = new MapTask();
mapTask.setInstructions(ImmutableList.of(read, write));
mapTask.setFactory(Transport.getJsonFactory());
Network<Node, Edge> network =
new MapTaskToNetworkFunction(IdGenerators.decrementingLongs()).apply(mapTask);
assertNetworkProperties(network);
assertEquals(3, network.nodes().size());
assertEquals(2, network.edges().size());
ParallelInstructionNode readNode = get(network, read);
InstructionOutputNode readOutputNode = getOnlySuccessor(network, readNode);
assertEquals(readOutput, readOutputNode.getInstructionOutput());
ParallelInstructionNode writeNode = getOnlySuccessor(network, readOutputNode);
assertNotNull(writeNode);
}
@Test
public void testPartialGroupByKey() {
// Read --> PGBK --> Write
InstructionOutput readOutput = createInstructionOutput("Read.out");
ParallelInstruction read = createParallelInstruction("Read", readOutput);
read.setRead(new ReadInstruction());
PartialGroupByKeyInstruction pgbkInstruction = new PartialGroupByKeyInstruction();
pgbkInstruction.setInput(createInstructionInput(0, 0)); // Read.out
InstructionOutput pgbkOutput = createInstructionOutput("PGBK.out");
ParallelInstruction pgbk = createParallelInstruction("PGBK", pgbkOutput);
pgbk.setPartialGroupByKey(pgbkInstruction);
WriteInstruction writeInstruction = new WriteInstruction();
writeInstruction.setInput(createInstructionInput(1, 0)); // PGBK.out
ParallelInstruction write = createParallelInstruction("Write");
write.setWrite(writeInstruction);
MapTask mapTask = new MapTask();
mapTask.setInstructions(ImmutableList.of(read, pgbk, write));
mapTask.setFactory(Transport.getJsonFactory());
Network<Node, Edge> network =
new MapTaskToNetworkFunction(IdGenerators.decrementingLongs()).apply(mapTask);
assertNetworkProperties(network);
assertEquals(5, network.nodes().size());
assertEquals(4, network.edges().size());
ParallelInstructionNode readNode = get(network, read);
InstructionOutputNode readOutputNode = getOnlySuccessor(network, readNode);
assertEquals(readOutput, readOutputNode.getInstructionOutput());
ParallelInstructionNode pgbkNode = getOnlySuccessor(network, readOutputNode);
InstructionOutputNode pgbkOutputNode = getOnlySuccessor(network, pgbkNode);
assertEquals(pgbkOutput, pgbkOutputNode.getInstructionOutput());
ParallelInstructionNode writeNode = getOnlySuccessor(network, pgbkOutputNode);
assertNotNull(write);
}
@Test
public void testMultipleOutput() {
// /---> WriteA
// Read ---> ParDo
// \---> WriteB
InstructionOutput readOutput = createInstructionOutput("Read.out");
ParallelInstruction read = createParallelInstruction("Read", readOutput);
read.setRead(new ReadInstruction());
MultiOutputInfo parDoMultiOutput1 = createMultiOutputInfo("output1");
MultiOutputInfo parDoMultiOutput2 = createMultiOutputInfo("output2");
ParDoInstruction parDoInstruction = new ParDoInstruction();
parDoInstruction.setInput(createInstructionInput(0, 0)); // Read.out
parDoInstruction.setMultiOutputInfos(ImmutableList.of(parDoMultiOutput1, parDoMultiOutput2));
InstructionOutput parDoOutput1 = createInstructionOutput("ParDo.out1");
InstructionOutput parDoOutput2 = createInstructionOutput("ParDo.out2");
ParallelInstruction parDo = createParallelInstruction("ParDo", parDoOutput1, parDoOutput2);
parDo.setParDo(parDoInstruction);
WriteInstruction writeAInstruction = new WriteInstruction();
writeAInstruction.setInput(createInstructionInput(1, 0)); // ParDo.out1
ParallelInstruction writeA = createParallelInstruction("WriteA");
writeA.setWrite(writeAInstruction);
WriteInstruction writeBInstruction = new WriteInstruction();
writeBInstruction.setInput(createInstructionInput(1, 1)); // ParDo.out2
ParallelInstruction writeB = createParallelInstruction("WriteB");
writeB.setWrite(writeBInstruction);
MapTask mapTask = new MapTask();
mapTask.setInstructions(ImmutableList.of(read, parDo, writeA, writeB));
mapTask.setFactory(Transport.getJsonFactory());
Network<Node, Edge> network =
new MapTaskToNetworkFunction(IdGenerators.decrementingLongs()).apply(mapTask);
assertNetworkProperties(network);
assertEquals(7, network.nodes().size());
assertEquals(6, network.edges().size());
ParallelInstructionNode parDoNode = get(network, parDo);
ParallelInstructionNode writeANode = get(network, writeA);
ParallelInstructionNode writeBNode = get(network, writeB);
InstructionOutputNode parDoOutput1Node = getOnlyPredecessor(network, writeANode);
assertEquals(parDoOutput1, parDoOutput1Node.getInstructionOutput());
InstructionOutputNode parDoOutput2Node = getOnlyPredecessor(network, writeBNode);
assertEquals(parDoOutput2, parDoOutput2Node.getInstructionOutput());
assertThat(
network.successors(parDoNode),
Matchers.<Node>containsInAnyOrder(parDoOutput1Node, parDoOutput2Node));
assertEquals(
parDoMultiOutput1,
((MultiOutputInfoEdge)
Iterables.getOnlyElement(network.edgesConnecting(parDoNode, parDoOutput1Node)))
.getMultiOutputInfo());
assertEquals(
parDoMultiOutput2,
((MultiOutputInfoEdge)
Iterables.getOnlyElement(network.edgesConnecting(parDoNode, parDoOutput2Node)))
.getMultiOutputInfo());
}
private static void assertNetworkProperties(Network<Node, Edge> network) {
assertTrue(network.isDirected());
assertFalse(network.allowsSelfLoops());
for (Node node : network.nodes()) {
if (node instanceof ParallelInstructionNode) {
assertNull(((ParallelInstructionNode) node).getParallelInstruction().getOutputs());
// Validate that all successors are PCollections with DefaultEdge outgoing edges
// except for ParDoInstructions
for (Node successor : network.successors(node)) {
assertThat(successor, instanceOf(InstructionOutputNode.class));
// Assert that all outgoing edges for a ParDo are MultiOutputInfoEdges
if (((ParallelInstructionNode) node).getParallelInstruction().getParDo() != null) {
for (Edge edge : network.edgesConnecting(node, successor)) {
assertThat(edge, instanceOf(MultiOutputInfoEdge.class));
}
} else {
for (Edge edge : network.edgesConnecting(node, successor)) {
assertThat(edge, instanceOf(DefaultEdge.class));
}
}
}
} else if (node instanceof InstructionOutputNode) {
assertThat(network.inDegree(node), greaterThanOrEqualTo(1));
// Validate that all successors are instructions with DefaultEdge outgoing edges
for (Node successor : network.successors(node)) {
assertThat(successor, instanceOf(ParallelInstructionNode.class));
for (Edge edge : network.edgesConnecting(node, successor)) {
assertThat(edge, instanceOf(DefaultEdge.class));
}
}
}
}
}
private static ParallelInstructionNode get(
Network<Node, Edge> network, ParallelInstruction instruction) {
for (Node node : network.nodes()) {
if (node instanceof ParallelInstructionNode
&& ((ParallelInstructionNode) node)
.getParallelInstruction()
.getName()
.equals(instruction.getName())) {
return (ParallelInstructionNode) node;
}
}
throw new AssertionError(
String.format("Unable to find instruction %s in network %s", instruction, network));
}
private static ParallelInstructionNode getOnlySuccessor(
Network<Node, Edge> network, InstructionOutputNode node) {
return (ParallelInstructionNode) Iterables.getOnlyElement(network.successors(node));
}
private static InstructionOutputNode getOnlySuccessor(
Network<Node, Edge> network, ParallelInstructionNode node) {
return (InstructionOutputNode) Iterables.getOnlyElement(network.successors(node));
}
private static InstructionOutputNode getOnlyPredecessor(
Network<Node, Edge> network, ParallelInstructionNode node) {
return (InstructionOutputNode) Iterables.getOnlyElement(network.predecessors(node));
}
private static InstructionOutput createInstructionOutput(String name) {
InstructionOutput rval = new InstructionOutput();
rval.setName(name);
return rval;
}
private static ParallelInstruction createParallelInstruction(
String name, InstructionOutput... outputs) {
ParallelInstruction rval = new ParallelInstruction();
rval.setName(name);
rval.setOutputs(Arrays.asList(outputs));
return rval;
}
private static MultiOutputInfo createMultiOutputInfo(String tag) {
MultiOutputInfo rval = new MultiOutputInfo();
rval.setTag(tag);
return rval;
}
private static InstructionInput createInstructionInput(int instructionIndex, int outputNum) {
InstructionInput rval = new InstructionInput();
rval.setProducerInstructionIndex(instructionIndex);
rval.setOutputNum(outputNum);
return rval;
}
}