YARN-10468. Fix TestNodeStatusUpdater timeouts and broken conditions (#2461)

diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/TestNodeStatusUpdater.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/TestNodeStatusUpdater.java
index 48ce704..b3c4014 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/TestNodeStatusUpdater.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/TestNodeStatusUpdater.java
@@ -20,7 +20,6 @@
 
 import static org.apache.hadoop.yarn.server.utils.YarnServerBuilderUtils.newNodeHeartbeatResponse;
 import static org.assertj.core.api.Assertions.assertThat;
-import static org.junit.Assert.assertEquals;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
@@ -63,6 +62,7 @@
 import org.apache.hadoop.service.Service.STATE;
 import org.apache.hadoop.test.GenericTestUtils;
 import org.apache.hadoop.service.ServiceOperations;
+import org.apache.hadoop.test.LambdaTestUtils;
 import org.apache.hadoop.util.concurrent.HadoopExecutors;
 import org.apache.hadoop.yarn.api.protocolrecords.SignalContainerRequest;
 import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
@@ -131,8 +131,8 @@
   /** Bytes in a GigaByte. */
   private static final long GB = 1024L * 1024L * 1024L;
 
-  volatile int heartBeatID = 0;
-  volatile Throwable nmStartError = null;
+  private volatile Throwable nmStartError = null;
+  private AtomicInteger heartBeatID = new AtomicInteger(0);
   private final List<NodeId> registeredNodes = new ArrayList<NodeId>();
   private boolean triggered = false;
   private NodeManager nm;
@@ -147,8 +147,12 @@
   @After
   public void tearDown() {
     this.registeredNodes.clear();
-    heartBeatID = 0;
-    ServiceOperations.stop(nm);
+    heartBeatID.set(0);
+    if (nm != null) {
+      ServiceOperations.stop(nm);
+      nm.waitForServiceToStop(10000);
+    }
+
     assertionFailedInThread.set(false);
     DefaultMetricsSystem.shutdown();
   }
@@ -220,7 +224,7 @@
       EventHandler<Event> mockEventHandler = mock(EventHandler.class);
       when(mockDispatcher.getEventHandler()).thenReturn(mockEventHandler);
       NMStateStoreService stateStore = new NMNullStateStoreService();
-      nodeStatus.setResponseId(heartBeatID++);
+      nodeStatus.setResponseId(heartBeatID.getAndIncrement());
       Map<ApplicationId, List<ContainerStatus>> appToContainers =
           getAppToContainerStatusMap(nodeStatus.getContainersStatuses());
       List<SignalContainerRequest> containersToSignal = null;
@@ -229,14 +233,14 @@
       ApplicationId appId2 = ApplicationId.newInstance(0, 2);
 
       ContainerId firstContainerID = null;
-      if (heartBeatID == 1) {
+      if (heartBeatID.get() == 1) {
         Assert.assertEquals(0, nodeStatus.getContainersStatuses().size());
 
         // Give a container to the NM.
         ApplicationAttemptId appAttemptID =
             ApplicationAttemptId.newInstance(appId1, 0);
         firstContainerID =
-            ContainerId.newContainerId(appAttemptID, heartBeatID);
+            ContainerId.newContainerId(appAttemptID, heartBeatID.get());
         ContainerLaunchContext launchContext = recordFactory
             .newRecordInstance(ContainerLaunchContext.class);
         Resource resource = BuilderUtils.newResource(2, 1);
@@ -252,7 +256,7 @@
         Container container = new ContainerImpl(conf, mockDispatcher,
             launchContext, null, mockMetrics, containerToken, context);
         this.context.getContainers().put(firstContainerID, container);
-      } else if (heartBeatID == 2) {
+      } else if (heartBeatID.get() == 2) {
         // Checks on the RM end
         Assert.assertEquals("Number of applications should only be one!", 1,
             nodeStatus.getContainersStatuses().size());
@@ -277,7 +281,7 @@
         ApplicationAttemptId appAttemptID =
             ApplicationAttemptId.newInstance(appId2, 0);
         ContainerId secondContainerID =
-            ContainerId.newContainerId(appAttemptID, heartBeatID);
+            ContainerId.newContainerId(appAttemptID, heartBeatID.get());
         ContainerLaunchContext launchContext = recordFactory
             .newRecordInstance(ContainerLaunchContext.class);
         long currentTime = System.currentTimeMillis();
@@ -293,7 +297,7 @@
         Container container = new ContainerImpl(conf, mockDispatcher,
             launchContext, null, mockMetrics, containerToken, context);
         this.context.getContainers().put(secondContainerID, container);
-      } else if (heartBeatID == 3) {
+      } else if (heartBeatID.get() == 3) {
         // Checks on the RM end
         Assert.assertEquals("Number of applications should have two!", 2,
             appToContainers.size());
@@ -309,8 +313,8 @@
       }
 
       NodeHeartbeatResponse nhResponse = YarnServerBuilderUtils.
-          newNodeHeartbeatResponse(heartBeatID, null, null, null, null, null,
-            1000L);
+          newNodeHeartbeatResponse(heartBeatID.get(), null, null, null, null,
+              null, 1000L);
       if (containersToSignal != null) {
         nhResponse.addAllContainersToSignal(containersToSignal);
       }
@@ -576,10 +580,10 @@
     public NodeHeartbeatResponse nodeHeartbeat(NodeHeartbeatRequest request)
         throws YarnException, IOException {
       NodeStatus nodeStatus = request.getNodeStatus();
-      nodeStatus.setResponseId(heartBeatID++);
+      nodeStatus.setResponseId(heartBeatID.getAndIncrement());
 
       NodeHeartbeatResponse nhResponse = YarnServerBuilderUtils.
-          newNodeHeartbeatResponse(heartBeatID, heartBeatNodeAction, null,
+          newNodeHeartbeatResponse(heartBeatID.get(), heartBeatNodeAction, null,
               null, null, null, 1000L);
       nhResponse.setDiagnosticsMessage(shutDownMessage);
       return nhResponse;
@@ -623,9 +627,9 @@
         throws YarnException, IOException {
       LOG.info("Got heartBeatId: [" + heartBeatID +"]");
       NodeStatus nodeStatus = request.getNodeStatus();
-      nodeStatus.setResponseId(heartBeatID++);
+      nodeStatus.setResponseId(heartBeatID.getAndIncrement());
       NodeHeartbeatResponse nhResponse = YarnServerBuilderUtils.
-          newNodeHeartbeatResponse(heartBeatID, heartBeatNodeAction, null,
+          newNodeHeartbeatResponse(heartBeatID.get(), heartBeatNodeAction, null,
               null, null, null, 1000L);
 
       if (nodeStatus.getKeepAliveApplications() != null
@@ -639,7 +643,7 @@
           list.add(System.currentTimeMillis());
         }
       }
-      if (heartBeatID == 2) {
+      if (heartBeatID.get() == 2) {
         LOG.info("Sending FINISH_APP for application: [" + appId + "]");
         this.context.getApplications().put(appId, mock(Application.class));
         nhResponse.addAllApplicationsToCleanup(Collections.singletonList(appId));
@@ -698,11 +702,11 @@
       List<ContainerId> finishedContainersPulledByAM = new ArrayList
           <ContainerId>();
       try {
-        if (heartBeatID == 0) {
+        if (heartBeatID.get() == 0) {
           Assert.assertEquals(0, request.getNodeStatus().getContainersStatuses()
             .size());
           Assert.assertEquals(0, context.getContainers().size());
-        } else if (heartBeatID == 1) {
+        } else if (heartBeatID.get() == 1) {
           List<ContainerStatus> statuses =
               request.getNodeStatus().getContainersStatuses();
           Assert.assertEquals(2, statuses.size());
@@ -712,14 +716,14 @@
           for (ContainerStatus status : statuses) {
             if (status.getContainerId().equals(
               containerStatus2.getContainerId())) {
-              Assert.assertTrue(status.getState().equals(
-                containerStatus2.getState()));
+              Assert.assertEquals(containerStatus2.getState(),
+                  status.getState());
               container2Exist = true;
             }
             if (status.getContainerId().equals(
               containerStatus3.getContainerId())) {
-              Assert.assertTrue(status.getState().equals(
-                containerStatus3.getState()));
+              Assert.assertEquals(containerStatus3.getState(),
+                  status.getState());
               container3Exist = true;
             }
           }
@@ -729,7 +733,7 @@
           // nodeStatusUpdaterRunnable, otherwise nm just shuts down and the
           // test passes.
           throw new YarnRuntimeException("Lost the heartbeat response");
-        } else if (heartBeatID == 2 || heartBeatID == 3) {
+        } else if (heartBeatID.get() == 2 || heartBeatID.get() == 3) {
           List<ContainerStatus> statuses =
               request.getNodeStatus().getContainersStatuses();
           // NM should send completed containers on heartbeat 2,
@@ -744,36 +748,36 @@
           for (ContainerStatus status : statuses) {
             if (status.getContainerId().equals(
               containerStatus2.getContainerId())) {
-              Assert.assertTrue(status.getState().equals(
-                containerStatus2.getState()));
+              Assert.assertEquals(containerStatus2.getState(),
+                  status.getState());
               container2Exist = true;
             }
             if (status.getContainerId().equals(
               containerStatus3.getContainerId())) {
-              Assert.assertTrue(status.getState().equals(
-                containerStatus3.getState()));
+              Assert.assertEquals(containerStatus3.getState(),
+                  status.getState());
               container3Exist = true;
             }
             if (status.getContainerId().equals(
               containerStatus4.getContainerId())) {
-              Assert.assertTrue(status.getState().equals(
-                containerStatus4.getState()));
+              Assert.assertEquals(containerStatus4.getState(),
+                  status.getState());
               container4Exist = true;
             }
             if (status.getContainerId().equals(
               containerStatus5.getContainerId())) {
-              Assert.assertTrue(status.getState().equals(
-                containerStatus5.getState()));
+              Assert.assertEquals(containerStatus5.getState(),
+                  status.getState());
               container5Exist = true;
             }
           }
           Assert.assertTrue(container2Exist && container3Exist
               && container4Exist && container5Exist);
 
-          if (heartBeatID == 3) {
+          if (heartBeatID.get() == 3) {
             finishedContainersPulledByAM.add(containerStatus3.getContainerId());
           }
-        } else if (heartBeatID == 4) {
+        } else if (heartBeatID.get() == 4) {
           List<ContainerStatus> statuses =
               request.getNodeStatus().getContainersStatuses();
           Assert.assertEquals(2, statuses.size());
@@ -793,12 +797,12 @@
         error.printStackTrace();
         assertionFailedInThread.set(true);
       } finally {
-        heartBeatID++;
+        heartBeatID.incrementAndGet();
       }
       NodeStatus nodeStatus = request.getNodeStatus();
-      nodeStatus.setResponseId(heartBeatID);
+      nodeStatus.setResponseId(heartBeatID.get());
       NodeHeartbeatResponse nhResponse =
-          YarnServerBuilderUtils.newNodeHeartbeatResponse(heartBeatID,
+          YarnServerBuilderUtils.newNodeHeartbeatResponse(heartBeatID.get(),
             heartBeatNodeAction, null, null, null, null, 1000L);
       nhResponse.addContainersToBeRemovedFromNM(finishedContainersPulledByAM);
       Map<ApplicationId, ByteBuffer> appCredentials =
@@ -839,8 +843,7 @@
     @Override
     public NodeHeartbeatResponse nodeHeartbeat(NodeHeartbeatRequest request)
         throws YarnException, IOException {
-      heartBeatID++;
-      if(heartBeatID == 1) {
+      if (heartBeatID.incrementAndGet() == 1) {
         // EOFException should be retried as well.
         throw new EOFException("NodeHeartbeat exception");
       }
@@ -909,10 +912,10 @@
     public NodeHeartbeatResponse nodeHeartbeat(NodeHeartbeatRequest request)
         throws YarnException, IOException {
       NodeStatus nodeStatus = request.getNodeStatus();
-      nodeStatus.setResponseId(heartBeatID++);
+      nodeStatus.setResponseId(heartBeatID.getAndIncrement());
 
       NodeHeartbeatResponse nhResponse = YarnServerBuilderUtils.
-          newNodeHeartbeatResponse(heartBeatID, NodeAction.NORMAL, null,
+          newNodeHeartbeatResponse(heartBeatID.get(), NodeAction.NORMAL, null,
               null, null, null, 1000L);
       return nhResponse;
     }
@@ -1141,7 +1144,7 @@
   }
 
   @Test
-  public void testNMRegistration() throws InterruptedException, IOException {
+  public void testNMRegistration() throws Exception {
     nm = new NodeManager() {
       @Override
       protected NodeStatusUpdater createNodeStatusUpdater(Context context,
@@ -1161,43 +1164,32 @@
     Assert.assertTrue("last service is NOT the node status updater",
         lastService instanceof NodeStatusUpdater);
 
-    new Thread() {
-      public void run() {
-        try {
-          nm.start();
-        } catch (Throwable e) {
-          TestNodeStatusUpdater.this.nmStartError = e;
-          throw new YarnRuntimeException(e);
-        }
+    Thread starterThread = new Thread(() -> {
+      try {
+        nm.start();
+      } catch (Throwable e) {
+        TestNodeStatusUpdater.this.nmStartError = e;
+        throw new YarnRuntimeException(e);
       }
-    }.start();
+    });
+    starterThread.start();
 
-    System.out.println(" ----- thread already started.."
-        + nm.getServiceState());
+    LOG.info(" ----- thread already started..{}", nm.getServiceState());
 
-    int waitCount = 0;
-    while (nm.getServiceState() == STATE.INITED && waitCount++ != 50) {
-      LOG.info("Waiting for NM to start..");
-      if (nmStartError != null) {
-        LOG.error("Error during startup. ", nmStartError);
-        Assert.fail(nmStartError.getCause().getMessage());
-      }
-      Thread.sleep(2000);
-    }
-    if (nm.getServiceState() != STATE.STARTED) {
-      // NM could have failed.
-      Assert.fail("NodeManager failed to start");
+    starterThread.join(100000);
+
+    if (nmStartError != null) {
+      LOG.error("Error during startup. ", nmStartError);
+      Assert.fail(nmStartError.getCause().getMessage());
     }
 
-    waitCount = 0;
-    while (heartBeatID <= 3 && waitCount++ != 200) {
-      Thread.sleep(1000);
-    }
-    Assert.assertFalse(heartBeatID <= 3);
-    Assert.assertEquals("Number of registered NMs is wrong!!", 1,
-        this.registeredNodes.size());
+    GenericTestUtils.waitFor(
+        () -> nm.getServiceState() != STATE.STARTED || heartBeatID.get() > 3,
+        50, 20000);
 
-    nm.stop();
+    Assert.assertTrue(heartBeatID.get() > 3);
+    Assert.assertEquals("Number of registered NMs is wrong!!",
+        1, this.registeredNodes.size());
   }
 
   @Test
@@ -1236,31 +1228,23 @@
     YarnConfiguration conf = createNMConfig();
     nm.init(conf);
     nm.start();
-
-    int waitCount = 0;
-    while (heartBeatID < 1 && waitCount++ != 200) {
-      Thread.sleep(500);
-    }
-    Assert.assertFalse(heartBeatID < 1);
+    GenericTestUtils.waitFor(() -> nm.getServiceState() == STATE.STARTED,
+        20, 10000);
+    GenericTestUtils.waitFor(
+        () -> nm.getServiceState() != STATE.STARTED || heartBeatID.get() >= 1,
+        50, 20000);
+    Assert.assertTrue(heartBeatID.get() >= 1);
 
     // Meanwhile call stop directly as the shutdown hook would
     nm.stop();
 
     // NM takes a while to reach the STOPPED state.
-    waitCount = 0;
-    while (nm.getServiceState() != STATE.STOPPED && waitCount++ != 20) {
-      LOG.info("Waiting for NM to stop..");
-      Thread.sleep(1000);
-    }
+    nm.waitForServiceToStop(20000);
 
     Assert.assertEquals(STATE.STOPPED, nm.getServiceState());
 
     // It further takes a while after NM reached the STOPPED state.
-    waitCount = 0;
-    while (numCleanups.get() == 0 && waitCount++ != 20) {
-      LOG.info("Waiting for NM shutdown..");
-      Thread.sleep(1000);
-    }
+    GenericTestUtils.waitFor(() -> numCleanups.get() > 0, 20, 20000);
     Assert.assertEquals(1, numCleanups.get());
   }
 
@@ -1271,20 +1255,22 @@
     nm.init(conf);
     Assert.assertEquals(STATE.INITED, nm.getServiceState());
     nm.start();
-
-    int waitCount = 0;
-    while (heartBeatID < 1 && waitCount++ != 200) {
-      Thread.sleep(500);
-    }
-    Assert.assertFalse(heartBeatID < 1);
+    GenericTestUtils.waitFor(() -> nm.getServiceState() == STATE.STARTED,
+        20, 10000);
+    GenericTestUtils.waitFor(
+        () -> {
+          if (nm.getServiceState() == STATE.STARTED) {
+            return (heartBeatID.get() >= 1
+                && nm.getNMContext().getDecommissioned());
+          }
+          return true;
+        },
+        50, 200000);
+    Assert.assertTrue(heartBeatID.get() >= 1);
     Assert.assertTrue(nm.getNMContext().getDecommissioned());
 
     // NM takes a while to reach the STOPPED state.
-    waitCount = 0;
-    while (nm.getServiceState() != STATE.STOPPED && waitCount++ != 20) {
-      LOG.info("Waiting for NM to stop..");
-      Thread.sleep(1000);
-    }
+    nm.waitForServiceToStop(20000);
 
     Assert.assertEquals(STATE.STOPPED, nm.getServiceState());
   }
@@ -1529,9 +1515,14 @@
       nm.init(conf);
       nm.start();
       // HB 2 -> app cancelled by RM.
-      while (heartBeatID < 12) {
-        Thread.sleep(1000l);
-      }
+      GenericTestUtils.waitFor(() -> nm.getServiceState() == STATE.STARTED, 20,
+          10000);
+      GenericTestUtils.waitFor(
+          () -> nm.getServiceState() != STATE.STARTED
+              || heartBeatID.get() >= 12,
+          100L, 60000000);
+
+      Assert.assertTrue(heartBeatID.get() >= 12);
       MyResourceTracker3 rt =
           (MyResourceTracker3) nm.getNodeStatusUpdater().getRMClient();
       rt.context.getApplications().remove(rt.appId);
@@ -1539,14 +1530,18 @@
       int numKeepAliveRequests = rt.keepAliveRequests.get(rt.appId).size();
       LOG.info("Number of Keep Alive Requests: [" + numKeepAliveRequests + "]");
       Assert.assertTrue(numKeepAliveRequests == 2 || numKeepAliveRequests == 3);
-      while (heartBeatID < 20) {
-        Thread.sleep(1000l);
-      }
+      GenericTestUtils.waitFor(
+          () -> nm.getServiceState() != STATE.STARTED
+              || heartBeatID.get() >= 20,
+          100L, 60000000);
+      Assert.assertTrue(heartBeatID.get() >= 20);
       int numKeepAliveRequests2 = rt.keepAliveRequests.get(rt.appId).size();
       Assert.assertEquals(numKeepAliveRequests, numKeepAliveRequests2);
     } finally {
-      if (nm.getServiceState() == STATE.STARTED)
+      if (nm != null) {
         nm.stop();
+        nm.waitForServiceToStop(10000);
+      }
     }
   }
 
@@ -1581,20 +1576,19 @@
     nm.init(conf);
     nm.start();
 
-    int waitCount = 0;
-    while (heartBeatID <= 4 && waitCount++ != 20) {
-      Thread.sleep(500);
-    }
-    if (heartBeatID <= 4) {
-      Assert.fail("Failed to get all heartbeats in time, " +
-          "heartbeatID:" + heartBeatID);
-    }
-    if(assertionFailedInThread.get()) {
-      Assert.fail("ContainerStatus Backup failed");
-    }
+    GenericTestUtils.waitFor(() -> nm.getServiceState() == STATE.STARTED,
+        20, 10000);
+
+    GenericTestUtils.waitFor(
+        () -> nm.getServiceState() != STATE.STARTED || heartBeatID.get() > 4,
+        50, 20000);
+    int hbID = heartBeatID.get();
+    Assert.assertFalse("Failed to get all heartbeats in time, "
+        + "heartbeatID:" + hbID, hbID <= 4);
+    Assert.assertFalse("ContainerStatus Backup failed",
+        assertionFailedInThread.get());
     Assert.assertNotNull(nm.getNMContext().getSystemCredentialsForApps()
       .get(ApplicationId.newInstance(1234, 1)).getToken(new Text("token1")));
-    nm.stop();
   }
 
   @Test(timeout = 200000)
@@ -1631,13 +1625,12 @@
     Assert.assertFalse("Containers not cleaned up when NM stopped",
       assertionFailedInThread.get());
     Assert.assertTrue(((MyNodeManager2) nm).isStopped);
-    Assert.assertTrue("calculate heartBeatCount based on" +
-        " connectionWaitSecs and RetryIntervalSecs", heartBeatID == 2);
+    Assert.assertEquals("calculate heartBeatCount based on" +
+        " connectionWaitSecs and RetryIntervalSecs", 2, heartBeatID.get());
   }
 
   @Test
-  public void testRMVersionLessThanMinimum() throws InterruptedException,
-      IOException {
+  public void testRMVersionLessThanMinimum() throws Exception {
     final AtomicInteger numCleanups = new AtomicInteger(0);
     YarnConfiguration conf = createNMConfig();
     conf.set(YarnConfiguration.NM_RESOURCEMANAGER_MINIMUM_VERSION, "3.0.0");
@@ -1674,15 +1667,9 @@
 
     nm.init(conf);
     nm.start();
-
     // NM takes a while to reach the STARTED state.
-    int waitCount = 0;
-    while (nm.getServiceState() != STATE.STARTED && waitCount++ != 20) {
-      LOG.info("Waiting for NM to stop..");
-      Thread.sleep(1000);
-    }
-    Assert.assertTrue(nm.getServiceState() == STATE.STARTED);
-    nm.stop();
+    GenericTestUtils.waitFor(() -> nm.getServiceState() == STATE.STARTED,
+        20, 200000);
   }
 
 
@@ -1712,37 +1699,20 @@
     YarnConfiguration conf = createNMConfig();
     nm.init(conf);
     nm.start();
+    GenericTestUtils.waitFor(() -> nm.getServiceState() == STATE.STARTED,
+        20, 20000);
 
-    System.out.println(" ----- thread already started.."
-        + nm.getServiceState());
-
-    int waitCount = 0;
-    while (nm.getServiceState() == STATE.INITED && waitCount++ != 20) {
-      LOG.info("Waiting for NM to start..");
-      if (nmStartError != null) {
-        LOG.error("Error during startup. ", nmStartError);
-        Assert.fail(nmStartError.getCause().getMessage());
-      }
-      Thread.sleep(1000);
-    }
-    if (nm.getServiceState() != STATE.STARTED) {
-      // NM could have failed.
-      Assert.fail("NodeManager failed to start");
-    }
-
-    waitCount = 0;
-    while (heartBeatID <= 3 && waitCount++ != 20) {
-      Thread.sleep(500);
-    }
-    Assert.assertFalse(heartBeatID <= 3);
+    GenericTestUtils.waitFor(
+        () -> nm.getServiceState() != STATE.STARTED
+            || heartBeatID.get() > 3,
+        50, 20000);
+    Assert.assertTrue(heartBeatID.get() > 3);
     Assert.assertEquals("Number of registered NMs is wrong!!", 1,
         this.registeredNodes.size());
 
     MyContainerManager containerManager =
         (MyContainerManager)nm.getContainerManager();
     Assert.assertTrue(containerManager.signaled);
-
-    nm.stop();
   }
 
   @Test
@@ -1823,38 +1793,48 @@
     LOG.info("Start the Node Manager");
     NodeManager nodeManager = new NodeManager();
     YarnConfiguration nmConf = new YarnConfiguration();
-    nmConf.setSocketAddr(YarnConfiguration.RM_RESOURCE_TRACKER_ADDRESS,
-        resourceTracker.getListenerAddress());
-    nmConf.set(YarnConfiguration.NM_LOCALIZER_ADDRESS, "0.0.0.0:0");
-    nodeManager.init(nmConf);
-    nodeManager.start();
+    try {
+      nmConf.setSocketAddr(YarnConfiguration.RM_RESOURCE_TRACKER_ADDRESS,
+          resourceTracker.getListenerAddress());
+      nmConf.set(YarnConfiguration.NM_LOCALIZER_ADDRESS, "0.0.0.0:0");
+      nodeManager.init(nmConf);
+      nodeManager.start();
 
-    LOG.info("Initially the Node Manager should have the default resources");
-    ContainerManager containerManager = nodeManager.getContainerManager();
-    ContainersMonitor containerMonitor =
-        containerManager.getContainersMonitor();
-    assertEquals(8, containerMonitor.getVCoresAllocatedForContainers());
-    assertEquals(8 * GB, containerMonitor.getPmemAllocatedForContainers());
+      LOG.info("Initially the Node Manager should have the default resources");
+      ContainerManager containerManager = nodeManager.getContainerManager();
+      ContainersMonitor containerMonitor =
+          containerManager.getContainersMonitor();
+      Assert.assertEquals(8,
+          containerMonitor.getVCoresAllocatedForContainers());
+      Assert.assertEquals(8 * GB,
+          containerMonitor.getPmemAllocatedForContainers());
 
-    LOG.info("The first heartbeat should trigger a resource change to {}",
-        resource);
-    GenericTestUtils.waitFor(
-        () -> containerMonitor.getVCoresAllocatedForContainers() == 1,
-        100, 2 * 1000);
-    assertEquals(8 * GB, containerMonitor.getPmemAllocatedForContainers());
+      LOG.info("The first heartbeat should trigger a resource change to {}",
+          resource);
+      GenericTestUtils.waitFor(
+          () -> containerMonitor.getVCoresAllocatedForContainers() == 1,
+          100, 2 * 1000);
+      Assert.assertEquals(8 * GB,
+          containerMonitor.getPmemAllocatedForContainers());
 
-    resource.setVirtualCores(5);
-    resource.setMemorySize(4 * 1024);
-    LOG.info("Change the resources to {}", resource);
-    GenericTestUtils.waitFor(
-        () -> containerMonitor.getVCoresAllocatedForContainers() == 5,
-        100, 2 * 1000);
-    assertEquals(4 * GB, containerMonitor.getPmemAllocatedForContainers());
-
-    LOG.info("Cleanup");
-    nodeManager.stop();
-    nodeManager.close();
-    resourceTracker.stop();
+      resource.setVirtualCores(5);
+      resource.setMemorySize(4 * 1024);
+      LOG.info("Change the resources to {}", resource);
+      GenericTestUtils.waitFor(
+          () -> containerMonitor.getVCoresAllocatedForContainers() == 5,
+          100, 2 * 1000);
+      Assert.assertEquals(4 * GB,
+          containerMonitor.getPmemAllocatedForContainers());
+    } finally {
+      LOG.info("Cleanup");
+      nodeManager.stop();
+      try {
+        nodeManager.close();
+      } catch (IOException ex) {
+        LOG.error("Could not close the node manager", ex);
+      }
+      resourceTracker.stop();
+    }
   }
 
   /**
@@ -1908,9 +1888,9 @@
 
     @Override
     public ConcurrentMap<ContainerId, Container> getContainers() {
-      if (heartBeatID == 0) {
+      if (heartBeatID.get() == 0) {
         return containers;
-      } else if (heartBeatID == 1) {
+      } else if (heartBeatID.get() == 1) {
         ContainerStatus containerStatus2 =
             createContainerStatus(2, ContainerState.RUNNING);
         putMockContainer(containerStatus2);
@@ -1919,7 +1899,7 @@
             createContainerStatus(3, ContainerState.COMPLETE);
         putMockContainer(containerStatus3);
         return containers;
-      } else if (heartBeatID == 2) {
+      } else if (heartBeatID.get() == 2) {
         ContainerStatus containerStatus4 =
             createContainerStatus(4, ContainerState.RUNNING);
         putMockContainer(containerStatus4);
@@ -1928,7 +1908,7 @@
             createContainerStatus(5, ContainerState.COMPLETE);
         putMockContainer(containerStatus5);
         return containers;
-      } else if (heartBeatID == 3 || heartBeatID == 4) {
+      } else if (heartBeatID.get() == 3 || heartBeatID.get() == 4) {
         return containers;
       } else {
         containers.clear();
@@ -1978,22 +1958,16 @@
     Assert.assertNotNull("nm is null", nm);
     YarnConfiguration conf = createNMConfig();
     nm.init(conf);
-    try {
-      nm.start();
-      Assert.fail("NM should have failed to start. Didn't get exception!!");
-    } catch (Exception e) {
-      //the version in trunk looked in the cause for equality
-      // and assumed failures were nested.
-      //this version assumes that error strings propagate to the base and
-      //use a contains() test only. It should be less brittle
-      if(!e.getMessage().contains(errMessage)) {
-        throw e;
-      }
-    }
+
+    //the version in trunk looked in the cause for equality
+    // and assumed failures were nested.
+    //this version assumes that error strings propagate to the base and
+    //use a contains() test only. It should be less brittle
+    LambdaTestUtils.intercept(Exception.class, errMessage, () -> nm.start());
 
     // the service should be stopped
-    Assert.assertEquals("NM state is wrong!", STATE.STOPPED, nm
-        .getServiceState());
+    Assert.assertEquals("NM state is wrong!", STATE.STOPPED,
+        nm.getServiceState());
 
     Assert.assertEquals("Number of registered nodes is wrong!", 0,
         this.registeredNodes.size());