TEZ-3880: Do not count rejected tasks as killed in vertex progress (Sergey Shelukhin, reviewed by Gunther Hagleitner)

Signed-off-by: Gopal V <gopalv@apache.org>
diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/client/Progress.java b/tez-api/src/main/java/org/apache/tez/dag/api/client/Progress.java
index 110ac90..656838d 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/client/Progress.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/client/Progress.java
@@ -63,6 +63,10 @@
     return proxy.getKilledTaskAttemptCount();
   }
 
+  public int getRejectedTaskAttemptCount() {
+    return proxy.getRejectedTaskAttemptCount();
+  }
+
   @Override
   public boolean equals(Object obj) {
     if (obj instanceof Progress){
@@ -73,7 +77,8 @@
           && getFailedTaskCount() == other.getFailedTaskCount()
           && getKilledTaskCount() == other.getKilledTaskCount()
           && getFailedTaskAttemptCount() == other.getFailedTaskAttemptCount()
-          && getKilledTaskAttemptCount() == other.getKilledTaskAttemptCount();
+          && getKilledTaskAttemptCount() == other.getKilledTaskAttemptCount()
+          && getRejectedTaskAttemptCount() == other.getRejectedTaskAttemptCount();
     }
     return false;
   }
@@ -94,6 +99,8 @@
         getFailedTaskAttemptCount();
     result = prime * result +
         getKilledTaskAttemptCount();
+    result = prime * result +
+        getRejectedTaskAttemptCount();
 
     return result;
   }
@@ -119,6 +126,10 @@
       sb.append(" KilledTaskAttempts: ");
       sb.append(getKilledTaskAttemptCount());
     }
+    if (getRejectedTaskAttemptCount() > 0) {
+      sb.append(" RejectedTaskAttempts: ");
+      sb.append(getRejectedTaskAttemptCount());
+    }
     return sb.toString();
   }
 
diff --git a/tez-api/src/main/proto/DAGApiRecords.proto b/tez-api/src/main/proto/DAGApiRecords.proto
index c84094b..34c369d 100644
--- a/tez-api/src/main/proto/DAGApiRecords.proto
+++ b/tez-api/src/main/proto/DAGApiRecords.proto
@@ -227,6 +227,7 @@
   optional int32 killedTaskCount = 5;
   optional int32 failedTaskAttemptCount = 6;
   optional int32 killedTaskAttemptCount = 7;
+  optional int32 rejectedTaskAttemptCount = 8;
 }
 
 enum VertexStatusStateProto {
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/api/client/ProgressBuilder.java b/tez-dag/src/main/java/org/apache/tez/dag/api/client/ProgressBuilder.java
index 5381518..9dc1354 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/api/client/ProgressBuilder.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/api/client/ProgressBuilder.java
@@ -59,6 +59,10 @@
     getBuilder().setKilledTaskAttemptCount(count);
   }
 
+  public void setRejectedTaskAttemptCount(int count) {
+    getBuilder().setRejectedTaskAttemptCount(count);
+  }
+
   private ProgressProto.Builder getBuilder() {
     return (Builder) this.proxy;
   }
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/Vertex.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/Vertex.java
index ba7624c..0e54e9f 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/Vertex.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/Vertex.java
@@ -214,4 +214,7 @@
     boolean getTaskRescheduleHigherPriority();
     boolean getTaskRescheduleRelaxedLocality();
   }
+
+  void incrementRejectedTaskAttemptCount();
+  int getRejectedTaskAttemptCount();
 }
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java
index 481353b..6c67e68 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java
@@ -867,6 +867,7 @@
     int totalKilledTaskCount = 0;
     int totalFailedTaskAttemptCount = 0;
     int totalKilledTaskAttemptCount = 0;
