/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.tez.dag.app.dag;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import java.util.EnumSet;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

import com.google.common.collect.Lists;

import org.apache.tez.dag.api.TezUncheckedException;
import org.apache.tez.dag.api.event.VertexState;
import org.apache.tez.dag.api.event.VertexStateUpdate;
import org.apache.tez.dag.api.event.VertexStateUpdateParallelismUpdated;
import org.apache.tez.dag.records.TezDAGID;
import org.apache.tez.dag.records.TezVertexID;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TestStateChangeNotifier {
  
  // uses the thread based notification code path but effectively blocks update
  // events till listeners have been notified
  public static class StateChangeNotifierForTest extends StateChangeNotifier {
    private static final Logger LOG = LoggerFactory.getLogger(StateChangeNotifierForTest.class);
    AtomicInteger count = new AtomicInteger(0);
    AtomicInteger totalCount = new AtomicInteger(0);

    public StateChangeNotifierForTest(DAG dag) {
      super(dag);
    }

    public void reset() {
      count.set(0);
      totalCount.set(0);
    }
    
    @Override
    protected void processedEventFromQueue() {
      // addedEventToQueue runs in dispatcher thread while
      // processedEventFromQueue runs in state change notifier event handling thread.
      // It is not guaranteed that addedEventToQueue is invoked before processedEventFromQueue.
      // so sleep here until there's available events
      while(count.get() <=0) {
        try {
          Thread.sleep(10);
          LOG.info("sleep to wait for available events");
        } catch (InterruptedException e) {
          e.printStackTrace();
        }
      }
      synchronized (count) {
        if (count.decrementAndGet() == 0) {
          count.notifyAll();
        }
      }
    }
    
    @Override
    protected void addedEventToQueue() {
      totalCount.incrementAndGet();
      synchronized (count) {
        // processing may finish by the time we get here
        if (count.incrementAndGet() > 0) {
          try {
            count.wait();
          } catch (InterruptedException e) {
            e.printStackTrace();
          }
        }
      }
    }
  }

  @Test(timeout = 5000)
  public void testEventsOnRegistration() {
    TezDAGID dagId = TezDAGID.getInstance("1", 1, 1);
    Vertex v1 = createMockVertex(dagId, 1);
    Vertex v2 = createMockVertex(dagId, 2);
    Vertex v3 = createMockVertex(dagId, 3);
    DAG dag = createMockDag(dagId, v1, v2, v3);

    StateChangeNotifierForTest tracker = new StateChangeNotifierForTest(dag);

    // Vertex has sent one event
    notifyTracker(tracker, v1, VertexState.RUNNING);
    VertexStateUpdateListener mockListener11 = mock(VertexStateUpdateListener.class);
    VertexStateUpdateListener mockListener12 = mock(VertexStateUpdateListener.class);
    VertexStateUpdateListener mockListener13 = mock(VertexStateUpdateListener.class);
    VertexStateUpdateListener mockListener14 = mock(VertexStateUpdateListener.class);
      // Register for all states
    tracker.registerForVertexUpdates(v1.getName(), null, mockListener11);
      // Register for all states
    tracker.registerForVertexUpdates(v1.getName(), EnumSet.allOf(
        VertexState.class), mockListener12);
      // Register for specific state, event generated
    tracker.registerForVertexUpdates(v1.getName(), EnumSet.of(
        VertexState.RUNNING), mockListener13);
      // Register for specific state, event not generated
    tracker.registerForVertexUpdates(v1.getName(), EnumSet.of(
        VertexState.SUCCEEDED), mockListener14);
    ArgumentCaptor<VertexStateUpdate> argumentCaptor =
        ArgumentCaptor.forClass(VertexStateUpdate.class);
    verify(mockListener11, times(1)).onStateUpdated(argumentCaptor.capture());
    assertEquals(VertexState.RUNNING,
        argumentCaptor.getValue().getVertexState());
    verify(mockListener12, times(1)).onStateUpdated(argumentCaptor.capture());
    assertEquals(VertexState.RUNNING,
        argumentCaptor.getValue().getVertexState());
    verify(mockListener13, times(1)).onStateUpdated(argumentCaptor.capture());
    assertEquals(VertexState.RUNNING,
        argumentCaptor.getValue().getVertexState());
    verify(mockListener14, never()).onStateUpdated(any(VertexStateUpdate.class));

    // Vertex has not notified of state
    tracker.reset();
    VertexStateUpdateListener mockListener2 = mock(VertexStateUpdateListener.class);
    tracker.registerForVertexUpdates(v2.getName(), null, mockListener2);
    Assert.assertEquals(0, tracker.totalCount.get()); // there should no be any event sent out
    verify(mockListener2, never()).onStateUpdated(any(VertexStateUpdate.class));

    // Vertex has notified about parallelism update only
    tracker.stateChanged(v3.getVertexId(), new VertexStateUpdateParallelismUpdated(v3.getName(), 23, -1));
    VertexStateUpdateListener mockListener3 = mock(VertexStateUpdateListener.class);
    tracker.registerForVertexUpdates(v3.getName(), null, mockListener3);
    verify(mockListener3, times(1)).onStateUpdated(argumentCaptor.capture());
    assertEquals(VertexState.PARALLELISM_UPDATED,
        argumentCaptor.getValue().getVertexState());
  }

  @Test(timeout = 5000)
  public void testSimpleStateUpdates() {
    TezDAGID dagId = TezDAGID.getInstance("1", 1, 1);
    Vertex v1 = createMockVertex(dagId, 1);
    DAG dag = createMockDag(dagId, v1);

    StateChangeNotifierForTest tracker = new StateChangeNotifierForTest(dag);

    VertexStateUpdateListener mockListener = mock(VertexStateUpdateListener.class);
    tracker.registerForVertexUpdates(v1.getName(), null, mockListener);

    List<VertexState> expectedStates = Lists.newArrayList(
        VertexState.RUNNING,
        VertexState.SUCCEEDED,
        VertexState.FAILED,
        VertexState.KILLED,
        VertexState.RUNNING,
        VertexState.SUCCEEDED);

    for (VertexState state : expectedStates) {
      notifyTracker(tracker, v1, state);
    }

    ArgumentCaptor<VertexStateUpdate> argumentCaptor =
        ArgumentCaptor.forClass(VertexStateUpdate.class);
    verify(mockListener, times(expectedStates.size())).onStateUpdated(argumentCaptor.capture());
    List<VertexStateUpdate> stateUpdatesSent = argumentCaptor.getAllValues();

    Iterator<VertexState> expectedStateIter =
        expectedStates.iterator();
    for (int i = 0; i < expectedStates.size(); i++) {
      assertEquals(expectedStateIter.next(), stateUpdatesSent.get(i).getVertexState());
    }
  }

  @Test(timeout = 5000)
  public void testDuplicateRegistration() {
    TezDAGID dagId = TezDAGID.getInstance("1", 1, 1);
    Vertex v1 = createMockVertex(dagId, 1);
    DAG dag = createMockDag(dagId, v1);

    StateChangeNotifierForTest tracker = new StateChangeNotifierForTest(dag);
    VertexStateUpdateListener mockListener = mock(VertexStateUpdateListener.class);

    tracker.registerForVertexUpdates(v1.getName(), null, mockListener);
    try {
      tracker.registerForVertexUpdates(v1.getName(), null, mockListener);
      fail("Expecting an error from duplicate registrations of the same listener");
    } catch (TezUncheckedException e) {
      // Expected, ignore
    }
  }

  @Test(timeout = 5000)
  public void testSpecificStateUpdates() {
    TezDAGID dagId = TezDAGID.getInstance("1", 1, 1);
    Vertex v1 = createMockVertex(dagId, 1);
    DAG dag = createMockDag(dagId, v1);

    StateChangeNotifierForTest tracker = new StateChangeNotifierForTest(dag);

    VertexStateUpdateListener mockListener = mock(VertexStateUpdateListener.class);
    tracker.registerForVertexUpdates(v1.getName(), EnumSet.of(
        VertexState.RUNNING,
        VertexState.SUCCEEDED), mockListener);

    List<VertexState> states = Lists.newArrayList(
        VertexState.RUNNING,
        VertexState.SUCCEEDED,
        VertexState.FAILED,
        VertexState.KILLED,
        VertexState.RUNNING,
        VertexState.SUCCEEDED);
    List<VertexState> expectedStates = Lists.newArrayList(
        VertexState.RUNNING,
        VertexState.SUCCEEDED,
        VertexState.RUNNING,
        VertexState.SUCCEEDED);

    for (VertexState state : states) {
      notifyTracker(tracker, v1, state);
    }

    ArgumentCaptor<VertexStateUpdate> argumentCaptor =
        ArgumentCaptor.forClass(VertexStateUpdate.class);
    verify(mockListener, times(expectedStates.size())).onStateUpdated(argumentCaptor.capture());
    List<VertexStateUpdate> stateUpdatesSent = argumentCaptor.getAllValues();

    Iterator<VertexState> expectedStateIter =
        expectedStates.iterator();
    for (int i = 0; i < expectedStates.size(); i++) {
      assertEquals(expectedStateIter.next(), stateUpdatesSent.get(i).getVertexState());
    }
  }

  @Test(timeout = 5000)
  public void testUnregister() {
    TezDAGID dagId = TezDAGID.getInstance("1", 1, 1);
    Vertex v1 = createMockVertex(dagId, 1);
    DAG dag = createMockDag(dagId, v1);

    StateChangeNotifierForTest tracker = new StateChangeNotifierForTest(dag);

    VertexStateUpdateListener mockListener = mock(VertexStateUpdateListener.class);
    tracker.registerForVertexUpdates(v1.getName(), null, mockListener);

    List<VertexState> expectedStates = Lists.newArrayList(
        VertexState.RUNNING,
        VertexState.SUCCEEDED,
        VertexState.FAILED,
        VertexState.KILLED,
        VertexState.RUNNING,
        VertexState.SUCCEEDED);

    int count = 0;
    int numExpectedEvents = 3;
    for (VertexState state : expectedStates) {
      if (count == numExpectedEvents) {
        tracker.unregisterForVertexUpdates(v1.getName(), mockListener);
      }
      notifyTracker(tracker, v1, state);
      count++;
    }

    ArgumentCaptor<VertexStateUpdate> argumentCaptor =
        ArgumentCaptor.forClass(VertexStateUpdate.class);
    verify(mockListener, times(numExpectedEvents)).onStateUpdated(argumentCaptor.capture());
    List<VertexStateUpdate> stateUpdatesSent = argumentCaptor.getAllValues();

    Iterator<VertexState> expectedStateIter =
        expectedStates.iterator();
    for (int i = 0; i < numExpectedEvents; i++) {
      assertEquals(expectedStateIter.next(), stateUpdatesSent.get(i).getVertexState());
    }
  }

  private DAG createMockDag(TezDAGID dagId, Vertex... vertices) {
    DAG dag = mock(DAG.class);
    doReturn(dagId).when(dag).getID();
    for (Vertex v : vertices) {
      String vertexName = v.getName();
      TezVertexID vertexId = v.getVertexId();

      doReturn(v).when(dag).getVertex(vertexName);
      doReturn(v).when(dag).getVertex(vertexId);
    }
    return dag;
  }

  private Vertex createMockVertex(TezDAGID dagId, int id) {
    TezVertexID vertexId = TezVertexID.getInstance(dagId, id);
    String vertexName = "vertex" + id;
    Vertex v = mock(Vertex.class);
    doReturn(vertexId).when(v).getVertexId();
    doReturn(vertexName).when(v).getName();
    return v;
  }

  private void notifyTracker(StateChangeNotifier notifier, Vertex v,
                             VertexState state) {
    notifier.stateChanged(v.getVertexId(), new VertexStateUpdate(v.getName(), state));
  }
}
