/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.tez.dag.app.dag;

import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.util.Collections;
import java.util.List;

import com.google.common.collect.Lists;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.event.EventHandler;
import org.apache.tez.dag.api.InputDescriptor;
import org.apache.tez.dag.api.InputInitializerDescriptor;
import org.apache.tez.dag.api.RootInputLeafOutput;
import org.apache.tez.dag.api.TezException;
import org.apache.tez.dag.api.oldrecords.TaskState;
import org.apache.tez.dag.app.AppContext;
import org.apache.tez.dag.records.TezDAGID;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.dag.records.TezTaskID;
import org.apache.tez.dag.records.TezVertexID;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.InputInitializer;
import org.apache.tez.runtime.api.InputInitializerContext;
import org.apache.tez.runtime.api.events.InputInitializerEvent;
import org.apache.tez.runtime.api.impl.EventMetaData;
import org.apache.tez.runtime.api.impl.TezEvent;
import org.junit.Test;
import org.mockito.ArgumentCaptor;

public class TestRootInputInitializerManager {

  // Simple testing. No events if task doesn't succeed.
  // Also exercises path where two attempts are reported as successful via the stateChangeNotifier.
  // Primarily a failure scenario, when a Task moves back to running from success
  // Order event1, success1, event2, success2
  @SuppressWarnings("unchecked")
  @Test(timeout = 5000)
  public void testEventBeforeSuccess() throws Exception {
    InputDescriptor id = mock(InputDescriptor.class);
    InputInitializerDescriptor iid = mock(InputInitializerDescriptor.class);
    RootInputLeafOutput<InputDescriptor, InputInitializerDescriptor> rootInput =
        new RootInputLeafOutput<InputDescriptor, InputInitializerDescriptor>("InputName", id, iid);

    InputInitializer initializer = mock(InputInitializer.class);
    InputInitializerContext initializerContext = mock(InputInitializerContext.class);
    Vertex vertex = mock(Vertex.class);
    StateChangeNotifier stateChangeNotifier = mock(StateChangeNotifier.class);
    AppContext appContext = mock(AppContext.class, RETURNS_DEEP_STUBS);

    RootInputInitializerManager.InitializerWrapper initializerWrapper =
        new RootInputInitializerManager.InitializerWrapper(rootInput, initializer,
            initializerContext, vertex, stateChangeNotifier, appContext);

    ApplicationId appId = ApplicationId.newInstance(1000, 1);
    TezDAGID dagId = TezDAGID.getInstance(appId, 1);
    TezVertexID srcVertexId = TezVertexID.getInstance(dagId, 2);
    TezTaskID srcTaskId1 = TezTaskID.getInstance(srcVertexId, 3);
    Vertex srcVertex = mock(Vertex.class);
    Task srcTask1 = mock(Task.class);
    doReturn(TaskState.RUNNING).when(srcTask1).getState();
    doReturn(srcTask1).when(srcVertex).getTask(srcTaskId1.getId());
    when(appContext.getCurrentDAG().getVertex(any(String.class))).thenReturn(srcVertex);

    String srcVertexName = "srcVertexName";
    List<TezEvent> eventList = Lists.newLinkedList();


    // First Attempt send event
    TezTaskAttemptID srcTaskAttemptId11 = TezTaskAttemptID.getInstance(srcTaskId1, 1);
    EventMetaData sourceInfo11 =
        new EventMetaData(EventMetaData.EventProducerConsumerType.PROCESSOR, srcVertexName, null,
            srcTaskAttemptId11);
    InputInitializerEvent e1 = InputInitializerEvent.create("fakeVertex", "fakeInput", null);
    TezEvent te1 = new TezEvent(e1, sourceInfo11);
    eventList.add(te1);
    initializerWrapper.handleInputInitializerEvents(eventList);

    verify(initializer, never()).handleInputInitializerEvent(any(List.class));
    eventList.clear();

    // First attempt, Task success notification
    initializerWrapper.onTaskSucceeded(srcVertexName, srcTaskId1, srcTaskAttemptId11.getId());
    ArgumentCaptor<List> argumentCaptor = ArgumentCaptor.forClass(List.class);
    verify(initializer, times(1)).handleInputInitializerEvent(argumentCaptor.capture());
    List<InputInitializerEvent> invokedEvents = argumentCaptor.getValue();
    assertEquals(1, invokedEvents.size());

    reset(initializer);

    // 2nd attempt send event
    TezTaskAttemptID srcTaskAttemptId12 = TezTaskAttemptID.getInstance(srcTaskId1, 2);
    EventMetaData sourceInfo12 =
        new EventMetaData(EventMetaData.EventProducerConsumerType.PROCESSOR, srcVertexName, null,
            srcTaskAttemptId12);
    InputInitializerEvent e2 = InputInitializerEvent.create("fakeVertex", "fakeInput", null);
    TezEvent te2 = new TezEvent(e2, sourceInfo12);
    eventList.add(te2);
    initializerWrapper.handleInputInitializerEvents(eventList);

    verify(initializer, never()).handleInputInitializerEvent(any(List.class));
    eventList.clear();
    reset(initializer);

    // 2nd attempt succeeded
    initializerWrapper.onTaskSucceeded(srcVertexName, srcTaskId1, srcTaskAttemptId12.getId());
    verify(initializer, never()).handleInputInitializerEvent(argumentCaptor.capture());
  }

