TEZ-4027. DagAwareYarnTaskScheduler can miscompute blocked vertices and cause a hang

Signed-off-by: Jason Lowe <jlowe@apache.org>
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/DagAwareYarnTaskScheduler.java b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/DagAwareYarnTaskScheduler.java
index 167d879..1cdc217 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/DagAwareYarnTaskScheduler.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/DagAwareYarnTaskScheduler.java
@@ -1899,12 +1899,12 @@
     // scheduled due to outstanding requests from higher priority predecessor vertices.
     @GuardedBy("DagAwareYarnTaskScheduler.this")
     BitSet createVertexBlockedSet() {
-      BitSet blocked = new BitSet();
+      BitSet blocked = new BitSet(vertexDescendants.size());
       Entry<Priority, RequestPriorityStats> entry = priorityStats.lastEntry();
       if (entry != null) {
         RequestPriorityStats stats = entry.getValue();
         blocked.or(stats.allowedVertices);
-        blocked.flip(0, blocked.length());
+        blocked.flip(0, blocked.size());
         blocked.or(stats.descendants);
       }
       return blocked;
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestDagAwareYarnTaskScheduler.java b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestDagAwareYarnTaskScheduler.java
index 0910ed2..911f4b1 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestDagAwareYarnTaskScheduler.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestDagAwareYarnTaskScheduler.java
@@ -1228,6 +1228,115 @@
     verify(mockRMClient).stop();
   }
 