+    int totalRejectedTaskAttemptCount = 0;
     readLock.lock();
     try {
       for(Map.Entry<String, Vertex> entry : vertexMap.entrySet()) {
@@ -879,6 +880,7 @@
         totalKilledTaskCount += progress.getKilledTaskCount();
         totalFailedTaskAttemptCount += progress.getFailedTaskAttemptCount();
         totalKilledTaskAttemptCount += progress.getKilledTaskAttemptCount();
+        totalRejectedTaskAttemptCount += progress.getRejectedTaskAttemptCount();
       }
       ProgressBuilder dagProgress = new ProgressBuilder();
       dagProgress.setTotalTaskCount(totalTaskCount);
@@ -888,6 +890,7 @@
       dagProgress.setKilledTaskCount(totalKilledTaskCount);
       dagProgress.setFailedTaskAttemptCount(totalFailedTaskAttemptCount);
       dagProgress.setKilledTaskAttemptCount(totalKilledTaskAttemptCount);
+      dagProgress.setRejectedTaskAttemptCount(totalRejectedTaskAttemptCount);
       status.setState(getState());
       status.setDiagnostics(diagnostics);
       status.setDAGProgress(dagProgress);
@@ -942,6 +945,7 @@
     int totalKilledTaskCount = 0;
     int totalFailedTaskAttemptCount = 0;
     int totalKilledTaskAttemptCount = 0;
+    int totalRejectedTaskAttemptCount = 0;
     readLock.lock();
     try {
       for(Map.Entry<String, Vertex> entry : vertexMap.entrySet()) {
@@ -953,6 +957,7 @@
         totalKilledTaskCount += progress.getKilledTaskCount();
         totalFailedTaskAttemptCount += progress.getFailedTaskAttemptCount();
         totalKilledTaskAttemptCount += progress.getKilledTaskAttemptCount();
+        totalRejectedTaskAttemptCount += progress.getRejectedTaskAttemptCount();
       }
       ProgressBuilder dagProgress = new ProgressBuilder();
       dagProgress.setTotalTaskCount(totalTaskCount);
@@ -962,6 +967,7 @@
       dagProgress.setKilledTaskCount(totalKilledTaskCount);
       dagProgress.setFailedTaskAttemptCount(totalFailedTaskAttemptCount);
       dagProgress.setKilledTaskAttemptCount(totalKilledTaskAttemptCount);
+      dagProgress.setRejectedTaskAttemptCount(totalRejectedTaskAttemptCount);
       return dagProgress;
     } finally {
       readLock.unlock();
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/TaskAttemptImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/TaskAttemptImpl.java
index 1218543..c43bd98 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/TaskAttemptImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/TaskAttemptImpl.java
@@ -1413,7 +1413,7 @@
       ta.sendEvent(createDAGCounterUpdateEventTAFinished(ta,
           helper.getTaskAttemptState()));
       // Send out events to the Task - indicating TaskAttemptTermination(F/K)
-      ta.sendEvent(helper.getTaskEvent(ta.attemptId,  event));
+      ta.sendEvent(helper.getTaskEvent(ta.attemptId, event));
     }
   }
 
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/TaskImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/TaskImpl.java
index 99cb2e0..9e1d85f 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/TaskImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/TaskImpl.java
@@ -75,8 +75,10 @@
 import org.apache.tez.dag.app.dag.event.TaskAttemptEventAttemptKilled;
 import org.apache.tez.dag.app.dag.event.TaskAttemptEventKillRequest;
 import org.apache.tez.dag.app.dag.event.TaskAttemptEventOutputFailed;
+import org.apache.tez.dag.app.dag.event.TaskAttemptEventTerminationCauseEvent;
 import org.apache.tez.dag.app.dag.event.TaskEvent;
 import org.apache.tez.dag.app.dag.event.TaskEventScheduleTask;
+import org.apache.tez.dag.app.dag.event.TaskEventTAKilled;
 import org.apache.tez.dag.app.dag.event.TaskEventTAUpdate;
 import org.apache.tez.dag.app.dag.event.TaskEventTermination;
 import org.apache.tez.dag.app.dag.event.TaskEventType;
@@ -1145,7 +1147,21 @@
           TaskAttemptStateInternal.KILLED);
       // we KillWaitAttemptCompletedTransitionready have a spare
       task.taskAttemptStatus.put(castEvent.getTaskAttemptID().getId(), true);
-      task.getVertex().incrementKilledTaskAttemptCount();
+
+      boolean isRejection = false;
+      if (event instanceof TaskEventTAKilled) {
+        TaskEventTAKilled killEvent = (TaskEventTAKilled) event;
+        if (killEvent.getCausalEvent() instanceof TaskAttemptEventTerminationCauseEvent) {
+          TaskAttemptEventTerminationCauseEvent cause =
+              (TaskAttemptEventTerminationCauseEvent)killEvent.getCausalEvent();
+          isRejection = cause.getTerminationCause() == TaskAttemptTerminationCause.SERVICE_BUSY;
+        }
+      }
+      if (isRejection) { // TODO: remove as part of TEZ-3881.
+        task.getVertex().incrementRejectedTaskAttemptCount();
+      } else {
+        task.getVertex().incrementKilledTaskAttemptCount();
+      }
       if (task.shouldScheduleNewAttempt()) {
         task.addAndScheduleAttempt(castEvent.getTaskAttemptID());
       }
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
index 13cfb8f..d727e39 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
@@ -676,6 +676,7 @@
   AtomicInteger failedTaskAttemptCount = new AtomicInteger(0);
   @VisibleForTesting
   AtomicInteger killedTaskAttemptCount = new AtomicInteger(0);
+  AtomicInteger rejectedTaskAttemptCount = new AtomicInteger(0);
 
   @VisibleForTesting
   long initTimeRequested; // Time at which INIT request was received.
@@ -1429,6 +1430,7 @@
       progress.setKilledTaskCount(killedTaskCount);
       progress.setFailedTaskAttemptCount(failedTaskAttemptCount.get());
       progress.setKilledTaskAttemptCount(killedTaskAttemptCount.get());
