| /** |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.tez.runtime; |
| |
| import static org.junit.Assert.assertEquals; |
| import static org.junit.Assert.assertNull; |
| import static org.junit.Assert.assertTrue; |
| import static org.junit.Assert.fail; |
| import static org.mockito.Mockito.mock; |
| |
| import java.nio.ByteBuffer; |
| import java.util.Arrays; |
| import java.util.HashMap; |
| import java.util.LinkedList; |
| import java.util.List; |
| import java.util.Map; |
| |
| import org.apache.tez.common.TezSharedExecutor; |
| import org.apache.tez.dag.api.InputDescriptor; |
| import org.apache.tez.dag.api.OutputDescriptor; |
| import org.apache.tez.dag.api.ProcessorDescriptor; |
| import org.apache.tez.dag.api.TezConfiguration; |
| import org.apache.tez.dag.records.TezDAGID; |
| import org.apache.tez.dag.records.TezTaskAttemptID; |
| import org.apache.tez.dag.records.TezTaskID; |
| import org.apache.tez.dag.records.TezVertexID; |
| import org.apache.tez.hadoop.shim.DefaultHadoopShim; |
| import org.apache.tez.runtime.api.AbstractLogicalIOProcessor; |
| import org.apache.tez.runtime.api.AbstractLogicalInput; |
| import org.apache.tez.runtime.api.AbstractLogicalOutput; |
| import org.apache.tez.runtime.api.Event; |
| import org.apache.tez.runtime.api.LogicalInput; |
| import org.apache.tez.runtime.api.LogicalOutput; |
| import org.apache.tez.runtime.api.Reader; |
| import org.apache.tez.runtime.api.InputContext; |
| import org.apache.tez.runtime.api.OutputContext; |
| import org.apache.tez.runtime.api.ProcessorContext; |
| import org.apache.tez.runtime.api.Writer; |
| import org.apache.tez.runtime.api.events.CompositeDataMovementEvent; |
| import org.apache.tez.runtime.api.impl.ExecutionContextImpl; |
| import org.apache.tez.runtime.api.impl.InputSpec; |
| import org.apache.tez.runtime.api.impl.OutputSpec; |
| import org.apache.tez.runtime.api.impl.TaskSpec; |
| import org.apache.tez.runtime.api.impl.TezUmbilical; |
| import org.apache.tez.runtime.common.resources.ScalingAllocator; |
| import org.junit.Test; |
| |
| import com.google.common.collect.HashMultimap; |
| import com.google.common.collect.Lists; |
| import com.google.common.collect.Multimap; |
| import org.mockito.Mockito; |
| |
| public class TestLogicalIOProcessorRuntimeTask { |
| |
| @Test(timeout = 5000) |
| public void testAutoStart() throws Exception { |
| TezDAGID dagId = createTezDagId(); |
| TezVertexID vertexId = createTezVertexId(dagId); |
| Map<String, ByteBuffer> serviceConsumerMetadata = new HashMap<String, ByteBuffer>(); |
| Multimap<String, String> startedInputsMap = HashMultimap.create(); |
| TezUmbilical umbilical = mock(TezUmbilical.class); |
| TezConfiguration tezConf = new TezConfiguration(); |
| tezConf.set(TezConfiguration.TEZ_TASK_SCALE_MEMORY_ALLOCATOR_CLASS, |
| ScalingAllocator.class.getName()); |
| |
| TezTaskAttemptID taId1 = createTaskAttemptID(vertexId, 1); |
| TaskSpec task1 = createTaskSpec(taId1, "dag1", |
| "vertex1", 30, TestProcessor.class.getName(), |
| TestOutput.class.getName()); |
| |
| TezTaskAttemptID taId2 = createTaskAttemptID(vertexId, 2); |
| TaskSpec task2 = createTaskSpec(taId2, "dag2", |
| "vertex1", 10, TestProcessor.class.getName(), |
| TestOutput.class.getName()); |
| |
| TezSharedExecutor sharedExecutor = new TezSharedExecutor(tezConf); |
| LogicalIOProcessorRuntimeTask lio1 = new LogicalIOProcessorRuntimeTask(task1, 0, tezConf, null, |
| umbilical, serviceConsumerMetadata, new HashMap<String, String>(), startedInputsMap, null, |
| "", new ExecutionContextImpl("localhost"), Runtime.getRuntime().maxMemory(), true, |
| new DefaultHadoopShim(), sharedExecutor); |
| |
| try { |
| lio1.initialize(); |
| lio1.run(); |
| lio1.close(); |
| |
| // Input should've been started, Output should not have been started |
| assertEquals(1, TestProcessor.runCount); |
| assertEquals(1, TestInput.startCount); |
| assertEquals(0, TestOutput.startCount); |
| // test that invocations of progress are counted correctly |
| assertEquals(true, lio1.getAndClearProgressNotification()); |
| assertEquals(false, lio1.getAndClearProgressNotification()); // cleared after getting |
| assertEquals(30, TestInput.vertexParallelism); |
| assertEquals(0, TestOutput.vertexParallelism); |
| assertEquals(30, lio1.getProcessorContext().getVertexParallelism()); |
| assertEquals(30, lio1.getInputContexts().iterator().next().getVertexParallelism()); |
| assertEquals(30, lio1.getOutputContexts().iterator().next().getVertexParallelism()); |
| } catch(Exception e) { |
| fail(); |
| sharedExecutor.shutdownNow(); |
| } finally { |
| cleanupAndTest(lio1); |
| } |
| |
| // local mode |
| tezConf.setBoolean(TezConfiguration.TEZ_LOCAL_MODE, true); |
| LogicalIOProcessorRuntimeTask lio2 = new LogicalIOProcessorRuntimeTask(task2, 0, tezConf, null, |
| umbilical, serviceConsumerMetadata, new HashMap<String, String>(), startedInputsMap, null, |
| "", new ExecutionContextImpl("localhost"), Runtime.getRuntime().maxMemory(), true, |
| new DefaultHadoopShim(), sharedExecutor); |
| try { |
| lio2.initialize(); |
| lio2.run(); |
| lio2.close(); |
| |
| // Input should not have been started again, Output should not have been started |
| assertEquals(2, TestProcessor.runCount); |
| assertEquals(1, TestInput.startCount); |
| assertEquals(0, TestOutput.startCount); |
| assertEquals(30, TestInput.vertexParallelism); |
| assertEquals(0, TestOutput.vertexParallelism); |
| //Check if parallelism is available in processor/ i/p / o/p contexts |
| assertEquals(10, lio2.getProcessorContext().getVertexParallelism()); |
| assertEquals(10, lio2.getInputContexts().iterator().next().getVertexParallelism()); |
| assertEquals(10, lio2.getOutputContexts().iterator().next().getVertexParallelism()); |
| } catch(Exception e) { |
| fail(); |
| } finally { |
| cleanupAndTest(lio2); |
| sharedExecutor.shutdownNow(); |
| } |
| |
| } |
| |
| /** |
| * We should expect no events being sent to the AM if an |
| * exception happens in the close method of the processor |
| */ |
| @Test |
| @SuppressWarnings("unchecked") |
| public void testExceptionHappensInClose() throws Exception { |
| TezDAGID dagId = createTezDagId(); |
| TezVertexID vertexId = createTezVertexId(dagId); |
| Map<String, ByteBuffer> serviceConsumerMetadata = new HashMap<>(); |
| Multimap<String, String> startedInputsMap = HashMultimap.create(); |
| TezUmbilical umbilical = mock(TezUmbilical.class); |
| TezConfiguration tezConf = new TezConfiguration(); |
| tezConf.set(TezConfiguration.TEZ_TASK_SCALE_MEMORY_ALLOCATOR_CLASS, |
| ScalingAllocator.class.getName()); |
| |
| TezTaskAttemptID taId1 = createTaskAttemptID(vertexId, 1); |
| TaskSpec task1 = createTaskSpec(taId1, "dag1", "vertex1", 30, |
| FaultyTestProcessor.class.getName(), |
| TestOutputWithEvents.class.getName()); |
| |
| TezSharedExecutor sharedExecutor = new TezSharedExecutor(tezConf); |
| LogicalIOProcessorRuntimeTask lio1 = new LogicalIOProcessorRuntimeTask(task1, 0, tezConf, null, |
| umbilical, serviceConsumerMetadata, new HashMap<String, String>(), startedInputsMap, null, |
| "", new ExecutionContextImpl("localhost"), Runtime.getRuntime().maxMemory(), true, |
| new DefaultHadoopShim(), sharedExecutor); |
| |
| try { |
| lio1.initialize(); |
| lio1.run(); |
| |
| try { |
| lio1.close(); |
| fail("RuntimeException should have been thrown"); |
| } catch (RuntimeException e) { |
| // No events should be sent thorught the umbilical protocol |
| Mockito.verify(umbilical, Mockito.never()).addEvents(Mockito.anyList()); |
| } |
| } finally { |
| sharedExecutor.shutdownNow(); |
| cleanupAndTest(lio1); |
| } |
| } |
| |
| private void cleanupAndTest(LogicalIOProcessorRuntimeTask lio) throws InterruptedException { |
| |
| ProcessorContext procContext = lio.getProcessorContext(); |
| List<InputContext> inputContexts = new LinkedList<InputContext>(); |
| inputContexts.addAll(lio.getInputContexts()); |
| List<OutputContext> outputContexts = new LinkedList<OutputContext>(); |
| outputContexts.addAll(lio.getOutputContexts()); |
| |
| lio.cleanup(); |
| |
| assertTrue(procContext.getUserPayload() == null); |
| assertTrue(procContext.getObjectRegistry() == null); |
| |
| for (InputContext inputContext : inputContexts) { |
| assertTrue(inputContext.getUserPayload() == null); |
| assertTrue(inputContext.getObjectRegistry() == null); |
| } |
| |
| for (OutputContext outputContext : outputContexts) { |
| assertTrue(outputContext.getUserPayload() == null); |
| assertTrue(outputContext.getObjectRegistry() == null); |
| } |
| boolean localMode = lio.tezConf.getBoolean(TezConfiguration.TEZ_LOCAL_MODE, |
| TezConfiguration.TEZ_LOCAL_MODE_DEFAULT); |
| if (localMode) { |
| assertEquals(1, lio.inputSpecs.size()); |
| assertEquals(1, lio.outputSpecs.size()); |
| assertTrue(lio.groupInputSpecs == null || lio.groupInputSpecs.size() == 0); |
| } else { |
| assertEquals(0, lio.inputSpecs.size()); |
| assertEquals(0, lio.outputSpecs.size()); |
| assertTrue(lio.groupInputSpecs == null || lio.groupInputSpecs.size() == 0); |
| } |
| |
| assertEquals(0, lio.inputsMap.size()); |
| assertEquals(0, lio.inputContextMap.size()); |
| assertEquals(0, lio.outputsMap.size()); |
| assertEquals(0, lio.outputContextMap.size()); |
| assertNull(lio.groupInputsMap); |
| assertNull(lio.processor); |
| assertNull(lio.processorContext); |
| assertEquals(0, lio.runInputMap.size()); |
| assertEquals(0, lio.runOutputMap.size()); |
| assertEquals(0, lio.eventsToBeProcessed.size()); |
| assertNull(lio.eventRouterThread); |
| } |
| |
| private TaskSpec createTaskSpec(TezTaskAttemptID taskAttemptID, |
| String dagName, String vertexName, int parallelism, |
| String processorClassname, String outputClassName) { |
| ProcessorDescriptor processorDesc = createProcessorDescriptor(processorClassname); |
| TaskSpec taskSpec = new TaskSpec(taskAttemptID, |
| dagName, vertexName, parallelism, processorDesc, |
| createInputSpecList(), createOutputSpecList(outputClassName), null, null); |
| return taskSpec; |
| } |
| |
| private List<InputSpec> createInputSpecList() { |
| InputDescriptor inputDesc = InputDescriptor.create(TestInput.class.getName()); |
| InputSpec inputSpec = new InputSpec("inedge", inputDesc, 1); |
| return Lists.newArrayList(inputSpec); |
| } |
| |
| private List<OutputSpec> createOutputSpecList(String outputClassName) { |
| OutputDescriptor outputtDesc = OutputDescriptor.create(outputClassName); |
| OutputSpec outputSpec = new OutputSpec("outedge", outputtDesc, 1); |
| return Lists.newArrayList(outputSpec); |
| } |
| |
| private ProcessorDescriptor createProcessorDescriptor(String className) { |
| ProcessorDescriptor desc = ProcessorDescriptor.create(className); |
| return desc; |
| } |
| |
| private TezTaskAttemptID createTaskAttemptID(TezVertexID vertexId, int taskIndex) { |
| TezTaskID taskId = TezTaskID.getInstance(vertexId, taskIndex); |
| TezTaskAttemptID taskAttemptId = TezTaskAttemptID.getInstance(taskId, taskIndex); |
| return taskAttemptId; |
| } |
| |
| private TezVertexID createTezVertexId(TezDAGID dagId) { |
| return TezVertexID.getInstance(dagId, 1); |
| } |
| |
| private TezDAGID createTezDagId() { |
| return TezDAGID.getInstance("2000", 100, 1); |
| } |
| |
| public static class TestProcessor extends AbstractLogicalIOProcessor { |
| |
| public static volatile int runCount = 0; |
| |
| public TestProcessor(ProcessorContext context) { |
| super(context); |
| } |
| |
| @Override |
| public void initialize() throws Exception { |
| } |
| |
| @Override |
| public void run(Map<String, LogicalInput> inputs, Map<String, LogicalOutput> outputs) |
| throws Exception { |
| runCount++; |
| getContext().notifyProgress(); |
| } |
| |
| @Override |
| public void handleEvents(List<Event> processorEvents) { |
| } |
| |
| @Override |
| public void close() throws Exception { |
| } |
| |
| } |
| |
| public static class FaultyTestProcessor extends TestProcessor { |
| public FaultyTestProcessor(ProcessorContext context) { |
| super(context); |
| } |
| |
| @Override |
| public void close() throws Exception { |
| throw new RuntimeException(); |
| } |
| |
| } |
| |
| public static class TestInput extends AbstractLogicalInput { |
| |
| public static volatile int startCount = 0; |
| public static volatile int vertexParallelism; |
| |
| public TestInput(InputContext inputContext, int numPhysicalInputs) { |
| super(inputContext, numPhysicalInputs); |
| } |
| |
| @Override |
| public List<Event> initialize() throws Exception { |
| getContext().requestInitialMemory(0, null); |
| getContext().inputIsReady(); |
| return null; |
| } |
| |
| @Override |
| public void start() throws Exception { |
| startCount++; |
| vertexParallelism = getContext().getVertexParallelism(); |
| getContext().notifyProgress(); |
| } |
| |
| @Override |
| public Reader getReader() throws Exception { |
| return null; |
| } |
| |
| @Override |
| public void handleEvents(List<Event> inputEvents) throws Exception { |
| } |
| |
| @Override |
| public List<Event> close() throws Exception { |
| return null; |
| } |
| |
| } |
| |
| public static class TestOutput extends AbstractLogicalOutput { |
| |
| public static volatile int startCount = 0; |
| public static volatile int vertexParallelism; |
| |
| public TestOutput(OutputContext outputContext, int numPhysicalOutputs) { |
| super(outputContext, numPhysicalOutputs); |
| } |
| |
| |
| @Override |
| public List<Event> initialize() throws Exception { |
| getContext().requestInitialMemory(0, null); |
| return null; |
| } |
| |
| @Override |
| public void start() throws Exception { |
| System.err.println("Out started"); |
| startCount++; |
| vertexParallelism = getContext().getVertexParallelism(); |
| getContext().notifyProgress(); |
| } |
| |
| @Override |
| public Writer getWriter() throws Exception { |
| return null; |
| } |
| |
| @Override |
| public void handleEvents(List<Event> outputEvents) { |
| } |
| |
| @Override |
| public List<Event> close() throws Exception { |
| return null; |
| } |
| } |
| |
| public static class TestOutputWithEvents extends TestOutput { |
| |
| public static volatile int startCount = 0; |
| public static volatile int vertexParallelism; |
| |
| public TestOutputWithEvents(OutputContext outputContext, int numPhysicalOutputs) { |
| super(outputContext, numPhysicalOutputs); |
| } |
| |
| @Override |
| public List<Event> close() throws Exception { |
| return Arrays.asList( |
| CompositeDataMovementEvent.create(0, |
| 0, null)); |
| } |
| } |
| } |