+  @Test (timeout = 50000L)
+  public void testPreemptionWhenBlocked() throws Exception {
+    AMRMClientAsyncWrapperForTest mockRMClient = spy(new AMRMClientAsyncWrapperForTest());
+
+    String appHost = "host";
+    int appPort = 0;
+    String appUrl = "url";
+
+    Configuration conf = new Configuration();
+    conf.setBoolean(TezConfiguration.TEZ_AM_CONTAINER_REUSE_ENABLED, true);
+    conf.setInt(TezConfiguration.TEZ_AM_CONTAINER_REUSE_LOCALITY_DELAY_ALLOCATION_MILLIS, 100);
+    conf.setBoolean(TezConfiguration.TEZ_AM_CONTAINER_REUSE_RACK_FALLBACK_ENABLED, true);
+    conf.setBoolean(TezConfiguration.TEZ_AM_CONTAINER_REUSE_NON_LOCAL_FALLBACK_ENABLED, false);
+    conf.setInt(TezConfiguration.TEZ_AM_RM_HEARTBEAT_INTERVAL_MS_MAX, 100);
+    conf.setInt(TezConfiguration.TEZ_AM_PREEMPTION_PERCENTAGE, 10);
+    conf.setInt(TezConfiguration.TEZ_AM_PREEMPTION_HEARTBEATS_BETWEEN_PREEMPTIONS, 3);
+    conf.setInt(TezConfiguration.TEZ_AM_PREEMPTION_MAX_WAIT_TIME_MS, 60 * 1000);
+
+
+    DagInfo mockDagInfo = mock(DagInfo.class);
+    when(mockDagInfo.getTotalVertices()).thenReturn(3);
+    when(mockDagInfo.getVertexDescendants(0)).thenReturn(BitSet.valueOf(new long[] { 0x6 }));
+    when(mockDagInfo.getVertexDescendants(1)).thenReturn(BitSet.valueOf(new long[] { 0x2 }));
+    when(mockDagInfo.getVertexDescendants(2)).thenReturn(new BitSet());
+    TaskSchedulerContext mockApp = setupMockTaskSchedulerContext(appHost, appPort, appUrl, conf);
+    when(mockApp.getCurrentDagInfo()).thenReturn(mockDagInfo);
+    TaskSchedulerContextDrainable drainableAppCallback = createDrainableContext(mockApp);
+
+    MockClock clock = new MockClock(1000);
+    NewTaskSchedulerForTest scheduler = new NewTaskSchedulerForTest(drainableAppCallback,
+        mockRMClient, clock);
+
+    scheduler.initialize();
+    drainableAppCallback.drain();
+
+    scheduler.start();
+    drainableAppCallback.drain();
+    verify(mockRMClient).start();
+    verify(mockRMClient).registerApplicationMaster(appHost, appPort, appUrl);
+    RegisterApplicationMasterResponse regResponse = mockRMClient.getRegistrationResponse();
+    verify(mockApp).setApplicationRegistrationData(regResponse.getMaximumResourceCapability(),
+        regResponse.getApplicationACLs(), regResponse.getClientToAMTokenMasterKey(),
+        regResponse.getQueue());
+
+    assertEquals(scheduler.getClusterNodeCount(), mockRMClient.getClusterNodeCount());
+
+    Priority priorityv0 = Priority.newInstance(1);
+    Priority priorityv2 = Priority.newInstance(3);
+    String[] hostsv0t0 = { "host1", "host2" };
+    MockTaskInfo taskv0t0 = new MockTaskInfo("taskv0t0", priorityv0, hostsv0t0);
+    when(mockApp.getVertexIndexForTask(taskv0t0.task)).thenReturn(0);
+    MockTaskInfo taskv0t1 = new MockTaskInfo("taskv0t1", priorityv0, hostsv0t0);
+    when(mockApp.getVertexIndexForTask(taskv0t1.task)).thenReturn(0);
+    MockTaskInfo taskv2t0 = new MockTaskInfo("taskv2t0", priorityv2, hostsv0t0);
+    when(mockApp.getVertexIndexForTask(taskv2t0.task)).thenReturn(2);
+    MockTaskInfo taskv2t1 = new MockTaskInfo("taskv2t1", priorityv2, hostsv0t0);
+    when(mockApp.getVertexIndexForTask(taskv2t1.task)).thenReturn(2);
+    when(mockApp.getVertexIndexForTask(taskv2t0.task)).thenReturn(2);
+
+    // asks for one task for vertex 2 and start running
+    TaskRequestCaptor taskRequestCaptor = new TaskRequestCaptor(mockRMClient,
+        scheduler, drainableAppCallback);
+    TaskRequest reqv2t0 = taskRequestCaptor.scheduleTask(taskv2t0);
+    NodeId host1 = NodeId.newInstance("host1", 1);
+    ApplicationAttemptId attemptId = ApplicationAttemptId.newInstance(ApplicationId.newInstance(1, 1), 1);
+    ContainerId cid1 = ContainerId.newContainerId(attemptId, 1);
+    Container container1 = Container.newInstance(cid1, host1, null, taskv2t0.capability, priorityv2, null);
+    scheduler.onContainersAllocated(Collections.singletonList(container1));
+    drainableAppCallback.drain();
+    verify(mockApp).taskAllocated(taskv2t0.task, taskv2t0.cookie, container1);
+    verify(mockRMClient).removeContainerRequest(reqv2t0);
+    clock.incrementTime(1000);
+
+    when(mockRMClient.getAvailableResources()).thenReturn(Resources.none());
+    scheduler.getProgress();
+    scheduler.getProgress();
+    scheduler.getProgress();
+    drainableAppCallback.drain();
+    //ask another task for v2
+    TaskRequest reqv2t1 = taskRequestCaptor.scheduleTask(taskv2t1);
+    scheduler.getProgress();
+    scheduler.getProgress();
+    scheduler.getProgress();
+    drainableAppCallback.drain();
+
+    clock.incrementTime(1000);
+    // add a request for vertex 0 but there is no headroom, this should preempt
+    when(mockRMClient.getAvailableResources()).thenReturn(Resources.none());
+    TaskRequest reqv0t0 = taskRequestCaptor.scheduleTask(taskv0t0);
+
+    // should preempt after enough heartbeats to get past preemption interval
+    scheduler.getProgress();
+    scheduler.getProgress();
+    scheduler.getProgress();
+    drainableAppCallback.drain();
+    verify(mockApp, times(1)).preemptContainer(any(ContainerId.class));
+    verify(mockApp).preemptContainer(cid1);
+    String appMsg = "success";
+    AppFinalStatus finalStatus =
+        new AppFinalStatus(FinalApplicationStatus.SUCCEEDED, appMsg, appUrl);
+    when(mockApp.getFinalAppStatus()).thenReturn(finalStatus);
+    scheduler.shutdown();
+    drainableAppCallback.drain();
+    verify(mockRMClient).
+        unregisterApplicationMaster(FinalApplicationStatus.SUCCEEDED,
+            appMsg, appUrl);
+    verify(mockRMClient).stop();
+  }
+
   @Test(timeout=50000)
   public void testContainerAssignmentReleaseNewContainers() throws Exception {
     AMRMClientAsyncWrapperForTest mockRMClient = spy(new AMRMClientAsyncWrapperForTest());