SAMZA-2606: container orchestration for AM HA (#1448)

Feature:
Main feature is YARN AM high availability. The feature ensures that the new AM can establish connection with already running containers to avoid restarting all running containers when AM dies. This PR enables the new AM to accept the list fo already running container provided by the resource manager and launch only those containers that are part of the job model but not in the running container list.

Changes:
[1] ClientHelper - job submit to RM indicates to keep containers alive across attempts
SamzaYarnAppMasterLifecycle: new AM uses the yarnid-samza id mapping and accepts the list of running containers given by RM and builds its internal state (SamzaApplicationState and YarnAppState) correctly
[2] ContainerProcessManager - removes running containers from the needed processor list prior to placing resource requests
diff --git a/samza-core/src/main/java/org/apache/samza/clustermanager/ContainerProcessManager.java b/samza-core/src/main/java/org/apache/samza/clustermanager/ContainerProcessManager.java
index a2ad540..ec52d4b 100644
--- a/samza-core/src/main/java/org/apache/samza/clustermanager/ContainerProcessManager.java
+++ b/samza-core/src/main/java/org/apache/samza/clustermanager/ContainerProcessManager.java
@@ -236,11 +236,15 @@
       diagnosticsManager.get().start();
     }
 
+    // In AM-HA, clusterResourceManager receives already running containers
+    // and invokes onStreamProcessorLaunchSuccess which inturn updates state
+    // hence state has to be set prior to starting clusterResourceManager.
+    state.processorCount.set(state.jobModelManager.jobModel().getContainers().size());
+    state.neededProcessors.set(state.jobModelManager.jobModel().getContainers().size());
+
     LOG.info("Starting the cluster resource manager");
     clusterResourceManager.start();
 
-    state.processorCount.set(state.jobModelManager.jobModel().getContainers().size());
-    state.neededProcessors.set(state.jobModelManager.jobModel().getContainers().size());
     // Request initial set of containers
     LocalityModel localityModel = localityManager.readLocality();
     Map<String, String> processorToHost = new HashMap<>();
@@ -251,6 +255,10 @@
           .orElse(null);
       processorToHost.put(containerId, host);
     });
+    if (jobConfig.getApplicationMasterHighAvailabilityEnabled()) {
+      // don't request resource for container that is already running
+      state.runningProcessors.keySet().forEach(processorToHost::remove);
+    }
     containerAllocator.requestResources(processorToHost);
 
     // Start container allocator thread
@@ -666,4 +674,4 @@
   private void handleContainerStop(String processorId, String containerId, String preferredHost, int exitStatus, Duration preferredHostRetryDelay) {
     containerManager.handleContainerStop(processorId, containerId, preferredHost, exitStatus, preferredHostRetryDelay, containerAllocator);
   }
-}
+}
\ No newline at end of file
diff --git a/samza-core/src/test/java/org/apache/samza/clustermanager/TestContainerProcessManager.java b/samza-core/src/test/java/org/apache/samza/clustermanager/TestContainerProcessManager.java
index 5e550cf..d285f9e 100644
--- a/samza-core/src/test/java/org/apache/samza/clustermanager/TestContainerProcessManager.java
+++ b/samza-core/src/test/java/org/apache/samza/clustermanager/TestContainerProcessManager.java
@@ -31,6 +31,7 @@
 import org.apache.samza.clustermanager.container.placement.ContainerPlacementMetadataStore;
 import org.apache.samza.config.ClusterManagerConfig;
 import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.container.LocalityManager;
 import org.apache.samza.coordinator.JobModelManager;
@@ -227,6 +228,55 @@
   }
 
   @Test
