blob: 1524cac3579737b879e479b870b61a79cfc4b010 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tez.runtime;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.anyList;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.only;
import static org.mockito.Mockito.verify;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.tez.common.TezExecutors;
import org.apache.tez.common.TezSharedExecutor;
import org.apache.tez.dag.api.InputDescriptor;
import org.apache.tez.dag.api.OutputDescriptor;
import org.apache.tez.dag.api.ProcessorDescriptor;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.records.TezDAGID;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.dag.records.TezTaskID;
import org.apache.tez.dag.records.TezVertexID;
import org.apache.tez.hadoop.shim.DefaultHadoopShim;
import org.apache.tez.hadoop.shim.HadoopShim;
import org.apache.tez.runtime.api.AbstractLogicalIOProcessor;
import org.apache.tez.runtime.api.AbstractLogicalInput;
import org.apache.tez.runtime.api.AbstractLogicalOutput;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.ExecutionContext;
import org.apache.tez.runtime.api.LogicalInput;
import org.apache.tez.runtime.api.LogicalOutput;
import org.apache.tez.runtime.api.ObjectRegistry;
import org.apache.tez.runtime.api.Reader;
import org.apache.tez.runtime.api.InputContext;
import org.apache.tez.runtime.api.OutputContext;
import org.apache.tez.runtime.api.ProcessorContext;
import org.apache.tez.runtime.api.Writer;
import org.apache.tez.runtime.api.events.CompositeDataMovementEvent;
import org.apache.tez.runtime.api.impl.ExecutionContextImpl;
import org.apache.tez.runtime.api.impl.InputSpec;
import org.apache.tez.runtime.api.impl.OutputSpec;
import org.apache.tez.runtime.api.impl.TaskSpec;
import org.apache.tez.runtime.api.impl.TezEvent;
import org.apache.tez.runtime.api.impl.TezUmbilical;
import org.apache.tez.runtime.common.resources.ScalingAllocator;
import org.apache.tez.runtime.task.TaskRunner2Callable;
import org.junit.Test;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Multimap;
public class TestLogicalIOProcessorRuntimeTask {
@Test(timeout = 5000)
public void testAutoStart() throws Exception {
TezDAGID dagId = createTezDagId();
TezVertexID vertexId = createTezVertexId(dagId);
Map<String, ByteBuffer> serviceConsumerMetadata = new HashMap<String, ByteBuffer>();
Multimap<String, String> startedInputsMap = HashMultimap.create();
TezUmbilical umbilical = mock(TezUmbilical.class);
TezConfiguration tezConf = new TezConfiguration();
tezConf.set(TezConfiguration.TEZ_TASK_SCALE_MEMORY_ALLOCATOR_CLASS,
ScalingAllocator.class.getName());
TezTaskAttemptID taId1 = createTaskAttemptID(vertexId, 1);
TaskSpec task1 = createTaskSpec(taId1, "dag1",
"vertex1", 30, TestProcessor.class.getName(),
TestOutput.class.getName());
TezTaskAttemptID taId2 = createTaskAttemptID(vertexId, 2);
TaskSpec task2 = createTaskSpec(taId2, "dag2",
"vertex1", 10, TestProcessor.class.getName(),
TestOutput.class.getName());
TezSharedExecutor sharedExecutor = new TezSharedExecutor(tezConf);
LogicalIOProcessorRuntimeTask lio1 = new LogicalIOProcessorRuntimeTask(task1, 0, tezConf, null,
umbilical, serviceConsumerMetadata, new HashMap<String, String>(), startedInputsMap, null,
"", new ExecutionContextImpl("localhost"), Runtime.getRuntime().maxMemory(), true,
new DefaultHadoopShim(), sharedExecutor);
try {
lio1.initialize();
lio1.run();
lio1.close();
// Input should've been started, Output should not have been started
assertEquals(1, TestProcessor.runCount);
assertEquals(1, TestInput.startCount);
assertEquals(0, TestOutput.startCount);
// test that invocations of progress are counted correctly
assertEquals(true, lio1.getAndClearProgressNotification());
assertEquals(false, lio1.getAndClearProgressNotification()); // cleared after getting
assertEquals(30, TestInput.vertexParallelism);
assertEquals(0, TestOutput.vertexParallelism);
assertEquals(30, lio1.getProcessorContext().getVertexParallelism());
assertEquals(30, lio1.getInputContexts().iterator().next().getVertexParallelism());
assertEquals(30, lio1.getOutputContexts().iterator().next().getVertexParallelism());
} catch(Exception e) {
fail();
sharedExecutor.shutdownNow();
} finally {
cleanupAndTest(lio1);
}
// local mode
tezConf.setBoolean(TezConfiguration.TEZ_LOCAL_MODE, true);
LogicalIOProcessorRuntimeTask lio2 = new LogicalIOProcessorRuntimeTask(task2, 0, tezConf, null,
umbilical, serviceConsumerMetadata, new HashMap<String, String>(), startedInputsMap, null,
"", new ExecutionContextImpl("localhost"), Runtime.getRuntime().maxMemory(), true,
new DefaultHadoopShim(), sharedExecutor);
try {
lio2.initialize();
lio2.run();
lio2.close();
// Input should not have been started again, Output should not have been started
assertEquals(2, TestProcessor.runCount);
assertEquals(1, TestInput.startCount);
assertEquals(0, TestOutput.startCount);
assertEquals(30, TestInput.vertexParallelism);
assertEquals(0, TestOutput.vertexParallelism);
//Check if parallelism is available in processor/ i/p / o/p contexts
assertEquals(10, lio2.getProcessorContext().getVertexParallelism());
assertEquals(10, lio2.getInputContexts().iterator().next().getVertexParallelism());
assertEquals(10, lio2.getOutputContexts().iterator().next().getVertexParallelism());
} catch(Exception e) {
fail();
} finally {
cleanupAndTest(lio2);
sharedExecutor.shutdownNow();
}
}
@Test
public void testEventsCantBeSentInCleanup() throws Exception {
TezDAGID dagId = createTezDagId();
TezVertexID vertexId = createTezVertexId(dagId);
Map<String, ByteBuffer> serviceConsumerMetadata = new HashMap<>();
Multimap<String, String> startedInputsMap = HashMultimap.create();
TezUmbilical umbilical = mock(TezUmbilical.class);
TezConfiguration tezConf = new TezConfiguration();
tezConf.set(TezConfiguration.TEZ_TASK_SCALE_MEMORY_ALLOCATOR_CLASS,
ScalingAllocator.class.getName());
TezTaskAttemptID taId1 = createTaskAttemptID(vertexId, 1);
TaskSpec task1 = createTaskSpec(taId1, "dag1", "vertex1", 30,
RunExceptionProcessor.class.getName(),
TestOutputWithEvents.class.getName());
TezSharedExecutor sharedExecutor = new TezSharedExecutor(tezConf);
LogicalIOProcessorRuntimeTask lio =
new CleanupLogicalIOProcessorRuntimeTask(task1, 0, tezConf, null,
umbilical, serviceConsumerMetadata, new HashMap<String, String>(),
startedInputsMap, null, "", new ExecutionContextImpl("localhost"),
Runtime.getRuntime().maxMemory(), true, new DefaultHadoopShim(),
sharedExecutor);
TaskRunner2Callable runner =
new TaskRunner2Callable(lio, UserGroupInformation.getCurrentUser(), umbilical);
runner.call();
// We verify that no events were sent
verify(umbilical, only()).addEvents(Collections.<TezEvent> emptyList());
}
/**
* We should expect no events being sent to the AM if an
* exception happens in the close method of the processor
*/
@Test
@SuppressWarnings("unchecked")
public void testExceptionHappensInClose() throws Exception {
TezDAGID dagId = createTezDagId();
TezVertexID vertexId = createTezVertexId(dagId);
Map<String, ByteBuffer> serviceConsumerMetadata = new HashMap<>();
Multimap<String, String> startedInputsMap = HashMultimap.create();
TezUmbilical umbilical = mock(TezUmbilical.class);
TezConfiguration tezConf = new TezConfiguration();
tezConf.set(TezConfiguration.TEZ_TASK_SCALE_MEMORY_ALLOCATOR_CLASS,
ScalingAllocator.class.getName());
TezTaskAttemptID taId1 = createTaskAttemptID(vertexId, 1);
TaskSpec task1 = createTaskSpec(taId1, "dag1", "vertex1", 30,
CloseExceptionProcessor.class.getName(),
TestOutputWithEvents.class.getName());
TezSharedExecutor sharedExecutor = new TezSharedExecutor(tezConf);
LogicalIOProcessorRuntimeTask lio1 = new LogicalIOProcessorRuntimeTask(task1, 0, tezConf, null,
umbilical, serviceConsumerMetadata, new HashMap<String, String>(), startedInputsMap, null,
"", new ExecutionContextImpl("localhost"), Runtime.getRuntime().maxMemory(), true,
new DefaultHadoopShim(), sharedExecutor);
try {
lio1.initialize();
lio1.run();
try {
lio1.close();
fail("RuntimeException should have been thrown");
} catch (RuntimeException e) {
// No events should be sent thorught the umbilical protocol
verify(umbilical, never()).addEvents(anyList());
}
} finally {
sharedExecutor.shutdownNow();
cleanupAndTest(lio1);
}
}
private void cleanupAndTest(LogicalIOProcessorRuntimeTask lio) throws InterruptedException {
ProcessorContext procContext = lio.getProcessorContext();
List<InputContext> inputContexts = new LinkedList<InputContext>();
inputContexts.addAll(lio.getInputContexts());
List<OutputContext> outputContexts = new LinkedList<OutputContext>();
outputContexts.addAll(lio.getOutputContexts());
lio.cleanup();
assertTrue(procContext.getUserPayload() == null);
assertTrue(procContext.getObjectRegistry() == null);
for (InputContext inputContext : inputContexts) {
assertTrue(inputContext.getUserPayload() == null);
assertTrue(inputContext.getObjectRegistry() == null);
}
for (OutputContext outputContext : outputContexts) {
assertTrue(outputContext.getUserPayload() == null);
assertTrue(outputContext.getObjectRegistry() == null);
}
boolean localMode = lio.tezConf.getBoolean(TezConfiguration.TEZ_LOCAL_MODE,
TezConfiguration.TEZ_LOCAL_MODE_DEFAULT);
if (localMode) {
assertEquals(1, lio.inputSpecs.size());
assertEquals(1, lio.outputSpecs.size());
assertTrue(lio.groupInputSpecs == null || lio.groupInputSpecs.size() == 0);
} else {
assertEquals(0, lio.inputSpecs.size());
assertEquals(0, lio.outputSpecs.size());
assertTrue(lio.groupInputSpecs == null || lio.groupInputSpecs.size() == 0);
}
assertEquals(0, lio.inputsMap.size());
assertEquals(0, lio.inputContextMap.size());
assertEquals(0, lio.outputsMap.size());
assertEquals(0, lio.outputContextMap.size());
assertNull(lio.groupInputsMap);
assertNull(lio.processor);
assertNull(lio.processorContext);
assertEquals(0, lio.runInputMap.size());
assertEquals(0, lio.runOutputMap.size());
assertEquals(0, lio.eventsToBeProcessed.size());
assertNull(lio.eventRouterThread);
}
private TaskSpec createTaskSpec(TezTaskAttemptID taskAttemptID,
String dagName, String vertexName, int parallelism,
String processorClassname, String outputClassName) {
ProcessorDescriptor processorDesc = createProcessorDescriptor(processorClassname);
TaskSpec taskSpec = new TaskSpec(taskAttemptID,
dagName, vertexName, parallelism, processorDesc,
createInputSpecList(), createOutputSpecList(outputClassName), null, null);
return taskSpec;
}
private List<InputSpec> createInputSpecList() {
InputDescriptor inputDesc = InputDescriptor.create(TestInput.class.getName());
InputSpec inputSpec = new InputSpec("inedge", inputDesc, 1);
return Lists.newArrayList(inputSpec);
}
private List<OutputSpec> createOutputSpecList(String outputClassName) {
OutputDescriptor outputtDesc = OutputDescriptor.create(outputClassName);
OutputSpec outputSpec = new OutputSpec("outedge", outputtDesc, 1);
return Lists.newArrayList(outputSpec);
}
private ProcessorDescriptor createProcessorDescriptor(String className) {
ProcessorDescriptor desc = ProcessorDescriptor.create(className);
return desc;
}
private TezTaskAttemptID createTaskAttemptID(TezVertexID vertexId, int taskIndex) {
TezTaskID taskId = TezTaskID.getInstance(vertexId, taskIndex);
TezTaskAttemptID taskAttemptId = TezTaskAttemptID.getInstance(taskId, taskIndex);
return taskAttemptId;
}
private TezVertexID createTezVertexId(TezDAGID dagId) {
return TezVertexID.getInstance(dagId, 1);
}
private TezDAGID createTezDagId() {
return TezDAGID.getInstance("2000", 100, 1);
}
private static class CleanupLogicalIOProcessorRuntimeTask
extends LogicalIOProcessorRuntimeTask {
CleanupLogicalIOProcessorRuntimeTask(TaskSpec taskSpec,
int appAttemptNumber, Configuration tezConf, String[] localDirs,
TezUmbilical tezUmbilical,
Map<String, ByteBuffer> serviceConsumerMetadata,
Map<String, String> envMap, Multimap<String, String> startedInputsMap,
ObjectRegistry objectRegistry, String pid,
org.apache.tez.runtime.api.ExecutionContext ExecutionContext,
long memAvailable, boolean updateSysCounters, HadoopShim hadoopShim,
TezExecutors sharedExecutor) throws IOException {
super(taskSpec, appAttemptNumber, tezConf, localDirs, tezUmbilical,
serviceConsumerMetadata, envMap, startedInputsMap, objectRegistry,
pid, ExecutionContext, memAvailable, updateSysCounters, hadoopShim,
sharedExecutor);
}
@Override public void cleanup() throws InterruptedException {
getOutputContexts().forEach(context
-> context.sendEvents(Arrays.asList(
CompositeDataMovementEvent.create(0, 0, null)
)));
}
}
public static class TestProcessor extends AbstractLogicalIOProcessor {
public static volatile int runCount = 0;
public TestProcessor(ProcessorContext context) {
super(context);
}
@Override
public void initialize() throws Exception {
}
@Override
public void run(Map<String, LogicalInput> inputs, Map<String, LogicalOutput> outputs)
throws Exception {
runCount++;
getContext().notifyProgress();
}
@Override
public void handleEvents(List<Event> processorEvents) {
}
@Override
public void close() throws Exception {
}
}
public static class RunExceptionProcessor
extends TestProcessor {
public RunExceptionProcessor(ProcessorContext context) {
super(context);
}
public void run(Map<String, LogicalInput> inputs,
Map<String, LogicalOutput> outputs)
throws Exception {
// This exception is thrown in purpose because we want to test this
throw new RuntimeException();
}
@Override
public void close() throws Exception {
// This exception is thrown because this method shouldn't be called
// if run has thrown an exception.
throw new RuntimeException();
}
}
public static class CloseExceptionProcessor extends TestProcessor {
public CloseExceptionProcessor(ProcessorContext context) {
super(context);
}
@Override
public void close() throws Exception {
throw new RuntimeException();
}
}
public static class TestInput extends AbstractLogicalInput {
public static volatile int startCount = 0;
public static volatile int vertexParallelism;
public TestInput(InputContext inputContext, int numPhysicalInputs) {
super(inputContext, numPhysicalInputs);
}
@Override
public List<Event> initialize() throws Exception {
getContext().requestInitialMemory(0, null);
getContext().inputIsReady();
return null;
}
@Override
public void start() throws Exception {
startCount++;
vertexParallelism = getContext().getVertexParallelism();
getContext().notifyProgress();
}
@Override
public Reader getReader() throws Exception {
return null;
}
@Override
public void handleEvents(List<Event> inputEvents) throws Exception {
}
@Override
public List<Event> close() throws Exception {
return null;
}
}
public static class TestOutput extends AbstractLogicalOutput {
public static volatile int startCount = 0;
public static volatile int vertexParallelism;
public TestOutput(OutputContext outputContext, int numPhysicalOutputs) {
super(outputContext, numPhysicalOutputs);
}
@Override
public List<Event> initialize() throws Exception {
getContext().requestInitialMemory(0, null);
return null;
}
@Override
public void start() throws Exception {
System.err.println("Out started");
startCount++;
vertexParallelism = getContext().getVertexParallelism();
getContext().notifyProgress();
}
@Override
public Writer getWriter() throws Exception {
return null;
}
@Override
public void handleEvents(List<Event> outputEvents) {
}
@Override
public List<Event> close() throws Exception {
return null;
}
}
public static class TestOutputWithEvents extends TestOutput {
public static volatile int startCount = 0;
public static volatile int vertexParallelism;
public TestOutputWithEvents(OutputContext outputContext, int numPhysicalOutputs) {
super(outputContext, numPhysicalOutputs);
}
@Override
public List<Event> close() throws Exception {
return Arrays.asList(
CompositeDataMovementEvent.create(0,
0, null));
}
}
}