  // Order event1 success1, success2, event2
  // Primarily a failure scenario, when a Task moves back to running from success
  @SuppressWarnings("unchecked")
  @Test(timeout = 5000)
  public void testSuccessBeforeEvent() throws Exception {
    InputDescriptor id = mock(InputDescriptor.class);
    InputInitializerDescriptor iid = mock(InputInitializerDescriptor.class);
    RootInputLeafOutput<InputDescriptor, InputInitializerDescriptor> rootInput =
        new RootInputLeafOutput<InputDescriptor, InputInitializerDescriptor>("InputName", id, iid);

    InputInitializer initializer = mock(InputInitializer.class);
    InputInitializerContext initializerContext = mock(InputInitializerContext.class);
    Vertex vertex = mock(Vertex.class);
    StateChangeNotifier stateChangeNotifier = mock(StateChangeNotifier.class);
    AppContext appContext = mock(AppContext.class, RETURNS_DEEP_STUBS);

    RootInputInitializerManager.InitializerWrapper initializerWrapper =
        new RootInputInitializerManager.InitializerWrapper(rootInput, initializer,
            initializerContext, vertex, stateChangeNotifier, appContext);

    ApplicationId appId = ApplicationId.newInstance(1000, 1);
    TezDAGID dagId = TezDAGID.getInstance(appId, 1);
    TezVertexID srcVertexId = TezVertexID.getInstance(dagId, 2);
    TezTaskID srcTaskId1 = TezTaskID.getInstance(srcVertexId, 3);
    Vertex srcVertex = mock(Vertex.class);
    Task srcTask1 = mock(Task.class);
    doReturn(TaskState.RUNNING).when(srcTask1).getState();
    doReturn(srcTask1).when(srcVertex).getTask(srcTaskId1.getId());
    when(appContext.getCurrentDAG().getVertex(any(String.class))).thenReturn(srcVertex);

    String srcVertexName = "srcVertexName";
    List<TezEvent> eventList = Lists.newLinkedList();


    // First Attempt send event
    TezTaskAttemptID srcTaskAttemptId11 = TezTaskAttemptID.getInstance(srcTaskId1, 1);
    EventMetaData sourceInfo11 =
        new EventMetaData(EventMetaData.EventProducerConsumerType.PROCESSOR, srcVertexName, null,
            srcTaskAttemptId11);
    InputInitializerEvent e1 = InputInitializerEvent.create("fakeVertex", "fakeInput", null);
    TezEvent te1 = new TezEvent(e1, sourceInfo11);
    eventList.add(te1);
    initializerWrapper.handleInputInitializerEvents(eventList);

    verify(initializer, never()).handleInputInitializerEvent(any(List.class));
    eventList.clear();

    // First attempt, Task success notification
    initializerWrapper.onTaskSucceeded(srcVertexName, srcTaskId1, srcTaskAttemptId11.getId());
    ArgumentCaptor<List> argumentCaptor = ArgumentCaptor.forClass(List.class);
    verify(initializer, times(1)).handleInputInitializerEvent(argumentCaptor.capture());
    List<InputInitializerEvent> invokedEvents = argumentCaptor.getValue();
    assertEquals(1, invokedEvents.size());

    reset(initializer);


    TezTaskAttemptID srcTaskAttemptId12 = TezTaskAttemptID.getInstance(srcTaskId1, 2);
    // 2nd attempt succeeded
    initializerWrapper.onTaskSucceeded(srcVertexName, srcTaskId1, srcTaskAttemptId12.getId());
    verify(initializer, never()).handleInputInitializerEvent(any(List.class));

    // 2nd attempt send event
    EventMetaData sourceInfo12 =
        new EventMetaData(EventMetaData.EventProducerConsumerType.PROCESSOR, srcVertexName, null,
            srcTaskAttemptId12);
    InputInitializerEvent e2 = InputInitializerEvent.create("fakeVertex", "fakeInput", null);
    TezEvent te2 = new TezEvent(e2, sourceInfo12);
    eventList.add(te2);
    initializerWrapper.handleInputInitializerEvents(eventList);

    verify(initializer, never()).handleInputInitializerEvent(any(List.class));
  }