+  public void testOnInitAMHighAvailability() throws Exception {
+    Map<String, String> configMap = new HashMap<>(configVals);
+    configMap.put(JobConfig.YARN_AM_HIGH_AVAILABILITY_ENABLED, "true");
+    Config conf = new MapConfig(configMap);
+
+    SamzaApplicationState state = new SamzaApplicationState(getJobModelManager(2));
+    state.runningProcessors.put("0", new SamzaResource(1, 1024, "host", "0"));
+
+    MockClusterResourceManagerCallback callback = new MockClusterResourceManagerCallback();
+    ClusterResourceManager clusterResourceManager = new MockClusterResourceManager(callback, state);
+    ClusterManagerConfig clusterManagerConfig = spy(new ClusterManagerConfig(conf));
+    ContainerManager containerManager =
+        buildContainerManager(containerPlacementMetadataStore, state, clusterResourceManager,
+            clusterManagerConfig.getHostAffinityEnabled(), false);
+
+    ContainerProcessManager cpm =
+        buildContainerProcessManager(clusterManagerConfig, state, clusterResourceManager, Optional.empty());
+
+    MockContainerAllocatorWithoutHostAffinity allocator = new MockContainerAllocatorWithoutHostAffinity(
+        clusterResourceManager,
+        conf,
+        state,
+        containerManager);
+
+    getPrivateFieldFromCpm("containerAllocator", cpm).set(cpm, allocator);
+    CountDownLatch latch = new CountDownLatch(1);
+    getPrivateFieldFromCpm("allocatorThread", cpm).set(cpm, new Thread() {
+      public void run() {
+        isRunning = true;
+        latch.countDown();
+      }
+    });
+
+    cpm.start();
+
+    if (!latch.await(2, TimeUnit.SECONDS)) {
+      Assert.fail("timed out waiting for the latch to expire");
+    }
+
+    // Verify Allocator thread has started running
+    assertTrue(isRunning);
+
+    // Verify only 1 was requested with allocator
+    assertEquals(1, allocator.requestedContainers);
+
+    cpm.stop();
+  }
+
+  @Test
   public void testOnShutdown() throws Exception {
     Config conf = getConfig();
     SamzaApplicationState state = new SamzaApplicationState(getJobModelManager(1));
diff --git a/samza-yarn/src/main/java/org/apache/samza/job/yarn/YarnClusterResourceManager.java b/samza-yarn/src/main/java/org/apache/samza/job/yarn/YarnClusterResourceManager.java
index 7a971a9..d1c5437 100644
--- a/samza-yarn/src/main/java/org/apache/samza/job/yarn/YarnClusterResourceManager.java
+++ b/samza-yarn/src/main/java/org/apache/samza/job/yarn/YarnClusterResourceManager.java
@@ -20,6 +20,7 @@
 package org.apache.samza.job.yarn;
 
 import java.time.Duration;
+import java.util.Set;
 import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
@@ -43,6 +44,7 @@
 import org.apache.samza.clustermanager.ProcessorLaunchException;
 import org.apache.samza.config.ClusterManagerConfig;
 import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.ShellCommandConfig;
 import org.apache.samza.config.YarnConfig;
 import org.apache.samza.coordinator.JobModelManager;
@@ -192,7 +194,8 @@
         clusterManagerConfig.getNumCores(),
         samzaAppState,
         state,
-        amClient
+        amClient,
+        new JobConfig(config).getApplicationMasterHighAvailabilityEnabled()
     );
     this.nmClientAsync = NMClientAsync.createNMClientAsync(this);
 
@@ -215,7 +218,12 @@
     amClient.start();
     nmClientAsync.init(yarnConfiguration);
     nmClientAsync.start();
-    lifecycle.onInit();
+    Set<ContainerId> previousAttemptsContainers = lifecycle.onInit();
+
+    if (new JobConfig(config).getApplicationMasterHighAvailabilityEnabled()) {
+      log.info("Received running containers from previous attempt. Invoking launch success for them.");
+      previousAttemptsContainers.forEach(this::handleOnContainerStarted);
+    }
 
     if (lifecycle.shouldShutdown()) {
       clusterManagerCallback.onError(new SamzaException("Invalid resource request."));
@@ -322,9 +330,11 @@
 
   public void stopStreamProcessor(SamzaResource resource) {
     synchronized (lock) {
-      log.info("Stopping Container ID: {} on host: {}", resource.getContainerId(), resource.getHost());
-      this.nmClientAsync.stopContainerAsync(allocatedResources.get(resource).getId(),
-          allocatedResources.get(resource).getNodeId());
+      Container container = allocatedResources.get(resource);
+      if (container != null) {
+        log.info("Stopping Container ID: {} on host: {}", resource.getContainerId(), resource.getHost());
+        this.nmClientAsync.stopContainerAsync(container.getId(), container.getNodeId());
+      }
     }
   }
 
@@ -513,22 +523,7 @@
 
   @Override
   public void onContainerStarted(ContainerId containerId, Map<String, ByteBuffer> allServiceResponse) {
-    String processorId = getPendingProcessorId(containerId);
-    if (processorId != null) {
-      log.info("Got start notification for Container ID: {} for Processor ID: {}", containerId, processorId);
-      // 1. Move the processor from pending to running state
-      final YarnContainer container = state.pendingProcessors.remove(processorId);
-
-      state.runningProcessors.put(processorId, container);
-
-      // 2. Invoke the success callback.
-      SamzaResource resource = new SamzaResource(container.resource().getVirtualCores(),
-          container.resource().getMemory(), container.nodeId().getHost(), containerId.toString());
-      clusterManagerCallback.onStreamProcessorLaunchSuccess(resource);
-    } else {
-      log.warn("Did not find the Processor ID for the start notification for Container ID: {}. " +
-          "Ignoring notification.", containerId);
-    }
+    handleOnContainerStarted(containerId);
   }
 
   @Override
@@ -726,4 +721,29 @@
     }
     return null;
   }
