blob: cad7595cdd49488a234074f9dadf8a6290c1dbb7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker;
import static org.apache.beam.runners.dataflow.util.Structs.addString;
import static org.apache.beam.runners.dataflow.worker.DataflowOutputCounter.getElementCounterName;
import static org.apache.beam.runners.dataflow.worker.DataflowOutputCounter.getMeanByteCounterName;
import static org.apache.beam.runners.dataflow.worker.DataflowOutputCounter.getObjectCounterName;
import static org.apache.beam.runners.dataflow.worker.counters.CounterName.named;
import static org.apache.beam.sdk.util.SerializableUtils.serializeToByteArray;
import static org.apache.beam.sdk.util.StringUtils.byteArrayToJsonString;
import static org.hamcrest.Matchers.hasItems;
import static org.hamcrest.Matchers.instanceOf;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertThat;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.api.services.dataflow.model.FlattenInstruction;
import com.google.api.services.dataflow.model.InstructionInput;
import com.google.api.services.dataflow.model.InstructionOutput;
import com.google.api.services.dataflow.model.MapTask;
import com.google.api.services.dataflow.model.MultiOutputInfo;
import com.google.api.services.dataflow.model.ParDoInstruction;
import com.google.api.services.dataflow.model.ParallelInstruction;
import com.google.api.services.dataflow.model.PartialGroupByKeyInstruction;
import com.google.api.services.dataflow.model.ReadInstruction;
import com.google.api.services.dataflow.model.Source;
import com.google.api.services.dataflow.model.WriteInstruction;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Function;
import javax.annotation.Nullable;
import org.apache.beam.runners.dataflow.util.CloudObject;
import org.apache.beam.runners.dataflow.util.CloudObjects;
import org.apache.beam.runners.dataflow.util.PropertyNames;
import org.apache.beam.runners.dataflow.worker.apiary.FixMultiOutputInfosOnParDoInstructions;
import org.apache.beam.runners.dataflow.worker.counters.Counter;
import org.apache.beam.runners.dataflow.worker.counters.Counter.CounterUpdateExtractor;
import org.apache.beam.runners.dataflow.worker.counters.CounterFactory.CounterMean;
import org.apache.beam.runners.dataflow.worker.counters.CounterSet;
import org.apache.beam.runners.dataflow.worker.graph.Edges.Edge;
import org.apache.beam.runners.dataflow.worker.graph.Edges.MultiOutputInfoEdge;
import org.apache.beam.runners.dataflow.worker.graph.MapTaskToNetworkFunction;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutionLocation;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.Node;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.OperationNode;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode;
import org.apache.beam.runners.dataflow.worker.util.common.worker.FlattenOperation;
import org.apache.beam.runners.dataflow.worker.util.common.worker.Operation;
import org.apache.beam.runners.dataflow.worker.util.common.worker.ParDoOperation;
import org.apache.beam.runners.dataflow.worker.util.common.worker.ReadOperation;
import org.apache.beam.runners.dataflow.worker.util.common.worker.Sink;
import org.apache.beam.runners.dataflow.worker.util.common.worker.WriteOperation;
import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.extensions.gcp.util.Transport;
import org.apache.beam.sdk.fn.IdGenerator;
import org.apache.beam.sdk.fn.IdGenerators;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder;
import org.apache.beam.sdk.util.AppliedCombineFn;
import org.apache.beam.sdk.util.DoFnInfo;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.util.StringUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v20_0.com.google.common.graph.MutableNetwork;
import org.apache.beam.vendor.guava.v20_0.com.google.common.graph.Network;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
/** Tests for {@link IntrinsicMapTaskExecutorFactory}. */
@RunWith(JUnit4.class)
public class IntrinsicMapTaskExecutorFactoryTest {
private static final String STAGE = "test";
private static final IdGenerator idGenerator = IdGenerators.decrementingLongs();
private static final Function<MapTask, MutableNetwork<Node, Edge>> mapTaskToNetwork =
new FixMultiOutputInfosOnParDoInstructions(idGenerator)
.andThen(new MapTaskToNetworkFunction(idGenerator));
private static final CloudObject windowedStringCoder =
CloudObjects.asCloudObject(
WindowedValue.getValueOnlyCoder(StringUtf8Coder.of()), /*sdkComponents=*/ null);
private IntrinsicMapTaskExecutorFactory mapTaskExecutorFactory;
private PipelineOptions options;
private ReaderRegistry readerRegistry;
private SinkRegistry sinkRegistry;
private static final String PCOLLECTION_ID = "fakeId";
@Mock private Network<Node, Edge> network;
@Mock private CounterUpdateExtractor<?> updateExtractor;
private final CounterSet counterSet = new CounterSet();
@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
mapTaskExecutorFactory = IntrinsicMapTaskExecutorFactory.defaultFactory();
options = PipelineOptionsFactory.create();
readerRegistry =
ReaderRegistry.defaultRegistry()
.register(
ReaderFactoryTest.TestReaderFactory.class.getName(),
new ReaderFactoryTest.TestReaderFactory())
.register(
ReaderFactoryTest.SingletonTestReaderFactory.class.getName(),
new ReaderFactoryTest.SingletonTestReaderFactory());
sinkRegistry =
SinkRegistry.defaultRegistry()
.register(TestSinkFactory.class.getName(), new TestSinkFactory());
}
@Test
public void testCreateMapTaskExecutor() throws Exception {
List<ParallelInstruction> instructions =
Arrays.asList(
createReadInstruction("Read"),
createParDoInstruction(0, 0, "DoFn1"),
createParDoInstruction(0, 0, "DoFnWithContext"),
createFlattenInstruction(1, 0, 2, 0, "Flatten"),
createWriteInstruction(3, 0, "Write"));
MapTask mapTask = new MapTask();
mapTask.setStageName(STAGE);
mapTask.setSystemName("systemName");
mapTask.setInstructions(instructions);
mapTask.setFactory(Transport.getJsonFactory());
try (DataflowMapTaskExecutor executor =
mapTaskExecutorFactory.create(
null /* beamFnControlClientHandler */,
null /* GrpcFnServer<GrpcDataService> */,
null /* ApiServiceDescriptor */,
null, /* GrpcFnServer<GrpcStateService> */
mapTaskToNetwork.apply(mapTask),
options,
STAGE,
readerRegistry,
sinkRegistry,
BatchModeExecutionContext.forTesting(options, counterSet, "testStage"),
counterSet,
idGenerator)) {
// Safe covariant cast not expressible without rawtypes.
@SuppressWarnings({"rawtypes", "unchecked"})
List<Object> operations = (List) executor.operations;
assertThat(
operations,
hasItems(
instanceOf(ReadOperation.class),
instanceOf(ParDoOperation.class),
instanceOf(ParDoOperation.class),
instanceOf(FlattenOperation.class),
instanceOf(WriteOperation.class)));
// Verify that the inputs are attached.
ReadOperation readOperation =
Iterables.getOnlyElement(Iterables.filter(operations, ReadOperation.class));
assertEquals(2, readOperation.receivers[0].getReceiverCount());
FlattenOperation flattenOperation =
Iterables.getOnlyElement(Iterables.filter(operations, FlattenOperation.class));
for (ParDoOperation operation : Iterables.filter(operations, ParDoOperation.class)) {
assertSame(flattenOperation, operation.receivers[0].getOnlyReceiver());
}
WriteOperation writeOperation =
Iterables.getOnlyElement(Iterables.filter(operations, WriteOperation.class));
assertSame(writeOperation, flattenOperation.receivers[0].getOnlyReceiver());
}
@SuppressWarnings("unchecked")
Counter<Long, ?> otherMsecCounter =
(Counter<Long, ?>) counterSet.getExistingCounter("test-other-msecs");
// "other" state only got created upon MapTaskExecutor.execute().
assertNull(otherMsecCounter);
counterSet.extractUpdates(false, updateExtractor);
verifyOutputCounters(
updateExtractor,
"read_output_name",
"DoFn1_output",
"DoFnWithContext_output",
"flatten_output_name");
verify(updateExtractor).longSum(eq(named("Read-ByteCount")), anyBoolean(), anyLong());
verify(updateExtractor).longSum(eq(named("Write-ByteCount")), anyBoolean(), anyLong());
verifyNoMoreInteractions(updateExtractor);
}
private static void verifyOutputCounters(
CounterUpdateExtractor<?> updateExtractor, String... outputNames) {
for (String outputName : outputNames) {
verify(updateExtractor)
.longSum(eq(named(getElementCounterName(outputName))), anyBoolean(), anyLong());
verify(updateExtractor)
.longSum(eq(named(getObjectCounterName(outputName))), anyBoolean(), anyLong());
verify(updateExtractor)
.longMean(
eq(named(getMeanByteCounterName(outputName))),
anyBoolean(),
Mockito.<CounterMean<Long>>any());
}
}
@Test
public void testExecutionContextPlumbing() throws Exception {
List<ParallelInstruction> instructions =
Arrays.asList(
createReadInstruction("Read", ReaderFactoryTest.SingletonTestReaderFactory.class),
createParDoInstruction(0, 0, "DoFn1", "DoFnUserName"),
createParDoInstruction(1, 0, "DoFnWithContext", "DoFnWithContextUserName"));
MapTask mapTask = new MapTask();
mapTask.setStageName(STAGE);
mapTask.setInstructions(instructions);
mapTask.setFactory(Transport.getJsonFactory());
BatchModeExecutionContext context =
BatchModeExecutionContext.forTesting(options, counterSet, "testStage");
try (DataflowMapTaskExecutor executor =
mapTaskExecutorFactory.create(
null /* beamFnControlClientHandler */,
null /* beamFnDataService */,
null /* beamFnStateService */,
null,
mapTaskToNetwork.apply(mapTask),
options,
STAGE,
readerRegistry,
sinkRegistry,
context,
counterSet,
idGenerator)) {
executor.execute();
}
List<String> stepNames = new ArrayList<>();
for (BatchModeExecutionContext.StepContext stepContext : context.getAllStepContexts()) {
stepNames.add(stepContext.getNameContext().systemName());
}
assertThat(stepNames, hasItems("DoFn1", "DoFnWithContext"));
}
static ParallelInstruction createReadInstruction(String name) {
return createReadInstruction(name, ReaderFactoryTest.TestReaderFactory.class);
}
static ParallelInstruction createReadInstruction(
String name, Class<? extends ReaderFactory> readerFactoryClass) {
CloudObject spec = CloudObject.forClass(readerFactoryClass);
Source cloudSource = new Source();
cloudSource.setSpec(spec);
cloudSource.setCodec(windowedStringCoder);
ReadInstruction readInstruction = new ReadInstruction();
readInstruction.setSource(cloudSource);
InstructionOutput output = new InstructionOutput();
output.setName("read_output_name");
output.setCodec(windowedStringCoder);
output.setOriginalName("originalName");
output.setSystemName("systemName");
ParallelInstruction instruction = new ParallelInstruction();
instruction.setSystemName(name);
instruction.setOriginalName(name + "OriginalName");
instruction.setRead(readInstruction);
instruction.setOutputs(Arrays.asList(output));
return instruction;
}
@Test
public void testCreateReadOperation() throws Exception {
ParallelInstructionNode instructionNode =
ParallelInstructionNode.create(createReadInstruction("Read"), ExecutionLocation.UNKNOWN);
when(network.successors(instructionNode))
.thenReturn(
ImmutableSet.<Node>of(
IntrinsicMapTaskExecutorFactory.createOutputReceiversTransform(STAGE, counterSet)
.apply(
InstructionOutputNode.create(
instructionNode.getParallelInstruction().getOutputs().get(0),
PCOLLECTION_ID))));
when(network.outDegree(instructionNode)).thenReturn(1);
Node operationNode =
mapTaskExecutorFactory
.createOperationTransformForParallelInstructionNodes(
STAGE,
network,
PipelineOptionsFactory.create(),
readerRegistry,
sinkRegistry,
BatchModeExecutionContext.forTesting(options, counterSet, "testStage"))
.apply(instructionNode);
assertThat(operationNode, instanceOf(OperationNode.class));
assertThat(((OperationNode) operationNode).getOperation(), instanceOf(ReadOperation.class));
ReadOperation readOperation = (ReadOperation) ((OperationNode) operationNode).getOperation();
assertEquals(1, readOperation.receivers.length);
assertEquals(0, readOperation.receivers[0].getReceiverCount());
assertEquals(Operation.InitializationState.UNSTARTED, readOperation.initializationState);
assertThat(readOperation.reader, instanceOf(ReaderFactoryTest.TestReader.class));
counterSet.extractUpdates(false, updateExtractor);
verifyOutputCounters(updateExtractor, "read_output_name");
verify(updateExtractor).longSum(eq(named("Read-ByteCount")), anyBoolean(), anyLong());
verifyNoMoreInteractions(updateExtractor);
}
static ParallelInstruction createWriteInstruction(
int producerIndex, int producerOutputNum, String systemName) {
InstructionInput cloudInput = new InstructionInput();
cloudInput.setProducerInstructionIndex(producerIndex);
cloudInput.setOutputNum(producerOutputNum);
CloudObject spec =
CloudObject.forClass(IntrinsicMapTaskExecutorFactoryTest.TestSinkFactory.class);
com.google.api.services.dataflow.model.Sink cloudSink =
new com.google.api.services.dataflow.model.Sink();
cloudSink.setSpec(spec);
cloudSink.setCodec(windowedStringCoder);
WriteInstruction writeInstruction = new WriteInstruction();
writeInstruction.setInput(cloudInput);
writeInstruction.setSink(cloudSink);
ParallelInstruction instruction = new ParallelInstruction();
instruction.setWrite(writeInstruction);
instruction.setSystemName(systemName);
instruction.setOriginalName(systemName + "OriginalName");
return instruction;
}
@SuppressWarnings("unchecked")
@Test
public void testCreateWriteOperation() throws Exception {
int producerIndex = 1;
int producerOutputNum = 2;
ParallelInstructionNode instructionNode =
ParallelInstructionNode.create(
createWriteInstruction(producerIndex, producerOutputNum, "WriteOperation"),
ExecutionLocation.UNKNOWN);
Node operationNode =
mapTaskExecutorFactory
.createOperationTransformForParallelInstructionNodes(
STAGE,
network,
options,
readerRegistry,
sinkRegistry,
BatchModeExecutionContext.forTesting(options, counterSet, "testStage"))
.apply(instructionNode);
assertThat(operationNode, instanceOf(OperationNode.class));
assertThat(((OperationNode) operationNode).getOperation(), instanceOf(WriteOperation.class));
WriteOperation writeOperation = (WriteOperation) ((OperationNode) operationNode).getOperation();
assertEquals(0, writeOperation.receivers.length);
assertEquals(Operation.InitializationState.UNSTARTED, writeOperation.initializationState);
assertThat(writeOperation.sink, instanceOf(SizeReportingSinkWrapper.class));
assertThat(
((SizeReportingSinkWrapper<?>) writeOperation.sink).getUnderlyingSink(),
instanceOf(TestSink.class));
counterSet.extractUpdates(false, updateExtractor);
verify(updateExtractor).longSum(eq(named("WriteOperation-ByteCount")), anyBoolean(), anyLong());
verifyNoMoreInteractions(updateExtractor);
}
static class TestDoFn extends DoFn<String, String> {
@ProcessElement
public void processElement(ProcessContext c) {
c.output(c.element());
}
}
static class TestSink extends Sink<Integer> {
@Override
public SinkWriter<Integer> writer() {
return new TestSinkWriter();
}
/** A sink writer that drops its input values, for testing. */
static class TestSinkWriter implements SinkWriter<Integer> {
@Override
public long add(Integer outputElem) {
return 4;
}
@Override
public void close() {}
@Override
public void abort() throws IOException {
close();
}
}
}
static class TestSinkFactory implements SinkFactory {
@Override
public TestSink create(
CloudObject o,
@Nullable Coder<?> coder,
@Nullable PipelineOptions options,
@Nullable DataflowExecutionContext executionContext,
DataflowOperationContext operationContext) {
return new TestSink();
}
}
static ParallelInstruction createParDoInstruction(
int producerIndex, int producerOutputNum, String systemName) {
return createParDoInstruction(producerIndex, producerOutputNum, systemName, "");
}
static ParallelInstruction createParDoInstruction(
int producerIndex, int producerOutputNum, String systemName, String userName) {
InstructionInput cloudInput = new InstructionInput();
cloudInput.setProducerInstructionIndex(producerIndex);
cloudInput.setOutputNum(producerOutputNum);
TestDoFn fn = new TestDoFn();
String serializedFn =
StringUtils.byteArrayToJsonString(
SerializableUtils.serializeToByteArray(
DoFnInfo.forFn(
fn,
WindowingStrategy.globalDefault(),
null /* side input views */,
null /* input coder */,
new TupleTag<>(PropertyNames.OUTPUT) /* main output id */,
DoFnSchemaInformation.create())));
CloudObject cloudUserFn = CloudObject.forClassName("DoFn");
addString(cloudUserFn, PropertyNames.SERIALIZED_FN, serializedFn);
MultiOutputInfo mainOutputTag = new MultiOutputInfo();
mainOutputTag.setTag("1");
ParDoInstruction parDoInstruction = new ParDoInstruction();
parDoInstruction.setInput(cloudInput);
parDoInstruction.setNumOutputs(1);
parDoInstruction.setMultiOutputInfos(ImmutableList.of(mainOutputTag));
parDoInstruction.setUserFn(cloudUserFn);
InstructionOutput output = new InstructionOutput();
output.setName(systemName + "_output");
output.setCodec(windowedStringCoder);
output.setOriginalName("originalName");
output.setSystemName("systemName");
ParallelInstruction instruction = new ParallelInstruction();
instruction.setParDo(parDoInstruction);
instruction.setOutputs(Arrays.asList(output));
instruction.setSystemName(systemName);
instruction.setOriginalName(systemName + "OriginalName");
instruction.setName(userName);
return instruction;
}
@Test
public void testCreateParDoOperation() throws Exception {
int producerIndex = 1;
int producerOutputNum = 2;
BatchModeExecutionContext context =
BatchModeExecutionContext.forTesting(options, counterSet, "testStage");
ParallelInstructionNode instructionNode =
ParallelInstructionNode.create(
createParDoInstruction(producerIndex, producerOutputNum, "DoFn"),
ExecutionLocation.UNKNOWN);
Node outputReceiverNode =
IntrinsicMapTaskExecutorFactory.createOutputReceiversTransform(STAGE, counterSet)
.apply(
InstructionOutputNode.create(
instructionNode.getParallelInstruction().getOutputs().get(0), PCOLLECTION_ID));
when(network.successors(instructionNode)).thenReturn(ImmutableSet.of(outputReceiverNode));
when(network.outDegree(instructionNode)).thenReturn(1);
when(network.edgesConnecting(instructionNode, outputReceiverNode))
.thenReturn(
ImmutableSet.<Edge>of(
MultiOutputInfoEdge.create(
instructionNode
.getParallelInstruction()
.getParDo()
.getMultiOutputInfos()
.get(0))));
Node operationNode =
mapTaskExecutorFactory
.createOperationTransformForParallelInstructionNodes(
STAGE, network, options, readerRegistry, sinkRegistry, context)
.apply(instructionNode);
assertThat(operationNode, instanceOf(OperationNode.class));
assertThat(((OperationNode) operationNode).getOperation(), instanceOf(ParDoOperation.class));
ParDoOperation parDoOperation = (ParDoOperation) ((OperationNode) operationNode).getOperation();
assertEquals(1, parDoOperation.receivers.length);
assertEquals(0, parDoOperation.receivers[0].getReceiverCount());
assertEquals(Operation.InitializationState.UNSTARTED, parDoOperation.initializationState);
}
static ParallelInstruction createPartialGroupByKeyInstruction(
int producerIndex, int producerOutputNum) {
InstructionInput cloudInput = new InstructionInput();
cloudInput.setProducerInstructionIndex(producerIndex);
cloudInput.setOutputNum(producerOutputNum);
PartialGroupByKeyInstruction pgbkInstruction = new PartialGroupByKeyInstruction();
pgbkInstruction.setInput(cloudInput);
pgbkInstruction.setInputElementCodec(
CloudObjects.asCloudObject(
FullWindowedValueCoder.of(
KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()),
IntervalWindowCoder.of()),
/*sdkComponents=*/ null));
InstructionOutput output = new InstructionOutput();
output.setName("pgbk_output_name");
output.setCodec(
CloudObjects.asCloudObject(
KvCoder.of(StringUtf8Coder.of(), IterableCoder.of(BigEndianIntegerCoder.of())),
/*sdkComponents=*/ null));
output.setOriginalName("originalName");
output.setSystemName("systemName");
ParallelInstruction instruction = new ParallelInstruction();
instruction.setOriginalName("pgbk_original_name");
instruction.setSystemName("pgbk_system_name");
instruction.setPartialGroupByKey(pgbkInstruction);
instruction.setOutputs(Arrays.asList(output));
return instruction;
}
@Test
public void testCreatePartialGroupByKeyOperation() throws Exception {
int producerIndex = 1;
int producerOutputNum = 2;
ParallelInstructionNode instructionNode =
ParallelInstructionNode.create(
createPartialGroupByKeyInstruction(producerIndex, producerOutputNum),
ExecutionLocation.UNKNOWN);
when(network.successors(instructionNode))
.thenReturn(
ImmutableSet.<Node>of(
IntrinsicMapTaskExecutorFactory.createOutputReceiversTransform(STAGE, counterSet)
.apply(
InstructionOutputNode.create(
instructionNode.getParallelInstruction().getOutputs().get(0),
PCOLLECTION_ID))));
when(network.outDegree(instructionNode)).thenReturn(1);
Node operationNode =
mapTaskExecutorFactory
.createOperationTransformForParallelInstructionNodes(
STAGE,
network,
PipelineOptionsFactory.create(),
readerRegistry,
sinkRegistry,
BatchModeExecutionContext.forTesting(options, counterSet, "testStage"))
.apply(instructionNode);
assertThat(operationNode, instanceOf(OperationNode.class));
assertThat(((OperationNode) operationNode).getOperation(), instanceOf(ParDoOperation.class));
ParDoOperation pgbkOperation = (ParDoOperation) ((OperationNode) operationNode).getOperation();
assertEquals(1, pgbkOperation.receivers.length);
assertEquals(0, pgbkOperation.receivers[0].getReceiverCount());
assertEquals(Operation.InitializationState.UNSTARTED, pgbkOperation.initializationState);
}
@Test
public void testCreatePartialGroupByKeyOperationWithCombine() throws Exception {
int producerIndex = 1;
int producerOutputNum = 2;
ParallelInstruction instruction =
createPartialGroupByKeyInstruction(producerIndex, producerOutputNum);
AppliedCombineFn<?, ?, ?, ?> combineFn =
AppliedCombineFn.withInputCoder(
Sum.ofIntegers(),
CoderRegistry.createDefault(),
KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of()));
CloudObject cloudCombineFn = CloudObject.forClassName("CombineFn");
addString(
cloudCombineFn,
PropertyNames.SERIALIZED_FN,
byteArrayToJsonString(serializeToByteArray(combineFn)));
instruction.getPartialGroupByKey().setValueCombiningFn(cloudCombineFn);
ParallelInstructionNode instructionNode =
ParallelInstructionNode.create(instruction, ExecutionLocation.UNKNOWN);
when(network.successors(instructionNode))
.thenReturn(
ImmutableSet.<Node>of(
IntrinsicMapTaskExecutorFactory.createOutputReceiversTransform(STAGE, counterSet)
.apply(
InstructionOutputNode.create(
instructionNode.getParallelInstruction().getOutputs().get(0),
PCOLLECTION_ID))));
when(network.outDegree(instructionNode)).thenReturn(1);
Node operationNode =
mapTaskExecutorFactory
.createOperationTransformForParallelInstructionNodes(
STAGE,
network,
options,
readerRegistry,
sinkRegistry,
BatchModeExecutionContext.forTesting(options, counterSet, "testStage"))
.apply(instructionNode);
assertThat(operationNode, instanceOf(OperationNode.class));
assertThat(((OperationNode) operationNode).getOperation(), instanceOf(ParDoOperation.class));
ParDoOperation pgbkOperation = (ParDoOperation) ((OperationNode) operationNode).getOperation();
assertEquals(1, pgbkOperation.receivers.length);
assertEquals(0, pgbkOperation.receivers[0].getReceiverCount());
assertEquals(Operation.InitializationState.UNSTARTED, pgbkOperation.initializationState);
}
static ParallelInstruction createFlattenInstruction(
int producerIndex1,
int producerOutputNum1,
int producerIndex2,
int producerOutputNum2,
String systemName) {
List<InstructionInput> cloudInputs = new ArrayList<>();
InstructionInput cloudInput1 = new InstructionInput();
cloudInput1.setProducerInstructionIndex(producerIndex1);
cloudInput1.setOutputNum(producerOutputNum1);
cloudInputs.add(cloudInput1);
InstructionInput cloudInput2 = new InstructionInput();
cloudInput2.setProducerInstructionIndex(producerIndex2);
cloudInput2.setOutputNum(producerOutputNum2);
cloudInputs.add(cloudInput2);
FlattenInstruction flattenInstruction = new FlattenInstruction();
flattenInstruction.setInputs(cloudInputs);
InstructionOutput output = new InstructionOutput();
output.setName("flatten_output_name");
output.setCodec(CloudObjects.asCloudObject(StringUtf8Coder.of(), /*sdkComponents=*/ null));
output.setOriginalName("originalName");
output.setSystemName("systemName");
ParallelInstruction instruction = new ParallelInstruction();
instruction.setFlatten(flattenInstruction);
instruction.setOutputs(Arrays.asList(output));
instruction.setSystemName(systemName);
instruction.setOriginalName(systemName + "OriginalName");
return instruction;
}
@Test
public void testCreateFlattenOperation() throws Exception {
int producerIndex1 = 1;
int producerOutputNum1 = 2;
int producerIndex2 = 0;
int producerOutputNum2 = 1;
ParallelInstructionNode instructionNode =
ParallelInstructionNode.create(
createFlattenInstruction(
producerIndex1, producerOutputNum1, producerIndex2, producerOutputNum2, "Flatten"),
ExecutionLocation.UNKNOWN);
when(network.successors(instructionNode))
.thenReturn(
ImmutableSet.<Node>of(
IntrinsicMapTaskExecutorFactory.createOutputReceiversTransform(STAGE, counterSet)
.apply(
InstructionOutputNode.create(
instructionNode.getParallelInstruction().getOutputs().get(0),
PCOLLECTION_ID))));
when(network.outDegree(instructionNode)).thenReturn(1);
Node operationNode =
mapTaskExecutorFactory
.createOperationTransformForParallelInstructionNodes(
STAGE,
network,
options,
readerRegistry,
sinkRegistry,
BatchModeExecutionContext.forTesting(options, counterSet, "testStage"))
.apply(instructionNode);
assertThat(operationNode, instanceOf(OperationNode.class));
assertThat(((OperationNode) operationNode).getOperation(), instanceOf(FlattenOperation.class));
FlattenOperation flattenOperation =
(FlattenOperation) ((OperationNode) operationNode).getOperation();
assertEquals(1, flattenOperation.receivers.length);
assertEquals(0, flattenOperation.receivers[0].getReceiverCount());
assertEquals(Operation.InitializationState.UNSTARTED, flattenOperation.initializationState);
}
}