+      progress.setRejectedTaskAttemptCount(rejectedTaskAttemptCount.get());
       return progress;
     } finally {
       this.readLock.unlock();
@@ -1551,6 +1553,11 @@
   }
 
   @Override
+  public void incrementRejectedTaskAttemptCount() {
+    this.rejectedTaskAttemptCount.incrementAndGet();
+  }
+
+  @Override
   public int getFailedTaskAttemptCount() {
     return this.failedTaskAttemptCount.get();
   }
@@ -1560,6 +1567,11 @@
     return this.killedTaskAttemptCount.get();
   }
 
+  @Override
+  public int getRejectedTaskAttemptCount() {
+    return this.rejectedTaskAttemptCount.get();
+  }
+
   private void setTaskLocationHints(VertexLocationHint vertexLocationHint) {
     if (vertexLocationHint != null &&
         vertexLocationHint.getTaskLocationHints() != null &&
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestTaskImpl.java b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestTaskImpl.java
index d13e654..b142bb9 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestTaskImpl.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestTaskImpl.java
@@ -40,6 +40,7 @@
 import org.apache.tez.dag.app.dag.event.TaskAttemptEventStartedRemotely;
 import org.apache.tez.dag.app.dag.event.TaskAttemptEventSubmitted;
 import org.apache.tez.dag.app.dag.event.DAGEventType;
+import org.apache.tez.dag.app.dag.event.TaskAttemptEventTerminationCauseEvent;
 import org.apache.tez.dag.app.dag.event.TaskEvent;
 import org.apache.tez.dag.app.dag.event.TaskEventTAFailed;
 import org.apache.tez.dag.app.dag.event.TaskEventTAKilled;
@@ -104,7 +105,6 @@
 import org.junit.Test;
 
 public class TestTaskImpl {
-
   private static final Logger LOG = LoggerFactory.getLogger(TestTaskImpl.class);
 
   private int taskCounter = 0;
@@ -185,7 +185,7 @@
     Vertex vertex = mock(Vertex.class);
     doReturn(new VertexImpl.VertexConfigImpl(conf)).when(vertex).getVertexConfig();
     eventHandler = new TestEventHandler();
-    
+
     mockTask = new MockTaskImpl(vertexId, partition,
         eventHandler, conf, taskCommunicatorManagerInterface, clock,
         taskHeartbeatHandler, appContext, leafVertex,
@@ -508,6 +508,23 @@
     Assert.assertEquals(lastTAId, mockTask.getLastAttempt().getSchedulingCausalTA());
   }
 
+  @Test(timeout = 5000)
+  /**
+   * Kill running attempt
+   * {@link TaskState#RUNNING}->{@link TaskState#RUNNING}
+   */
+  public void testKillTaskAttemptServiceBusy() {
+    LOG.info("--- START: testKillTaskAttemptServiceBusy ---");
+    TezTaskID taskId = getNewTaskID();
+    scheduleTaskAttempt(taskId);
+    launchTaskAttempt(mockTask.getLastAttempt().getID());
+    mockTask.handle(createTaskTAKilledEvent(
+        mockTask.getLastAttempt().getID(), new ServiceBusyEvent()));
+    assertTaskRunningState();
+    verify(mockTask.getVertex(), times(0)).incrementKilledTaskAttemptCount();
+    verify(mockTask.getVertex(), times(1)).incrementRejectedTaskAttemptCount();
+  }
+
   /**
    * {@link TaskState#KILLED}->{@link TaskState#KILLED}
    */
@@ -1386,4 +1403,16 @@
     }
   }
 
+  public class ServiceBusyEvent extends TezAbstractEvent<TaskAttemptEventType>
+     implements TaskAttemptEventTerminationCauseEvent {
+    public ServiceBusyEvent() {
+      super(TaskAttemptEventType.TA_KILLED);
+    }
+
+    @Override
+    public TaskAttemptTerminationCause getTerminationCause() {
+      return TaskAttemptTerminationCause.SERVICE_BUSY;
+    }
+  }
 }
+
diff --git a/tez-ext-service-tests/src/test/java/org/apache/tez/tests/TestExternalTezServices.java b/tez-ext-service-tests/src/test/java/org/apache/tez/tests/TestExternalTezServices.java
index 920534a..c135d7a 100644
--- a/tez-ext-service-tests/src/test/java/org/apache/tez/tests/TestExternalTezServices.java
+++ b/tez-ext-service-tests/src/test/java/org/apache/tez/tests/TestExternalTezServices.java
@@ -200,7 +200,7 @@
     DAGStatus dagStatus = dagClient.waitForCompletion();
     assertEquals(DAGStatus.State.SUCCEEDED, dagStatus.getState());
     assertEquals(1, dagStatus.getDAGProgress().getFailedTaskAttemptCount());
-    assertEquals(1, dagStatus.getDAGProgress().getKilledTaskAttemptCount());
+    assertEquals(1, dagStatus.getDAGProgress().getRejectedTaskAttemptCount());
 
   }