+
+  /**
+   * Handles container started call back for a yarn container.
+   * updates the YarnAppState's pendingProcessors and runningProcessors
+   * and also invokes clusterManagerCallback.s stream processor launch success
+   * @param containerId yarn container id which has started
+   */
+  private void handleOnContainerStarted(ContainerId containerId) {
+    String processorId = getPendingProcessorId(containerId);
+    if (processorId != null) {
+      log.info("Got start notification for Container ID: {} for Processor ID: {}", containerId, processorId);
+      // 1. Move the processor from pending to running state
+      final YarnContainer container = state.pendingProcessors.remove(processorId);
+
+      state.runningProcessors.put(processorId, container);
+
+      // 2. Invoke the success callback.
+      SamzaResource resource = new SamzaResource(container.resource().getVirtualCores(),
+          container.resource().getMemory(), container.nodeId().getHost(), containerId.toString());
+      clusterManagerCallback.onStreamProcessorLaunchSuccess(resource);
+    } else {
+      log.warn("Did not find the Processor ID for the start notification for Container ID: {}. " +
+          "Ignoring notification.", containerId);
+    }
+  }
 }
diff --git a/samza-yarn/src/main/scala/org/apache/samza/job/yarn/ClientHelper.scala b/samza-yarn/src/main/scala/org/apache/samza/job/yarn/ClientHelper.scala
index 4c3c93e..8e0c3d1 100644
--- a/samza-yarn/src/main/scala/org/apache/samza/job/yarn/ClientHelper.scala
+++ b/samza-yarn/src/main/scala/org/apache/samza/job/yarn/ClientHelper.scala
@@ -175,6 +175,10 @@
     appCtx.setApplicationId(appId.get)
     info("set app ID to %s" format appId.get)
 
+    if (new JobConfig(config).getApplicationMasterHighAvailabilityEnabled) {
+      appCtx.setKeepContainersAcrossApplicationAttempts(true)
+      info("keep containers alive across application attempts for AM High availability")
+    }
     val localResources: HashMap[String, LocalResource] = HashMap[String, LocalResource]()
     localResources += "__package" -> packageResource
 
diff --git a/samza-yarn/src/main/scala/org/apache/samza/job/yarn/SamzaYarnAppMasterLifecycle.scala b/samza-yarn/src/main/scala/org/apache/samza/job/yarn/SamzaYarnAppMasterLifecycle.scala
index 5c0dfac..27e0b1f 100644
--- a/samza-yarn/src/main/scala/org/apache/samza/job/yarn/SamzaYarnAppMasterLifecycle.scala
+++ b/samza-yarn/src/main/scala/org/apache/samza/job/yarn/SamzaYarnAppMasterLifecycle.scala
@@ -20,27 +20,32 @@
 package org.apache.samza.job.yarn
 
 import java.io.IOException
+import java.util
+import java.util.HashMap
 
-import org.apache.hadoop.yarn.api.records.FinalApplicationStatus
+import org.apache.hadoop.yarn.api.records.{Container, ContainerId, FinalApplicationStatus}
 import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
 import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync
 import org.apache.hadoop.yarn.exceptions.{InvalidApplicationMasterRequestException, YarnException}
 import org.apache.samza.SamzaException
-import org.apache.samza.clustermanager.SamzaApplicationState
+import org.apache.samza.clustermanager.{SamzaApplicationState, SamzaResource}
 import SamzaApplicationState.SamzaAppStatus
 import org.apache.samza.util.Logging
 