  @Test (timeout = 5000)
  public void testCorrectUgiUsage() throws TezException, InterruptedException {
    Vertex vertex = mock(Vertex.class);
    doReturn(mock(TezVertexID.class)).when(vertex).getVertexId();
    AppContext appContext = mock(AppContext.class);
    doReturn(mock(EventHandler.class)).when(appContext).getEventHandler();
    UserGroupInformation dagUgi = UserGroupInformation.createRemoteUser("fakeuser");
    StateChangeNotifier stateChangeNotifier = mock(StateChangeNotifier.class);
    RootInputInitializerManager rootInputInitializerManager = new RootInputInitializerManager(vertex, appContext, dagUgi, stateChangeNotifier);

    InputDescriptor id = mock(InputDescriptor.class);
    InputInitializerDescriptor iid = InputInitializerDescriptor.create(InputInitializerForUgiTest.class.getName());
    RootInputLeafOutput<InputDescriptor, InputInitializerDescriptor> rootInput =
        new RootInputLeafOutput<InputDescriptor, InputInitializerDescriptor>("InputName", id, iid);
    rootInputInitializerManager.runInputInitializers(Collections.singletonList(rootInput));

    InputInitializerForUgiTest.awaitInitialize();

    assertEquals(dagUgi, InputInitializerForUgiTest.ctorUgi);
    assertEquals(dagUgi, InputInitializerForUgiTest.initializeUgi);
  }

  public static class InputInitializerForUgiTest extends InputInitializer {

    static volatile UserGroupInformation ctorUgi;
    static volatile UserGroupInformation initializeUgi;

    static boolean initialized = false;
    static final Object initializeSync = new Object();

    public InputInitializerForUgiTest(InputInitializerContext initializerContext) {
      super(initializerContext);
      try {
        ctorUgi = UserGroupInformation.getCurrentUser();
      } catch (IOException e) {
        throw new RuntimeException(e);
      }
    }

    @Override
    public List<Event> initialize() throws Exception {
      initializeUgi = UserGroupInformation.getCurrentUser();
      synchronized (initializeSync) {
        initialized = true;
        initializeSync.notify();
      }
      return null;
    }

    @Override
    public void handleInputInitializerEvent(List<InputInitializerEvent> events) throws Exception {
    }

    static void awaitInitialize() throws InterruptedException {
      synchronized (initializeSync) {
        while (!initialized) {
          initializeSync.wait();
        }
      }
    }
  }
}
