TEZ-1742. Improve response time of internal preemption (bikas)
diff --git a/CHANGES.txt b/CHANGES.txt
index bac35ff..f4742f9 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -10,6 +10,7 @@
   TEZ-1738. Tez tfile parser for log parsing
   TEZ-1627. Remove OUTPUT_CONSUMABLE and related Event in TaskAttemptImpl
   TEZ-1749. Increase test timeout for TestLocalMode.testMultipleClientsWithSession
+  TEZ-1742. Improve response time of internal preemption
 
 Release 0.5.3: Unreleased
 
diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java b/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
index 7cc0aa5..d9003b3 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
@@ -570,6 +570,28 @@
   public static final String TEZ_AM_SESSION_MIN_HELD_CONTAINERS = 
       TEZ_AM_PREFIX + "session.min.held-containers";
   public static final int TEZ_AM_SESSION_MIN_HELD_CONTAINERS_DEFAULT = 0;
+  
+  /**
+   * Int value. Specifies the percentage of tasks eligible to be preempted that
+   * will actually be preempted in a given round of Tez internal preemption.
+   * This slows down preemption and gives more time for free resources to be
+   * allocated by the cluster (if any) and gives more time for preemptable tasks
+   * to finish. Valid values are 0-100. Higher values will preempt quickly at
+   * the cost of losing work. Setting to 0 turns off preemption. Expert level
+   * setting.
+   */
+  public static final String TEZ_AM_PREEMPTION_PERCENTAGE = 
+      TEZ_AM_PREFIX + "preemption.percentage";
+  public static final int TEZ_AM_PREEMPTION_PERCENTAGE_DEFAULT = 10;
+  
+  /**
+   * Int value. The number of RM heartbeats to wait after preempting running tasks before preempting
+   * more running tasks. After preempting a task, we need to wait at least 1 heartbeat so that the 
+   * RM can act on the released resources and assign new ones to us. Expert level setting.
+   */
+  public static final String TEZ_AM_PREEMPTION_HEARTBEATS_BETWEEN_PREEMPTIONS = 
+      TEZ_AM_PREFIX + "preemption.heartbeats-between-preemptions";
+  public static final int TEZ_AM_PREEMPTION_HEARTBEATS_BETWEEN_PREEMPTIONS_DEFAULT = 3;
 
   /**
    * String value to a file path.
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/YarnTaskSchedulerService.java b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/YarnTaskSchedulerService.java
index 75d62f1..5941a45 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/YarnTaskSchedulerService.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/YarnTaskSchedulerService.java
@@ -123,6 +123,9 @@
   
   Resource totalResources = Resource.newInstance(0, 0);
   Resource allocatedResources = Resource.newInstance(0, 0);
+  long numHeartbeats = 0;
+  long heartbeatAtLastPreemption = 0;
+  int numHeartbeatsBetweenPreemptions = 0;
   
   final String appHostName;
   final int appHostPort;
@@ -141,6 +144,8 @@
   long idleContainerTimeoutMin;
   long idleContainerTimeoutMax = 0;
   int sessionNumMinHeldContainers = 0;
+  int preemptionPercentage = 0; 
+  
   Set<ContainerId> sessionMinHeldContainers = Sets.newHashSet();
   
   RandomDataGenerator random = new RandomDataGenerator();
@@ -328,6 +333,17 @@
         TezConfiguration.TEZ_AM_SESSION_MIN_HELD_CONTAINERS_DEFAULT);
     Preconditions.checkArgument(sessionNumMinHeldContainers >= 0, 
         "Session minimum held containers should be >=0");
+    
+    preemptionPercentage = conf.getInt(TezConfiguration.TEZ_AM_PREEMPTION_PERCENTAGE, 
+        TezConfiguration.TEZ_AM_PREEMPTION_PERCENTAGE_DEFAULT);
+    Preconditions.checkArgument(preemptionPercentage >= 0 && preemptionPercentage <= 100,
+        "Preemption percentage should be between 0-100");
+    
+    numHeartbeatsBetweenPreemptions = conf.getInt(
+        TezConfiguration.TEZ_AM_PREEMPTION_HEARTBEATS_BETWEEN_PREEMPTIONS,
+        TezConfiguration.TEZ_AM_PREEMPTION_HEARTBEATS_BETWEEN_PREEMPTIONS_DEFAULT);
+    Preconditions.checkArgument(numHeartbeatsBetweenPreemptions >= 1, 
+        "Heartbeats between preemptions should be >=1");
 
     delayedContainerManager = new DelayedContainerManager();
     LOG.info("TaskScheduler initialized with configuration: " +
@@ -336,6 +352,8 @@
             ", reuseRackLocal: " + reuseRackLocal +
             ", reuseNonLocal: " + reuseNonLocal + 
             ", localitySchedulingDelay: " + localitySchedulingDelay +
+            ", preemptionPercentage: " + preemptionPercentage +
+            ", numHeartbeatsBetweenPreemptions" + numHeartbeatsBetweenPreemptions +
             ", idleContainerMinTimeout=" + idleContainerTimeoutMin +
             ", idleContainerMaxTimeout=" + idleContainerTimeoutMax +
             ", sessionMinHeldContainers=" + sessionNumMinHeldContainers);
@@ -847,6 +865,7 @@
                " taskAllocations: " + taskAllocations.size());
     }
 
+    numHeartbeats++;
     preemptIfNeeded();
 
     return appClientDelegate.getProgress();
@@ -1031,34 +1050,72 @@
     }
     return false; 
   }
+
+  private int scaleDownByPreemptionPercentage(int original) {
+    return (original + (preemptionPercentage - 1)) / preemptionPercentage;
+  }
   
   void preemptIfNeeded() {
-    ContainerId preemptedContainer = null;
+    if (preemptionPercentage == 0) {
+      // turned off
+      return;
+    }
+    ContainerId[] preemptedContainers = null;
+    int numPendingRequestsToService = 0;
     synchronized (this) {
-      Resource freeResources = Resources.subtract(totalResources,
-        allocatedResources);
+      Resource freeResources = amRmClient.getAvailableResources();
       if (LOG.isDebugEnabled()) {
         LOG.debug("Allocated resource memory: " + allocatedResources.getMemory() +
           " cpu:" + allocatedResources.getVirtualCores() + 
-          " delayedContainers: " + delayedContainerManager.delayedContainers.size());
+          " delayedContainers: " + delayedContainerManager.delayedContainers.size() +
+          " heartbeats: " + numHeartbeats + " lastPreemptionHeartbeat: " + heartbeatAtLastPreemption);
       }
       assert freeResources.getMemory() >= 0;
   
       CookieContainerRequest highestPriRequest = null;
+      int numHighestPriRequests = 0;
       for(CookieContainerRequest request : taskRequests.values()) {
         if(highestPriRequest == null) {
           highestPriRequest = request;
+          numHighestPriRequests = 1;
         } else if(isHigherPriority(request.getPriority(),
                                      highestPriRequest.getPriority())){
           highestPriRequest = request;
+          numHighestPriRequests = 1;
+        } else if (request.getPriority().equals(highestPriRequest.getPriority())) {
+          numHighestPriRequests++;
         }
       }
-      if(highestPriRequest != null &&
-         !fitsIn(highestPriRequest.getCapability(), freeResources)) {
-        // highest priority request will not fit in existing free resources
-        // free up some more
-        // TODO this is subject to error wrt RM resource normalization
-        
+      
+      if (highestPriRequest == null) {
+        // nothing pending
+        return;
+      }
+      
+      if(fitsIn(highestPriRequest.getCapability(), freeResources)) {
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("Highest pri request: " + highestPriRequest + " fits in available resources "
+              + freeResources);
+        }
+        return;
+      }
+      // highest priority request will not fit in existing free resources
+      // free up some more
+      // TODO this is subject to error wrt RM resource normalization
+      
+      numPendingRequestsToService = scaleDownByPreemptionPercentage(numHighestPriRequests);
+      
+      if (numPendingRequestsToService < 1) {
+        return;
+      }
+
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("Trying to service " + numPendingRequestsToService + " out of total "
+            + numHighestPriRequests + " pending requests at pri: "
+            + highestPriRequest.getPriority());
+      }
+      
+      for (int i=0; i<numPendingRequestsToService; ++i) {
         // This request must have been considered for matching with all existing 
         // containers when request was made.
         Container lowestPriNewContainer = null;
@@ -1093,6 +1150,7 @@
               " with priority: " + lowestPriNewContainer.getPriority() + 
               " to free resource for request: " + highestPriRequest +
               " . Current free resources: " + freeResources);
+          numPendingRequestsToService--;
           releaseUnassignedContainers(Collections.singletonList(lowestPriNewContainer));
           // We are returning an unused resource back the RM. The RM thinks it 
           // has serviced our initial request and will not re-allocate this back
@@ -1116,54 +1174,91 @@
               break;
             }
           }
-          
+          // come back and free more new containers if needed
+          continue;
+        }
+      }
+      
+      if (numPendingRequestsToService < 1) {
+        return;
+      }
+
+      // there are no reused or new containers to release. try to preempt running containers
+      // this assert will be a no-op in production but can help identify 
+      // invalid assumptions during testing
+      assert delayedContainerManager.delayedContainers.isEmpty();
+      
+      if ((numHeartbeats - heartbeatAtLastPreemption) <= numHeartbeatsBetweenPreemptions) {
+        return;
+      }
+        
+      Priority preemptedTaskPriority = null;
+      int numEntriesAtPreemptedPriority = 0;
+      for(Map.Entry<Object, Container> entry : taskAllocations.entrySet()) {
+        HeldContainer heldContainer = heldContainers.get(entry.getValue().getId());
+        CookieContainerRequest lastTaskInfo = heldContainer.getLastTaskInfo();
+        Priority taskPriority = lastTaskInfo.getPriority();
+        Object signature = lastTaskInfo.getCookie().getContainerSignature();
+        if(!isHigherPriority(highestPriRequest.getPriority(), taskPriority)) {
+          // higher or same priority
+          continue;
+        }
+        if (containerSignatureMatcher.isExactMatch(
+            highestPriRequest.getCookie().getContainerSignature(),
+            signature)) {
+          // exact match with different priorities
+          continue;
+        }
+        if (preemptedTaskPriority == null ||
+            !isHigherPriority(taskPriority, preemptedTaskPriority)) {
+          // keep the lower priority
+          preemptedTaskPriority = taskPriority;
+          if (taskPriority.equals(preemptedTaskPriority)) {
+            numEntriesAtPreemptedPriority++;
+          } else {
+            // this is at a lower priority than existing
+            numEntriesAtPreemptedPriority = 1;
+          }
+        }
+      }
+      if(preemptedTaskPriority != null) {
+        int newNumPendingRequestsToService = scaleDownByPreemptionPercentage(Math.min(
+            numEntriesAtPreemptedPriority, numHighestPriRequests));
+        numPendingRequestsToService = Math.min(newNumPendingRequestsToService,
+            numPendingRequestsToService);
+        if (numPendingRequestsToService < 1) {
           return;
         }
-        
-        // this assert will be a no-op in production but can help identify 
-        // invalid assumptions during testing
-        assert delayedContainerManager.delayedContainers.isEmpty();
-        
-        // there are no reused or new containers to release
-        // try to preempt running containers
-        Map.Entry<Object, Container> preemptedEntry = null;
-        for(Map.Entry<Object, Container> entry : taskAllocations.entrySet()) {
-          HeldContainer heldContainer = heldContainers.get(entry.getValue().getId());
-          CookieContainerRequest lastTaskInfo = heldContainer.getLastTaskInfo();
-          Priority taskPriority = lastTaskInfo.getPriority();
-          Object signature = lastTaskInfo.getCookie().getContainerSignature();
-          if(!isHigherPriority(highestPriRequest.getPriority(), taskPriority)) {
-            // higher or same priority
-            continue;
-          }
-          if (containerSignatureMatcher.isExactMatch(
-              highestPriRequest.getCookie().getContainerSignature(),
-              signature)) {
-            // exact match with different priorities
-            continue;
-          }
-          if(preemptedEntry == null ||
-             !isHigherPriority(taskPriority, 
-                 preemptedEntry.getValue().getPriority())) {
-            // keep the lower priority or the one added later
-            preemptedEntry = entry;
+        LOG.info("Trying to service " + numPendingRequestsToService + " out of total "
+            + numHighestPriRequests + " pending requests at pri: "
+            + highestPriRequest.getPriority() + " by preempting from "
+            + numEntriesAtPreemptedPriority + " running tasks at priority: " + preemptedTaskPriority);
+        // found something to preempt. get others of the same priority
+        preemptedContainers = new ContainerId[numPendingRequestsToService];
+        int currIndex = 0;
+        for (Map.Entry<Object, Container> entry : taskAllocations.entrySet()) {
+          Container container = entry.getValue();
+          if (preemptedTaskPriority.equals(container.getPriority())) {
+            // taskAllocations map will iterate from oldest to newest assigned containers
+            // keep the N newest containersIds with the matching priority
+            preemptedContainers[currIndex++ % numPendingRequestsToService] = container.getId();
           }
         }
-        if(preemptedEntry != null) {
-          // found something to preempt
-          LOG.info("Preempting task: " + preemptedEntry.getKey() +
-              " to free resource for request: " + highestPriRequest +
-              " . Current free resources: " + freeResources);
-          preemptedContainer = preemptedEntry.getValue().getId();
-          // app client will be notified when after container is killed
-          // and we get its completed container status
-        }
+        // app client will be notified when after container is killed
+        // and we get its completed container status
       }
     }
     
     // upcall outside locks
-    if (preemptedContainer != null) {
-      appClientDelegate.preemptContainer(preemptedContainer);
+    if (preemptedContainers != null) {
+      heartbeatAtLastPreemption = numHeartbeats;
+      for(int i=0; i<numPendingRequestsToService; ++i) {
+        ContainerId cId = preemptedContainers[i];
+        if (cId != null) {
+          LOG.info("Preempting container: " + cId + " currently allocated to a task.");
+          appClientDelegate.preemptContainer(cId);
+        }
+      }
     }
   }
 
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskScheduler.java b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskScheduler.java
index bd53f87..3afff7c 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskScheduler.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskScheduler.java
@@ -1202,7 +1202,7 @@
   }
 
   @SuppressWarnings({ "unchecked", "rawtypes" })
-  @Test
+  @Test (timeout=5000)
   public void testTaskSchedulerPreemption() throws Exception {
     RackResolver.init(new YarnConfiguration());
     TaskSchedulerAppCallback mockApp = mock(TaskSchedulerAppCallback.class);
@@ -1248,7 +1248,8 @@
     Object mockTask3 = mock(Object.class);
     Object mockTask3Wait = mock(Object.class);
     Object mockTask3Retry = mock(Object.class);
-    Object mockTask3Kill = mock(Object.class);
+    Object mockTask3KillA = mock(Object.class);
+    Object mockTask3KillB = mock(Object.class);
     Object obj3 = new Object();
     Priority pri2 = Priority.newInstance(2);
     Priority pri4 = Priority.newInstance(4);
@@ -1275,12 +1276,19 @@
     addContainerRequest(requestCaptor.capture());
     anyContainers.add(requestCaptor.getValue());
     // later one in the allocation gets killed between the two task3's
-    scheduler.allocateTask(mockTask3Kill, taskAsk, null,
+    scheduler.allocateTask(mockTask3KillA, taskAsk, null,
                            null, pri6, obj3, null);
     drainableAppCallback.drain();
     verify(mockRMClient, times(3)).
     addContainerRequest(requestCaptor.capture());
     anyContainers.add(requestCaptor.getValue());
+    // later one in the allocation gets killed between the two task3's
+    scheduler.allocateTask(mockTask3KillB, taskAsk, null,
+                           null, pri6, obj3, null);
+    drainableAppCallback.drain();
+    verify(mockRMClient, times(4)).
+    addContainerRequest(requestCaptor.capture());
+    anyContainers.add(requestCaptor.getValue());
 
     Resource freeResource = Resource.newInstance(500, 0);
     when(mockRMClient.getAvailableResources()).thenReturn(freeResource);
@@ -1310,13 +1318,20 @@
     ContainerId mockCId2 = mock(ContainerId.class);
     when(mockContainer2.getId()).thenReturn(mockCId2);
     containers.add(mockContainer2);
-    Container mockContainer3 = mock(Container.class, RETURNS_DEEP_STUBS);
-    when(mockContainer3.getNodeId().getHost()).thenReturn("host1");
-    when(mockContainer3.getResource()).thenReturn(taskAsk);
-    when(mockContainer3.getPriority()).thenReturn(pri6);
-    ContainerId mockCId3 = mock(ContainerId.class);
-    when(mockContainer3.getId()).thenReturn(mockCId3);
-    containers.add(mockContainer3);
+    Container mockContainer3A = mock(Container.class, RETURNS_DEEP_STUBS);
+    when(mockContainer3A.getNodeId().getHost()).thenReturn("host1");
+    when(mockContainer3A.getResource()).thenReturn(taskAsk);
+    when(mockContainer3A.getPriority()).thenReturn(pri6);
+    ContainerId mockCId3A = mock(ContainerId.class);
+    when(mockContainer3A.getId()).thenReturn(mockCId3A);
+    containers.add(mockContainer3A);
+    Container mockContainer3B = mock(Container.class, RETURNS_DEEP_STUBS);
+    when(mockContainer3B.getNodeId().getHost()).thenReturn("host1");
+    when(mockContainer3B.getResource()).thenReturn(taskAsk);
+    when(mockContainer3B.getPriority()).thenReturn(pri6);
+    ContainerId mockCId3B = mock(ContainerId.class);
+    when(mockContainer3B.getId()).thenReturn(mockCId3B);
+    containers.add(mockContainer3B);
     when(
         mockRMClient.getMatchingRequests((Priority) any(), eq("host1"),
             (Resource) any())).thenAnswer(
@@ -1368,14 +1383,16 @@
     
     scheduler.onContainersAllocated(containers);
     drainableAppCallback.drain();
-    Assert.assertEquals(3, scheduler.taskAllocations.size());
-    Assert.assertEquals(3072, scheduler.allocatedResources.getMemory());
+    Assert.assertEquals(4, scheduler.taskAllocations.size());
+    Assert.assertEquals(4096, scheduler.allocatedResources.getMemory());
     Assert.assertEquals(mockCId1,
         scheduler.taskAllocations.get(mockTask1).getId());
     Assert.assertEquals(mockCId2,
         scheduler.taskAllocations.get(mockTask3).getId());
-    Assert.assertEquals(mockCId3,
-        scheduler.taskAllocations.get(mockTask3Kill).getId());
+    Assert.assertEquals(mockCId3A,
+        scheduler.taskAllocations.get(mockTask3KillA).getId());
+    Assert.assertEquals(mockCId3B,
+        scheduler.taskAllocations.get(mockTask3KillB).getId());
 
     // no preemption
     scheduler.getProgress();
@@ -1419,7 +1436,7 @@
     drainableAppCallback.drain();
     verify(mockRMClient, times(1)).releaseAssignedContainer((ContainerId)any());
     verify(mockRMClient, times(1)).releaseAssignedContainer(mockCId4);
-    verify(mockRMClient, times(4)).
+    verify(mockRMClient, times(5)).
     addContainerRequest(requestCaptor.capture());
     CookieContainerRequest reAdded = requestCaptor.getValue();
     Assert.assertEquals(pri6, reAdded.getPriority());
@@ -1436,15 +1453,27 @@
     drainableAppCallback.drain();
     verify(mockRMClient, times(1)).releaseAssignedContainer((ContainerId)any());
 
-    scheduler.allocateTask(mockTask2, taskAsk, null,
-                           null, pri4, null, null);
+    for (int i=0; i<11; ++i) {
+      scheduler.allocateTask(mockTask2, taskAsk, null,
+                             null, pri4, null, null);
+    }
     drainableAppCallback.drain();
 
-    // mockTaskPri3Kill gets preempted
+    // mockTaskPri3KillB gets preempted to clear 10% of outstanding running preemptable tasks
     scheduler.getProgress();
     drainableAppCallback.drain();
     verify(mockRMClient, times(2)).releaseAssignedContainer((ContainerId)any());
-    verify(mockRMClient, times(1)).releaseAssignedContainer(mockCId3);
+    verify(mockRMClient, times(1)).releaseAssignedContainer(mockCId3B);
+    // next 3 heartbeats do nothing, waiting for the RM to act on the last released resources
+    scheduler.getProgress();
+    scheduler.getProgress();
+    scheduler.getProgress();
+    verify(mockRMClient, times(2)).releaseAssignedContainer((ContainerId)any());
+    scheduler.getProgress();
+    drainableAppCallback.drain();
+    // Next oldest mockTaskPri3KillA gets preempted to clear 10% of outstanding running preemptable tasks
+    verify(mockRMClient, times(3)).releaseAssignedContainer((ContainerId)any());
+    verify(mockRMClient, times(1)).releaseAssignedContainer(mockCId3A);
 
     AppFinalStatus finalStatus =
         new AppFinalStatus(FinalApplicationStatus.SUCCEEDED, "", appUrl);