+import scala.collection.JavaConverters._
+
 /**
  * Responsible for managing the lifecycle of the Yarn application master. Mostly,
  * this means registering and unregistering with the RM, and shutting down
  * when the RM tells us to Reboot.
  */
 //This class is used in the refactored code path as called by run-jc.sh
-class SamzaYarnAppMasterLifecycle(containerMem: Int, containerCpu: Int, samzaAppState: SamzaApplicationState, state: YarnAppState, amClient: AMRMClientAsync[ContainerRequest]) extends Logging {
+class SamzaYarnAppMasterLifecycle(containerMem: Int, containerCpu: Int, samzaAppState: SamzaApplicationState, state: YarnAppState, amClient: AMRMClientAsync[ContainerRequest],
+  isApplicationMasterHighAvailabilityEnabled: Boolean) extends Logging {
   var validResourceRequest = true
   var shutdownMessage: String = null
   var webApp: SamzaYarnAppMasterService = null
-  def onInit() {
+  def onInit(): util.Set[ContainerId] = {
     val host = state.nodeHost
     val response = amClient.registerApplicationMaster(host, state.rpcUrl.getPort, "%s:%d" format (host, state.trackingUrl.getPort))
 
@@ -48,6 +53,19 @@
     val maxCapability = response.getMaximumResourceCapability
     val maxMem = maxCapability.getMemory
     val maxCpu = maxCapability.getVirtualCores
+    val previousAttemptContainers = new util.HashSet[ContainerId]()
+    if (isApplicationMasterHighAvailabilityEnabled) {
+      val yarnIdToprocIdMap = new HashMap[String, String]()
+      samzaAppState.processorToExecutionId.asScala foreach { entry => yarnIdToprocIdMap.put(entry._2, entry._1) }
+      response.getContainersFromPreviousAttempts.asScala foreach { (ctr: Container) =>
+        val samzaProcId = yarnIdToprocIdMap.get(ctr.getId.toString)
+        info("Received container from previous attempt with samza processor id %s and yarn container id %s" format(samzaProcId, ctr.getId.toString))
+        samzaAppState.pendingProcessors.put(samzaProcId,
+          new SamzaResource(ctr.getResource.getVirtualCores, ctr.getResource.getMemory, ctr.getNodeId.getHost, ctr.getId.toString))
+        state.pendingProcessors.put(samzaProcId, new YarnContainer(ctr))
+        previousAttemptContainers.add(ctr.getId)
+      }
+    }
     info("Got AM register response. The YARN RM supports container requests with max-mem: %s, max-cpu: %s" format (maxMem, maxCpu))
 
     if (containerMem > maxMem || containerCpu > maxCpu) {
@@ -57,6 +75,7 @@
       samzaAppState.status = SamzaAppStatus.FAILED;
       samzaAppState.jobHealthy.set(false)
     }
+    previousAttemptContainers
   }
 
   def onReboot() {
diff --git a/samza-yarn/src/main/scala/org/apache/samza/webapp/ApplicationMasterWebServlet.scala b/samza-yarn/src/main/scala/org/apache/samza/webapp/ApplicationMasterWebServlet.scala
index d787f9e..2b62b96 100644
--- a/samza-yarn/src/main/scala/org/apache/samza/webapp/ApplicationMasterWebServlet.scala
+++ b/samza-yarn/src/main/scala/org/apache/samza/webapp/ApplicationMasterWebServlet.scala
@@ -36,6 +36,7 @@
     contentType = "text/html"
   }
 
+  // Due to AMHA, the uptime and start time of containers (within state) from previous attempt is reset to the time the new AM becomes alive.
   get("/") {
     layoutTemplate("/WEB-INF/views/index.scaml",
       "config" -> TreeMap(samzaConfig.sanitize.asScala.toMap.toArray: _*),
diff --git a/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestYarnClusterResourceManager.java b/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestYarnClusterResourceManager.java
index 1ed1d09..08b18e8 100644
--- a/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestYarnClusterResourceManager.java
+++ b/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestYarnClusterResourceManager.java
@@ -21,6 +21,10 @@
 
 import java.io.IOException;
 import java.time.Duration;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
 import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.api.records.Container;
@@ -39,12 +43,15 @@
 import org.apache.samza.clustermanager.SamzaApplicationState;
 import org.apache.samza.clustermanager.SamzaResource;
 import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.config.MapConfig;
 import org.junit.Assert;
 import org.junit.Test;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mockito;
 
-import static org.junit.Assert.assertEquals;
-import static org.mockito.Matchers.anyObject;
+import static org.junit.Assert.*;
+import static org.mockito.Matchers.*;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.*;
 
@@ -116,7 +123,7 @@
     Config config = mock(Config.class);
     AMRMClientAsync asyncClient = mock(AMRMClientAsync.class);
     YarnAppState yarnAppState = new YarnAppState(0, mock(ContainerId.class), "host", 8080, 8081);
-    SamzaYarnAppMasterLifecycle lifecycle = Mockito.spy(new SamzaYarnAppMasterLifecycle(512, 2, mock(SamzaApplicationState.class), yarnAppState, asyncClient));
+    SamzaYarnAppMasterLifecycle lifecycle = Mockito.spy(new SamzaYarnAppMasterLifecycle(512, 2, mock(SamzaApplicationState.class), yarnAppState, asyncClient, false));
     SamzaYarnAppMasterService service = mock(SamzaYarnAppMasterService.class);
     NMClientAsync asyncNMClient = mock(NMClientAsync.class);
     ClusterResourceManager.Callback callback = mock(ClusterResourceManager.Callback.class);
@@ -144,7 +151,7 @@
     Config config = mock(Config.class);
     AMRMClientAsync asyncClient = mock(AMRMClientAsync.class);
     YarnAppState yarnAppState = new YarnAppState(0, mock(ContainerId.class), "host", 8080, 8081);
-    SamzaYarnAppMasterLifecycle lifecycle = Mockito.spy(new SamzaYarnAppMasterLifecycle(512, 2, mock(SamzaApplicationState.class), yarnAppState, asyncClient));
+    SamzaYarnAppMasterLifecycle lifecycle = Mockito.spy(new SamzaYarnAppMasterLifecycle(512, 2, mock(SamzaApplicationState.class), yarnAppState, asyncClient, false));
     SamzaYarnAppMasterService service = mock(SamzaYarnAppMasterService.class);
     NMClientAsync asyncNMClient = mock(NMClientAsync.class);
     ClusterResourceManager.Callback callback = mock(ClusterResourceManager.Callback.class);
@@ -165,4 +172,54 @@
     verify(service, times(1)).onShutdown();
     verify(metrics, times(1)).stop();
   }
+
+  @Test
+  public void testAMHACallbackInvokedForPreviousAttemptContainers() throws IOException, YarnException {
+    String previousAttemptContainerId = "0";
+    String previousAttemptYarnContainerId = "container_1607304997422_0008_02_000002";
+    // create mocks
+    YarnConfiguration yarnConfiguration = mock(YarnConfiguration.class);
+    SamzaAppMasterMetrics metrics = mock(SamzaAppMasterMetrics.class);
+    AMRMClientAsync asyncClient = mock(AMRMClientAsync.class);
+    YarnAppState yarnAppState = Mockito.spy(new YarnAppState(0, mock(ContainerId.class), "host", 8080, 8081));
+    SamzaYarnAppMasterLifecycle lifecycle = mock(SamzaYarnAppMasterLifecycle.class);
+    SamzaYarnAppMasterService service = mock(SamzaYarnAppMasterService.class);
+    NMClientAsync asyncNMClient = mock(NMClientAsync.class);
+    ClusterResourceManager.Callback callback = mock(ClusterResourceManager.Callback.class);
+
+    ContainerId containerId = mock(ContainerId.class);
+    when(containerId.toString()).thenReturn(previousAttemptYarnContainerId);
+
+    YarnContainer yarnContainer = mock(YarnContainer.class);
+    Resource resource = mock(Resource.class);
+    when(resource.getMemory()).thenReturn(1024);
+    Mockito.when(resource.getVirtualCores()).thenReturn(1);
+    Mockito.when(yarnContainer.resource()).thenReturn(resource);
+    Mockito.when(yarnContainer.id()).thenReturn(containerId);
+    NodeId nodeId = mock(NodeId.class);
+    when(nodeId.getHost()).thenReturn("host");
+    when(yarnContainer.nodeId()).thenReturn(nodeId);
+
+    yarnAppState.pendingProcessors.put(previousAttemptContainerId, yarnContainer);
+
+    Set<ContainerId> previousAttemptContainers = new HashSet<>();
+    previousAttemptContainers.add(containerId);
+    when(lifecycle.onInit()).thenReturn(previousAttemptContainers);
+
+    Map<String, String> configMap = new HashMap<>();
+    configMap.put(JobConfig.YARN_AM_HIGH_AVAILABILITY_ENABLED, "true");
+    Config config = new MapConfig(configMap);
+
+    // start the cluster manager
+    YarnClusterResourceManager yarnClusterResourceManager =
+        new YarnClusterResourceManager(asyncClient, asyncNMClient, callback, yarnAppState, lifecycle, service, metrics,
+            yarnConfiguration, config);
+
+    yarnClusterResourceManager.start();
+    verify(lifecycle).onInit();
+    ArgumentCaptor<SamzaResource> samzaResourceArgumentCaptor = ArgumentCaptor.forClass(SamzaResource.class);
+    verify(callback).onStreamProcessorLaunchSuccess(samzaResourceArgumentCaptor.capture());
+    SamzaResource samzaResource = samzaResourceArgumentCaptor.getValue();
+    assertEquals(previousAttemptYarnContainerId, samzaResource.getContainerId());
+  }
 }
\ No newline at end of file
diff --git a/samza-yarn/src/test/scala/org/apache/samza/job/yarn/TestSamzaYarnAppMasterLifecycle.scala b/samza-yarn/src/test/scala/org/apache/samza/job/yarn/TestSamzaYarnAppMasterLifecycle.scala
index 2664e41..5f78f78 100644
--- a/samza-yarn/src/test/scala/org/apache/samza/job/yarn/TestSamzaYarnAppMasterLifecycle.scala
+++ b/samza-yarn/src/test/scala/org/apache/samza/job/yarn/TestSamzaYarnAppMasterLifecycle.scala
@@ -36,9 +36,13 @@
 import org.apache.samza.coordinator.JobModelManager
 import org.junit.Assert._
 import org.junit.Test
-import org.mockito.Mockito
+import org.mockito.{ArgumentCaptor, Mockito}
 
 class TestSamzaYarnAppMasterLifecycle {
+  private def YARN_CONTAINER_ID = "container_123_123_123"
+  private def YARN_CONTAINER_HOST = "host"
+  private def YARN_CONTAINER_MEM = 1024
+  private def YARN_CONTAINER_VCORE = 1
   val coordinator = new JobModelManager(null, null)
   val amClient = new AMRMClientAsyncImpl[ContainerRequest](1, Mockito.mock(classOf[CallbackHandler])) {
     var host = ""
@@ -60,7 +64,10 @@
         }
         override def getClientToAMTokenMasterKey = null
         override def setClientToAMTokenMasterKey(buffer: ByteBuffer) {}
-        override def getContainersFromPreviousAttempts(): java.util.List[Container] = java.util.Collections.emptyList[Container]
+        // to test AM high availability - return a running container from previous attempt
+        val prevAttemptCotainers = new java.util.ArrayList[Container]()
+        prevAttemptCotainers.add(getMockContainer)
+        override def getContainersFromPreviousAttempts(): java.util.List[Container] = prevAttemptCotainers
         override def getNMTokensFromPreviousAttempts(): java.util.List[NMToken] = java.util.Collections.emptyList[NMToken]
         override def getQueue(): String = null
         override def setContainersFromPreviousAttempts(containers: java.util.List[Container]): Unit = Unit
@@ -92,7 +99,7 @@
     yarnState.rpcUrl = new URL("http://localhost:1")
     yarnState.trackingUrl = new URL("http://localhost:2")
 
-    val saml = new SamzaYarnAppMasterLifecycle(512, 2, state, yarnState, amClient)
+    val saml = new SamzaYarnAppMasterLifecycle(512, 2, state, yarnState, amClient, false)
     saml.onInit
     assertEquals("testHost", amClient.host)
     assertEquals(1, amClient.port)
@@ -104,7 +111,7 @@
     val state = new SamzaApplicationState(coordinator)
 
     val yarnState =  new YarnAppState(1, ConverterUtils.toContainerId("container_1350670447861_0003_01_000001"), "testHost", 1, 2);
-    new SamzaYarnAppMasterLifecycle(512, 2, state, yarnState, amClient).onShutdown (SamzaAppStatus.SUCCEEDED)
+    new SamzaYarnAppMasterLifecycle(512, 2, state, yarnState, amClient, false).onShutdown (SamzaAppStatus.SUCCEEDED)
     assertEquals(FinalApplicationStatus.SUCCEEDED, amClient.status)
   }
 
@@ -115,7 +122,7 @@
       val state = new SamzaApplicationState(coordinator)
 
       val yarnState =  new YarnAppState(1, ConverterUtils.toContainerId("container_1350670447861_0003_01_000001"), "testHost", 1, 2);
-      new SamzaYarnAppMasterLifecycle(512, 2, state, yarnState, amClient).onReboot()
+      new SamzaYarnAppMasterLifecycle(512, 2, state, yarnState, amClient, false).onReboot()
     } catch {
       // expected
       case e: SamzaException => gotException = true
@@ -132,11 +139,62 @@
     yarnState.trackingUrl = new URL("http://localhost:2")
 
     //Request a higher amount of memory from yarn.
-    List(new SamzaYarnAppMasterLifecycle(768, 1, state, yarnState, amClient),
+    List(new SamzaYarnAppMasterLifecycle(768, 1, state, yarnState, amClient, false),
     //Request a higher number of cores from yarn.
-      new SamzaYarnAppMasterLifecycle(368, 3, state, yarnState, amClient)).map(saml => {
+      new SamzaYarnAppMasterLifecycle(368, 3, state, yarnState, amClient, false)).map(saml => {
         saml.onInit
         assertTrue(saml.shouldShutdown)
       })
   }
+
+  @Test
+  def testAMHighAvailabilityOnInit {
+    val PROCESSOR_ID = "0"
+    val samzaApplicationState = new SamzaApplicationState(coordinator)
+
+    samzaApplicationState.processorToExecutionId.put(PROCESSOR_ID, YARN_CONTAINER_ID);
+
+    val yarnState = new YarnAppState(1, ConverterUtils.toContainerId("container_1350670447861_0003_01_000001"), "testHost", 1, 2);
+    yarnState.rpcUrl = new URL("http://localhost:1")
+    yarnState.trackingUrl = new URL("http://localhost:2")
+
+    val saml = new SamzaYarnAppMasterLifecycle(512, 2, samzaApplicationState, yarnState, amClient, true)
+    saml.onInit
+
+    // verify that the samzaApplicationState is updated to reflect a running container from previous attempt
+    assertEquals(1, samzaApplicationState.pendingProcessors.size())
+    assertTrue(samzaApplicationState.pendingProcessors.containsKey(PROCESSOR_ID))
+    val resource = samzaApplicationState.pendingProcessors.get(PROCESSOR_ID)
+    assertEquals(YARN_CONTAINER_ID, resource.getContainerId)
+    assertEquals(YARN_CONTAINER_HOST, resource.getHost)
+    assertEquals(YARN_CONTAINER_MEM, resource.getMemoryMb)
+    assertEquals(YARN_CONTAINER_VCORE, resource.getNumCores)
+
+    assertEquals(1, yarnState.pendingProcessors.size())
+    assertTrue(yarnState.pendingProcessors.containsKey(PROCESSOR_ID))
+    val yarnCtr = yarnState.pendingProcessors.get(PROCESSOR_ID)
+    assertEquals(YARN_CONTAINER_ID, yarnCtr.id.toString)
+    assertEquals(YARN_CONTAINER_HOST, yarnCtr.nodeId.getHost)
+    assertEquals(YARN_CONTAINER_MEM, yarnCtr.resource.getMemory)
+    assertEquals(YARN_CONTAINER_VCORE, yarnCtr.resource.getVirtualCores)
+  }
+
+  def getMockContainer: Container = {
+    val container = Mockito.mock(classOf[Container])
+
+    val containerId = Mockito.mock(classOf[ContainerId])
+    Mockito.when(containerId.toString).thenReturn(YARN_CONTAINER_ID)
+    Mockito.when(container.getId).thenReturn(containerId)
+
+    val resource = Mockito.mock(classOf[Resource])
+    Mockito.when(resource.getMemory).thenReturn(YARN_CONTAINER_MEM)
+    Mockito.when(resource.getVirtualCores).thenReturn(YARN_CONTAINER_VCORE)
+    Mockito.when(container.getResource).thenReturn(resource)
+
+    val nodeId = Mockito.mock(classOf[NodeId])
+    Mockito.when(nodeId.getHost).thenReturn(YARN_CONTAINER_HOST)
+    Mockito.when(container.getNodeId).thenReturn(nodeId)
+
+    container
+  }
 }