| /* |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.tez.dag.app.dag; |
| |
| import static org.junit.Assert.assertEquals; |
| import static org.mockito.Matchers.any; |
| import static org.mockito.Mockito.RETURNS_DEEP_STUBS; |
| import static org.mockito.Mockito.doReturn; |
| import static org.mockito.Mockito.mock; |
| import static org.mockito.Mockito.never; |
| import static org.mockito.Mockito.reset; |
| import static org.mockito.Mockito.times; |
| import static org.mockito.Mockito.verify; |
| import static org.mockito.Mockito.when; |
| |
| import java.io.IOException; |
| import java.util.Collections; |
| import java.util.List; |
| |
| import com.google.common.collect.Lists; |
| import org.apache.hadoop.security.UserGroupInformation; |
| import org.apache.hadoop.yarn.api.records.ApplicationId; |
| import org.apache.hadoop.yarn.event.EventHandler; |
| import org.apache.tez.dag.api.InputDescriptor; |
| import org.apache.tez.dag.api.InputInitializerDescriptor; |
| import org.apache.tez.dag.api.RootInputLeafOutput; |
| import org.apache.tez.dag.api.TezException; |
| import org.apache.tez.dag.api.oldrecords.TaskState; |
| import org.apache.tez.dag.app.AppContext; |
| import org.apache.tez.dag.records.TezDAGID; |
| import org.apache.tez.dag.records.TezTaskAttemptID; |
| import org.apache.tez.dag.records.TezTaskID; |
| import org.apache.tez.dag.records.TezVertexID; |
| import org.apache.tez.runtime.api.Event; |
| import org.apache.tez.runtime.api.InputInitializer; |
| import org.apache.tez.runtime.api.InputInitializerContext; |
| import org.apache.tez.runtime.api.events.InputInitializerEvent; |
| import org.apache.tez.runtime.api.impl.EventMetaData; |
| import org.apache.tez.runtime.api.impl.TezEvent; |
| import org.junit.Test; |
| import org.mockito.ArgumentCaptor; |
| |
| public class TestRootInputInitializerManager { |
| |
| // Simple testing. No events if task doesn't succeed. |
| // Also exercises path where two attempts are reported as successful via the stateChangeNotifier. |
| // Primarily a failure scenario, when a Task moves back to running from success |
| // Order event1, success1, event2, success2 |
| @SuppressWarnings("unchecked") |
| @Test(timeout = 5000) |
| public void testEventBeforeSuccess() throws Exception { |
| InputDescriptor id = mock(InputDescriptor.class); |
| InputInitializerDescriptor iid = mock(InputInitializerDescriptor.class); |
| RootInputLeafOutput<InputDescriptor, InputInitializerDescriptor> rootInput = |
| new RootInputLeafOutput<InputDescriptor, InputInitializerDescriptor>("InputName", id, iid); |
| |
| InputInitializer initializer = mock(InputInitializer.class); |
| InputInitializerContext initializerContext = mock(InputInitializerContext.class); |
| Vertex vertex = mock(Vertex.class); |
| StateChangeNotifier stateChangeNotifier = mock(StateChangeNotifier.class); |
| AppContext appContext = mock(AppContext.class, RETURNS_DEEP_STUBS); |
| |
| RootInputInitializerManager.InitializerWrapper initializerWrapper = |
| new RootInputInitializerManager.InitializerWrapper(rootInput, initializer, |
| initializerContext, vertex, stateChangeNotifier, appContext); |
| |
| ApplicationId appId = ApplicationId.newInstance(1000, 1); |
| TezDAGID dagId = TezDAGID.getInstance(appId, 1); |
| TezVertexID srcVertexId = TezVertexID.getInstance(dagId, 2); |
| TezTaskID srcTaskId1 = TezTaskID.getInstance(srcVertexId, 3); |
| Vertex srcVertex = mock(Vertex.class); |
| Task srcTask1 = mock(Task.class); |
| doReturn(TaskState.RUNNING).when(srcTask1).getState(); |
| doReturn(srcTask1).when(srcVertex).getTask(srcTaskId1.getId()); |
| when(appContext.getCurrentDAG().getVertex(any(String.class))).thenReturn(srcVertex); |
| |
| String srcVertexName = "srcVertexName"; |
| List<TezEvent> eventList = Lists.newLinkedList(); |
| |
| |
| // First Attempt send event |
| TezTaskAttemptID srcTaskAttemptId11 = TezTaskAttemptID.getInstance(srcTaskId1, 1); |
| EventMetaData sourceInfo11 = |
| new EventMetaData(EventMetaData.EventProducerConsumerType.PROCESSOR, srcVertexName, null, |
| srcTaskAttemptId11); |
| InputInitializerEvent e1 = InputInitializerEvent.create("fakeVertex", "fakeInput", null); |
| TezEvent te1 = new TezEvent(e1, sourceInfo11); |
| eventList.add(te1); |
| initializerWrapper.handleInputInitializerEvents(eventList); |
| |
| verify(initializer, never()).handleInputInitializerEvent(any(List.class)); |
| eventList.clear(); |
| |
| // First attempt, Task success notification |
| initializerWrapper.onTaskSucceeded(srcVertexName, srcTaskId1, srcTaskAttemptId11.getId()); |
| ArgumentCaptor<List> argumentCaptor = ArgumentCaptor.forClass(List.class); |
| verify(initializer, times(1)).handleInputInitializerEvent(argumentCaptor.capture()); |
| List<InputInitializerEvent> invokedEvents = argumentCaptor.getValue(); |
| assertEquals(1, invokedEvents.size()); |
| |
| reset(initializer); |
| |
| // 2nd attempt send event |
| TezTaskAttemptID srcTaskAttemptId12 = TezTaskAttemptID.getInstance(srcTaskId1, 2); |
| EventMetaData sourceInfo12 = |
| new EventMetaData(EventMetaData.EventProducerConsumerType.PROCESSOR, srcVertexName, null, |
| srcTaskAttemptId12); |
| InputInitializerEvent e2 = InputInitializerEvent.create("fakeVertex", "fakeInput", null); |
| TezEvent te2 = new TezEvent(e2, sourceInfo12); |
| eventList.add(te2); |
| initializerWrapper.handleInputInitializerEvents(eventList); |
| |
| verify(initializer, never()).handleInputInitializerEvent(any(List.class)); |
| eventList.clear(); |
| reset(initializer); |
| |
| // 2nd attempt succeeded |
| initializerWrapper.onTaskSucceeded(srcVertexName, srcTaskId1, srcTaskAttemptId12.getId()); |
| verify(initializer, never()).handleInputInitializerEvent(argumentCaptor.capture()); |
| } |
| |
| // Order event1 success1, success2, event2 |
| // Primarily a failure scenario, when a Task moves back to running from success |
| @SuppressWarnings("unchecked") |
| @Test(timeout = 5000) |
| public void testSuccessBeforeEvent() throws Exception { |
| InputDescriptor id = mock(InputDescriptor.class); |
| InputInitializerDescriptor iid = mock(InputInitializerDescriptor.class); |
| RootInputLeafOutput<InputDescriptor, InputInitializerDescriptor> rootInput = |
| new RootInputLeafOutput<InputDescriptor, InputInitializerDescriptor>("InputName", id, iid); |
| |
| InputInitializer initializer = mock(InputInitializer.class); |
| InputInitializerContext initializerContext = mock(InputInitializerContext.class); |
| Vertex vertex = mock(Vertex.class); |
| StateChangeNotifier stateChangeNotifier = mock(StateChangeNotifier.class); |
| AppContext appContext = mock(AppContext.class, RETURNS_DEEP_STUBS); |
| |
| RootInputInitializerManager.InitializerWrapper initializerWrapper = |
| new RootInputInitializerManager.InitializerWrapper(rootInput, initializer, |
| initializerContext, vertex, stateChangeNotifier, appContext); |
| |
| ApplicationId appId = ApplicationId.newInstance(1000, 1); |
| TezDAGID dagId = TezDAGID.getInstance(appId, 1); |
| TezVertexID srcVertexId = TezVertexID.getInstance(dagId, 2); |
| TezTaskID srcTaskId1 = TezTaskID.getInstance(srcVertexId, 3); |
| Vertex srcVertex = mock(Vertex.class); |
| Task srcTask1 = mock(Task.class); |
| doReturn(TaskState.RUNNING).when(srcTask1).getState(); |
| doReturn(srcTask1).when(srcVertex).getTask(srcTaskId1.getId()); |
| when(appContext.getCurrentDAG().getVertex(any(String.class))).thenReturn(srcVertex); |
| |
| String srcVertexName = "srcVertexName"; |
| List<TezEvent> eventList = Lists.newLinkedList(); |
| |
| |
| // First Attempt send event |
| TezTaskAttemptID srcTaskAttemptId11 = TezTaskAttemptID.getInstance(srcTaskId1, 1); |
| EventMetaData sourceInfo11 = |
| new EventMetaData(EventMetaData.EventProducerConsumerType.PROCESSOR, srcVertexName, null, |
| srcTaskAttemptId11); |
| InputInitializerEvent e1 = InputInitializerEvent.create("fakeVertex", "fakeInput", null); |
| TezEvent te1 = new TezEvent(e1, sourceInfo11); |
| eventList.add(te1); |
| initializerWrapper.handleInputInitializerEvents(eventList); |
| |
| verify(initializer, never()).handleInputInitializerEvent(any(List.class)); |
| eventList.clear(); |
| |
| // First attempt, Task success notification |
| initializerWrapper.onTaskSucceeded(srcVertexName, srcTaskId1, srcTaskAttemptId11.getId()); |
| ArgumentCaptor<List> argumentCaptor = ArgumentCaptor.forClass(List.class); |
| verify(initializer, times(1)).handleInputInitializerEvent(argumentCaptor.capture()); |
| List<InputInitializerEvent> invokedEvents = argumentCaptor.getValue(); |
| assertEquals(1, invokedEvents.size()); |
| |
| reset(initializer); |
| |
| |
| TezTaskAttemptID srcTaskAttemptId12 = TezTaskAttemptID.getInstance(srcTaskId1, 2); |
| // 2nd attempt succeeded |
| initializerWrapper.onTaskSucceeded(srcVertexName, srcTaskId1, srcTaskAttemptId12.getId()); |
| verify(initializer, never()).handleInputInitializerEvent(any(List.class)); |
| |
| // 2nd attempt send event |
| EventMetaData sourceInfo12 = |
| new EventMetaData(EventMetaData.EventProducerConsumerType.PROCESSOR, srcVertexName, null, |
| srcTaskAttemptId12); |
| InputInitializerEvent e2 = InputInitializerEvent.create("fakeVertex", "fakeInput", null); |
| TezEvent te2 = new TezEvent(e2, sourceInfo12); |
| eventList.add(te2); |
| initializerWrapper.handleInputInitializerEvents(eventList); |
| |
| verify(initializer, never()).handleInputInitializerEvent(any(List.class)); |
| } |
| |
| |
| @Test (timeout = 5000) |
| public void testCorrectUgiUsage() throws TezException, InterruptedException { |
| Vertex vertex = mock(Vertex.class); |
| doReturn(mock(TezVertexID.class)).when(vertex).getVertexId(); |
| AppContext appContext = mock(AppContext.class); |
| doReturn(mock(EventHandler.class)).when(appContext).getEventHandler(); |
| UserGroupInformation dagUgi = UserGroupInformation.createRemoteUser("fakeuser"); |
| StateChangeNotifier stateChangeNotifier = mock(StateChangeNotifier.class); |
| RootInputInitializerManager rootInputInitializerManager = new RootInputInitializerManager(vertex, appContext, dagUgi, stateChangeNotifier); |
| |
| InputDescriptor id = mock(InputDescriptor.class); |
| InputInitializerDescriptor iid = InputInitializerDescriptor.create(InputInitializerForUgiTest.class.getName()); |
| RootInputLeafOutput<InputDescriptor, InputInitializerDescriptor> rootInput = |
| new RootInputLeafOutput<InputDescriptor, InputInitializerDescriptor>("InputName", id, iid); |
| rootInputInitializerManager.runInputInitializers(Collections.singletonList(rootInput)); |
| |
| InputInitializerForUgiTest.awaitInitialize(); |
| |
| assertEquals(dagUgi, InputInitializerForUgiTest.ctorUgi); |
| assertEquals(dagUgi, InputInitializerForUgiTest.initializeUgi); |
| } |
| |
| public static class InputInitializerForUgiTest extends InputInitializer { |
| |
| static volatile UserGroupInformation ctorUgi; |
| static volatile UserGroupInformation initializeUgi; |
| |
| static boolean initialized = false; |
| static final Object initializeSync = new Object(); |
| |
| public InputInitializerForUgiTest(InputInitializerContext initializerContext) { |
| super(initializerContext); |
| try { |
| ctorUgi = UserGroupInformation.getCurrentUser(); |
| } catch (IOException e) { |
| throw new RuntimeException(e); |
| } |
| } |
| |
| @Override |
| public List<Event> initialize() throws Exception { |
| initializeUgi = UserGroupInformation.getCurrentUser(); |
| synchronized (initializeSync) { |
| initialized = true; |
| initializeSync.notify(); |
| } |
| return null; |
| } |
| |
| @Override |
| public void handleInputInitializerEvent(List<InputInitializerEvent> events) throws Exception { |
| } |
| |
| static void awaitInitialize() throws InterruptedException { |
| synchronized (initializeSync) { |
| while (!initialized) { |
| initializeSync.wait(); |
| } |
| } |
| } |
| } |
| } |