SAMZA-2610: Handle Metadata changes for AM HA orchestration (#1450)

Description:
AM performs planning and job model generation for every incarnation. With AM-HA, the new job model or configuration may invalidate the containers from the previous attempt. In order to ensure correctness, we handle this by detecting these changes and restart all the containers in case of any changes to metadata (job model or configuration).

Changes:

Detect changes in metadata by reading older metadata from coordinator stream and signal the CPM
As part of resource request & orchestration, ignore the containers that are already running from the previous attempt and proceed to release them if metadata changed.
Releasing the container will signal RM through AMRM client and RM will orchestrate killing the processing container. It is different from the normal StopStreamProcessor flow as the NMClient isn't the source of truth and doesn't have context about the containers spun in the previous attempts
diff --git a/samza-core/src/main/java/org/apache/samza/clustermanager/ClusterBasedJobCoordinator.java b/samza-core/src/main/java/org/apache/samza/clustermanager/ClusterBasedJobCoordinator.java
index b98c727..08bcfda 100644
--- a/samza-core/src/main/java/org/apache/samza/clustermanager/ClusterBasedJobCoordinator.java
+++ b/samza-core/src/main/java/org/apache/samza/clustermanager/ClusterBasedJobCoordinator.java
@@ -38,6 +38,7 @@
 import org.apache.samza.container.LocalityManager;
 import org.apache.samza.container.TaskName;
 import org.apache.samza.coordinator.InputStreamsDiscoveredException;
+import org.apache.samza.coordinator.JobCoordinatorMetadataManager;
 import org.apache.samza.coordinator.JobModelManager;
 import org.apache.samza.coordinator.MetadataResourceUtil;
 import org.apache.samza.coordinator.PartitionChangeException;
@@ -47,6 +48,8 @@
 import org.apache.samza.coordinator.stream.messages.SetChangelogMapping;
 import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping;
 import org.apache.samza.coordinator.stream.messages.SetExecutionEnvContainerIdMapping;
+import org.apache.samza.coordinator.stream.messages.SetJobCoordinatorMetadataMessage;
+import org.apache.samza.job.JobCoordinatorMetadata;
 import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.JobModel;
 import org.apache.samza.job.model.JobModelUtil;
@@ -64,7 +67,6 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-
 /**
  * Implements a JobCoordinator that is completely independent of the underlying cluster
  * manager system. This {@link ClusterBasedJobCoordinator} handles functionality common
@@ -170,6 +172,11 @@
    */
   private JmxServer jmxServer;
 
+  /*
+   * Denotes if the metadata changed across application attempts. Used only if job coordinator high availability is enabled
+   */
+  private boolean metadataChangedAcrossAttempts = false;
+
   /**
    * Variable to keep the callback exception
    */
@@ -208,11 +215,11 @@
     this.localityManager =
         new LocalityManager(new NamespaceAwareCoordinatorStreamStore(metadataStore, SetContainerHostMapping.TYPE));
 
-    if (new JobConfig(config).getApplicationMasterHighAvailabilityEnabled()) {
+    if (isApplicationMasterHighAvailabilityEnabled()) {
       ExecutionContainerIdManager executionContainerIdManager = new ExecutionContainerIdManager(
           new NamespaceAwareCoordinatorStreamStore(metadataStore, SetExecutionEnvContainerIdMapping.TYPE));
-
       state.processorToExecutionId.putAll(executionContainerIdManager.readExecutionEnvironmentContainerIdMapping());
+      generateAndUpdateJobCoordinatorMetadata(jobModelManager.jobModel());
     }
     // build metastore for container placement messages
     containerPlacementMetadataStore = new ContainerPlacementMetadataStore(metadataStore);
@@ -260,8 +267,12 @@
       MetadataResourceUtil metadataResourceUtil = new MetadataResourceUtil(jobModel, this.metrics, config);
       metadataResourceUtil.createResources();
 
-      // fan out the startpoints if startpoints is enabled
-      if (new JobConfig(config).getStartpointEnabled()) {
+      /*
+       * We fanout startpoint if and only if
+       *  1. Startpoint is enabled in configuration
+       *  2. If AM HA is enabled, fanout only if startpoint enabled and job coordinator metadata changed
+       */
+      if (shouldFanoutStartpoint()) {
         StartpointManager startpointManager = createStartpointManager();
         startpointManager.start();
         try {
@@ -332,6 +343,24 @@
   }
 
   /**
+   * Generate the job coordinator metadata for current application attempt and checks for changes in the
+   * metadata from the previous attempt and writes the updates metadata to coordinator stream.
+   *
+   * @param jobModel job model used to generate the job coordinator metadata
+   */
+  @VisibleForTesting
+  void generateAndUpdateJobCoordinatorMetadata(JobModel jobModel) {
+    JobCoordinatorMetadataManager jobCoordinatorMetadataManager = createJobCoordinatorMetadataManager();
+
+    JobCoordinatorMetadata previousMetadata = jobCoordinatorMetadataManager.readJobCoordinatorMetadata();
+    JobCoordinatorMetadata newMetadata = jobCoordinatorMetadataManager.generateJobCoordinatorMetadata(jobModel, config);
+    if (jobCoordinatorMetadataManager.checkForMetadataChanges(newMetadata, previousMetadata)) {
+      jobCoordinatorMetadataManager.writeJobCoordinatorMetadata(newMetadata);
+      metadataChangedAcrossAttempts = true;
+    }
+  }
+
+  /**
    * Stops all components of the JobCoordinator.
    */
   private void onShutDown() {
@@ -456,6 +485,39 @@
 
   @VisibleForTesting
   ContainerProcessManager createContainerProcessManager() {
-    return new ContainerProcessManager(config, state, metrics, containerPlacementMetadataStore, localityManager);
+    return new ContainerProcessManager(config, state, metrics, containerPlacementMetadataStore, localityManager,
+        metadataChangedAcrossAttempts);
+  }
+
+  @VisibleForTesting
+  JobCoordinatorMetadataManager createJobCoordinatorMetadataManager() {
+    return new JobCoordinatorMetadataManager(new NamespaceAwareCoordinatorStreamStore(metadataStore,
+        SetJobCoordinatorMetadataMessage.TYPE), JobCoordinatorMetadataManager.ClusterType.YARN, metrics);
+  }
+
+  @VisibleForTesting
+  boolean isApplicationMasterHighAvailabilityEnabled() {
+    return new JobConfig(config).getApplicationMasterHighAvailabilityEnabled();
+  }
+
+  @VisibleForTesting
+  boolean isMetadataChangedAcrossAttempts() {
+    return metadataChangedAcrossAttempts;
+  }
+
+  /**
+   * We only fanout startpoint if and only if
+   *  1. Startpoint is enabled
+   *  2. If AM HA is enabled, fanout only if startpoint enabled and job coordinator metadata changed
+   *
+   * @return true if it satisfies above conditions, false otherwise
+   */
+  @VisibleForTesting
+  boolean shouldFanoutStartpoint() {
+    JobConfig jobConfig = new JobConfig(config);
+    boolean startpointEnabled = jobConfig.getStartpointEnabled();
+
+    return isApplicationMasterHighAvailabilityEnabled() ?
+        startpointEnabled && isMetadataChangedAcrossAttempts() : startpointEnabled;
   }
 }
diff --git a/samza-core/src/main/java/org/apache/samza/clustermanager/ContainerProcessManager.java b/samza-core/src/main/java/org/apache/samza/clustermanager/ContainerProcessManager.java
index ec52d4b..995cf7d 100644
--- a/samza-core/src/main/java/org/apache/samza/clustermanager/ContainerProcessManager.java
+++ b/samza-core/src/main/java/org/apache/samza/clustermanager/ContainerProcessManager.java
@@ -131,12 +131,14 @@
    */
   private final Map<String, ProcessorFailure> processorFailures = new HashMap<>();
 
+  private final boolean restartContainers;
+
   private ContainerProcessManagerMetrics containerProcessManagerMetrics;
   private JvmMetrics jvmMetrics;
   private Map<String, MetricsReporter> metricsReporters;
 
   public ContainerProcessManager(Config config, SamzaApplicationState state, MetricsRegistryMap registry,
-      ContainerPlacementMetadataStore metadataStore, LocalityManager localityManager) {
+      ContainerPlacementMetadataStore metadataStore, LocalityManager localityManager, boolean restartContainers) {
     Preconditions.checkNotNull(localityManager, "Locality manager cannot be null");
     this.state = state;
     this.clusterManagerConfig = new ClusterManagerConfig(config);
@@ -175,6 +177,7 @@
 
     this.containerAllocator = new ContainerAllocator(this.clusterResourceManager, config, state, hostAffinityEnabled, this.containerManager);
     this.allocatorThread = new Thread(this.containerAllocator, "Container Allocator Thread");
+    this.restartContainers = restartContainers;
     LOG.info("Finished container process manager initialization.");
   }
 
@@ -185,7 +188,8 @@
       ClusterResourceManager resourceManager,
       Optional<ContainerAllocator> allocator,
       ContainerManager containerManager,
-      LocalityManager localityManager) {
+      LocalityManager localityManager,
+      boolean restartContainers) {
     this.state = state;
     this.clusterManagerConfig = clusterManagerConfig;
     this.jobConfig = new JobConfig(clusterManagerConfig);
@@ -200,6 +204,7 @@
       () -> new ContainerAllocator(this.clusterResourceManager, clusterManagerConfig, state,
           hostAffinityEnabled, this.containerManager));
     this.allocatorThread = new Thread(this.containerAllocator, "Container Allocator Thread");
+    this.restartContainers = restartContainers;
     LOG.info("Finished container process manager initialization");
   }
 
@@ -248,16 +253,23 @@
     // Request initial set of containers
     LocalityModel localityModel = localityManager.readLocality();
     Map<String, String> processorToHost = new HashMap<>();
-    state.jobModelManager.jobModel().getContainers().keySet().forEach((containerId) -> {
-      String host = Optional.ofNullable(localityModel.getProcessorLocality(containerId))
+    state.jobModelManager.jobModel().getContainers().keySet().forEach((processorId) -> {
+      String host = Optional.ofNullable(localityModel.getProcessorLocality(processorId))
           .map(ProcessorLocality::host)
           .filter(StringUtils::isNotBlank)
           .orElse(null);
-      processorToHost.put(containerId, host);
+      processorToHost.put(processorId, host);
     });
     if (jobConfig.getApplicationMasterHighAvailabilityEnabled()) {
       // don't request resource for container that is already running
-      state.runningProcessors.keySet().forEach(processorToHost::remove);
+      state.runningProcessors.forEach((processorId, samzaResource) -> {
+        LOG.info("Not requesting container for processorId: {} since its already running as containerId: {}",
+            processorId, samzaResource.getContainerId());
+        processorToHost.remove(processorId);
+        if (restartContainers) {
+          clusterResourceManager.stopStreamProcessor(samzaResource);
+        }
+      });
     }
     containerAllocator.requestResources(processorToHost);
 
diff --git a/samza-core/src/main/java/org/apache/samza/coordinator/JobCoordinatorMetadataManager.java b/samza-core/src/main/java/org/apache/samza/coordinator/JobCoordinatorMetadataManager.java
index c5d72f5..c4540a6 100644
--- a/samza-core/src/main/java/org/apache/samza/coordinator/JobCoordinatorMetadataManager.java
+++ b/samza-core/src/main/java/org/apache/samza/coordinator/JobCoordinatorMetadataManager.java
@@ -54,12 +54,18 @@
   private static final String JOB_COORDINATOR_MANAGER_METRICS = "job-coordinator-manager";
   private static final String JOB_MODEL_CHANGED = "jobModelChanged";
   private static final String CONFIG_CHANGED = "configChanged";
+  private static final String METADATA_GENERATION_FAILED_COUNT = "metadataGenerationFailedCount";
+  private static final String METADATA_READ_FAILED_COUNT = "metadataReadFailedCount";
+  private static final String METADATA_WRITE_FAILED_COUNT = "metadataWriteFailedCount";
   private static final String NEW_DEPLOYMENT = "newDeployment";
 
   static final String CONTAINER_ID_PROPERTY = "CONTAINER_ID";
   static final String CONTAINER_ID_DELIMITER = "_";
 
   private final Counter applicationAttemptCount;
+  private final Counter metadataGenerationFailedCount;
+  private final Counter metadataReadFailedCount;
+  private final Counter metadataWriteFailedCount;
   private final Gauge<Integer> jobModelChangedAcrossApplicationAttempt;
   private final Gauge<Integer> configChangedAcrossApplicationAttempt;
   private final Gauge<Integer> newDeployment;
@@ -68,17 +74,30 @@
   private final Serde<String> valueSerde;
   private final ClusterType clusterType;
 
-  public JobCoordinatorMetadataManager(MetadataStore metadataStore, ClusterType clusterType, MetricsRegistry metricsRegistry) {
+  public JobCoordinatorMetadataManager(MetadataStore metadataStore, ClusterType clusterType,
+      MetricsRegistry metricsRegistry) {
+    this(metadataStore, clusterType, metricsRegistry,
+        new CoordinatorStreamValueSerde(SetJobCoordinatorMetadataMessage.TYPE));
+  }
+
+  @VisibleForTesting
+  JobCoordinatorMetadataManager(MetadataStore metadataStore, ClusterType clusterType, MetricsRegistry metricsRegistry,
+      Serde<String> valueSerde) {
     Preconditions.checkNotNull(clusterType, "Cluster type cannot be null");
+
     this.clusterType = clusterType;
     this.metadataStore = metadataStore;
-    this.valueSerde = new CoordinatorStreamValueSerde(SetJobCoordinatorMetadataMessage.TYPE);
+    this.valueSerde = valueSerde;
 
     applicationAttemptCount = metricsRegistry.newCounter(JOB_COORDINATOR_MANAGER_METRICS, APPLICATION_ATTEMPT_COUNT);
     configChangedAcrossApplicationAttempt =
         metricsRegistry.newGauge(JOB_COORDINATOR_MANAGER_METRICS, CONFIG_CHANGED, 0);
     jobModelChangedAcrossApplicationAttempt =
         metricsRegistry.newGauge(JOB_COORDINATOR_MANAGER_METRICS, JOB_MODEL_CHANGED, 0);
+    metadataGenerationFailedCount = metricsRegistry.newCounter(JOB_COORDINATOR_MANAGER_METRICS,
+        METADATA_GENERATION_FAILED_COUNT);
+    metadataReadFailedCount = metricsRegistry.newCounter(JOB_COORDINATOR_MANAGER_METRICS, METADATA_READ_FAILED_COUNT);
+    metadataWriteFailedCount = metricsRegistry.newCounter(JOB_COORDINATOR_MANAGER_METRICS, METADATA_WRITE_FAILED_COUNT);
     newDeployment = metricsRegistry.newGauge(JOB_COORDINATOR_MANAGER_METRICS, NEW_DEPLOYMENT, 0);
   }
 
@@ -121,8 +140,9 @@
       return new JobCoordinatorMetadata(fetchEpochIdForJobCoordinator(), String.valueOf(configId),
           String.valueOf(jobModelId));
     } catch (Exception e) {
+      metadataGenerationFailedCount.inc();
       LOG.error("Failed to generate metadata for the current attempt due to ", e);
-      throw new RuntimeException("Failed to generate the metadata for the current attempt due to ", e);
+      throw new SamzaException("Failed to generate the metadata for the current attempt due to ", e);
     }
   }
 
@@ -179,6 +199,7 @@
           metadata = metadataMapper.readValue(metadataString, JobCoordinatorMetadata.class);
           break;
         } catch (Exception e) {
+          metadataReadFailedCount.inc();
           LOG.error("Failed to read job coordinator metadata due to ", e);
         }
       }
@@ -204,36 +225,12 @@
       metadataStore.put(clusterType.name(), valueSerde.toBytes(metadataValueString));
       LOG.info("Successfully written job coordinator metadata: {} for cluster {}.", metadata, clusterType);
     } catch (Exception e) {
+      metadataWriteFailedCount.inc();
       LOG.error("Failed to write the job coordinator metadata to metadata store due to ", e);
       throw new SamzaException("Failed to write the job coordinator metadata.", e);
     }
   }
 
-  @VisibleForTesting
-  Counter getApplicationAttemptCount() {
-    return applicationAttemptCount;
-  }
-
-  @VisibleForTesting
-  Gauge<Integer> getJobModelChangedAcrossApplicationAttempt() {
-    return jobModelChangedAcrossApplicationAttempt;
-  }
-
-  @VisibleForTesting
-  Gauge<Integer> getConfigChangedAcrossApplicationAttempt() {
-    return configChangedAcrossApplicationAttempt;
-  }
-
-  @VisibleForTesting
-  Gauge<Integer> getNewDeployment() {
-    return newDeployment;
-  }
-
-  @VisibleForTesting
-  String getEnvProperty(String propertyName) {
-    return System.getenv(propertyName);
-  }
-
   /**
    * Generate the epoch id using the execution container id that is passed through system environment. This isn't ideal
    * way of generating this ID and we will need some contract between the underlying cluster manager and samza engine
@@ -254,11 +251,52 @@
    *
    * @return an identifier associated with the job coordinator satisfying the above properties
    */
-  private String fetchEpochIdForJobCoordinator() {
+  @VisibleForTesting
+  String fetchEpochIdForJobCoordinator() {
     String[] containerIdParts = getEnvProperty(CONTAINER_ID_PROPERTY).split(CONTAINER_ID_DELIMITER);
     return containerIdParts[1] + CONTAINER_ID_DELIMITER + containerIdParts[2];
   }
 
+  @VisibleForTesting
+  Counter getApplicationAttemptCount() {
+    return applicationAttemptCount;
+  }
+
+  @VisibleForTesting
+  Counter getMetadataGenerationFailedCount() {
+    return metadataGenerationFailedCount;
+  }
+
+  @VisibleForTesting
+  Counter getMetadataReadFailedCount() {
+    return metadataReadFailedCount;
+  }
+
+  @VisibleForTesting
+  Counter getMetadataWriteFailedCount() {
+    return metadataWriteFailedCount;
+  }
+
+  @VisibleForTesting
+  Gauge<Integer> getJobModelChangedAcrossApplicationAttempt() {
+    return jobModelChangedAcrossApplicationAttempt;
+  }
+
+  @VisibleForTesting
+  Gauge<Integer> getConfigChangedAcrossApplicationAttempt() {
+    return configChangedAcrossApplicationAttempt;
+  }
+
+  @VisibleForTesting
+  Gauge<Integer> getNewDeployment() {
+    return newDeployment;
+  }
+
+  @VisibleForTesting
+  String getEnvProperty(String propertyName) {
+    return System.getenv(propertyName);
+  }
+
   /**
    * A helper class to generate hash for the {@link Config} based on with a subset of configuration.
    * The subset of configuration used are configurations that prefix match the allowed prefixes.
diff --git a/samza-core/src/test/java/org/apache/samza/clustermanager/TestClusterBasedJobCoordinator.java b/samza-core/src/test/java/org/apache/samza/clustermanager/TestClusterBasedJobCoordinator.java
index caa1ffe..50a1ee1 100644
--- a/samza-core/src/test/java/org/apache/samza/clustermanager/TestClusterBasedJobCoordinator.java
+++ b/samza-core/src/test/java/org/apache/samza/clustermanager/TestClusterBasedJobCoordinator.java
@@ -32,11 +32,14 @@
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.MapConfig;
+import org.apache.samza.coordinator.JobCoordinatorMetadataManager;
 import org.apache.samza.coordinator.StreamPartitionCountMonitor;
 import org.apache.samza.coordinator.metadatastore.CoordinatorStreamStore;
 import org.apache.samza.coordinator.stream.CoordinatorStreamSystemProducer;
 import org.apache.samza.coordinator.stream.MockCoordinatorStreamSystemFactory;
 import org.apache.samza.execution.RemoteJobPlanner;
+import org.apache.samza.job.JobCoordinatorMetadata;
+import org.apache.samza.job.model.JobModel;
 import org.apache.samza.metrics.MetricsRegistry;
 import org.apache.samza.startpoint.StartpointManager;
 import org.apache.samza.system.MockSystemFactory;
@@ -47,7 +50,6 @@
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
-import org.mockito.Mockito;
 import org.mockito.exceptions.base.MockitoException;
 import org.powermock.api.mockito.PowerMockito;
 import org.powermock.core.classloader.annotations.PrepareForTest;
@@ -58,6 +60,8 @@
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 import static org.powermock.api.mockito.PowerMockito.mock;
@@ -150,7 +154,7 @@
     Config config = new MapConfig(configMap);
     MockitoException stopException = new MockitoException("Stop");
 
-    ClusterBasedJobCoordinator clusterCoordinator = Mockito.spy(ClusterBasedJobCoordinatorRunner.createFromMetadataStore(config));
+    ClusterBasedJobCoordinator clusterCoordinator = spy(ClusterBasedJobCoordinatorRunner.createFromMetadataStore(config));
     ContainerProcessManager mockContainerProcessManager = mock(ContainerProcessManager.class);
     doReturn(true).when(mockContainerProcessManager).shouldShutdown();
     StartpointManager mockStartpointManager = mock(StartpointManager.class);
@@ -174,6 +178,43 @@
   }
 
   @Test
+  public void testVerifyShouldFanoutStartpointWithoutAMHA() {
+    Config jobConfig = new MapConfig(configMap);
+
+    when(CoordinatorStreamUtil.readConfigFromCoordinatorStream(anyObject())).thenReturn(jobConfig);
+    ClusterBasedJobCoordinator clusterBasedJobCoordinator =
+        spy(ClusterBasedJobCoordinatorRunner.createFromMetadataStore(jobConfig));
+
+    when(clusterBasedJobCoordinator.isMetadataChangedAcrossAttempts()).thenReturn(true);
+    assertTrue("Startpoint should fanout even if metadata changed",
+        clusterBasedJobCoordinator.shouldFanoutStartpoint());
+
+    when(clusterBasedJobCoordinator.isMetadataChangedAcrossAttempts()).thenReturn(false);
+    assertTrue("Startpoint should fanout even if metadata remains unchanged",
+        clusterBasedJobCoordinator.shouldFanoutStartpoint());
+  }
+
+  @Test
+  public void testVerifyShouldFanoutStartpointWithAMHA() {
+    Config jobConfig = new MapConfig(configMap);
+
+    when(CoordinatorStreamUtil.readConfigFromCoordinatorStream(anyObject())).thenReturn(jobConfig);
+    ClusterBasedJobCoordinator clusterBasedJobCoordinator =
+        spy(ClusterBasedJobCoordinatorRunner.createFromMetadataStore(jobConfig));
+
+    when(clusterBasedJobCoordinator.isApplicationMasterHighAvailabilityEnabled()).thenReturn(true);
+
+    when(clusterBasedJobCoordinator.isMetadataChangedAcrossAttempts()).thenReturn(true);
+    assertTrue("Startpoint should fanout with change in metadata",
+        clusterBasedJobCoordinator.shouldFanoutStartpoint());
+
+    when(clusterBasedJobCoordinator.isMetadataChangedAcrossAttempts()).thenReturn(false);
+    assertFalse("Startpoint fan out shouldn't happen when metadata is unchanged",
+        clusterBasedJobCoordinator.shouldFanoutStartpoint());
+
+  }
+
+  @Test
   public void testToArgs() {
     ApplicationConfig appConfig = new ApplicationConfig(new MapConfig(ImmutableMap.of(
         JobConfig.JOB_NAME, "test1",
@@ -192,4 +233,39 @@
     assertEquals(expected.size(), actual.size());
     assertTrue(actual.containsAll(expected));
   }
+
+  @Test
+  public void testGenerateAndUpdateJobCoordinatorMetadata() {
+    Config jobConfig = new MapConfig(configMap);
+    when(CoordinatorStreamUtil.readConfigFromCoordinatorStream(anyObject())).thenReturn(jobConfig);
+    ClusterBasedJobCoordinator clusterBasedJobCoordinator =
+        spy(ClusterBasedJobCoordinatorRunner.createFromMetadataStore(jobConfig));
+
+    JobCoordinatorMetadata previousMetadata = mock(JobCoordinatorMetadata.class);
+    JobCoordinatorMetadata newMetadata = mock(JobCoordinatorMetadata.class);
+    JobCoordinatorMetadataManager jobCoordinatorMetadataManager = mock(JobCoordinatorMetadataManager.class);
+    JobModel mockJobModel = mock(JobModel.class);
+
+    when(jobCoordinatorMetadataManager.readJobCoordinatorMetadata()).thenReturn(previousMetadata);
+    when(jobCoordinatorMetadataManager.generateJobCoordinatorMetadata(any(), any())).thenReturn(newMetadata);
+    when(jobCoordinatorMetadataManager.checkForMetadataChanges(newMetadata, previousMetadata)).thenReturn(false);
+    when(clusterBasedJobCoordinator.createJobCoordinatorMetadataManager()).thenReturn(jobCoordinatorMetadataManager);
+
+    /*
+     * Verify if there are no changes to metadata, the metadata changed flag remains false and no interactions
+     * with job coordinator metadata manager
+     */
+    clusterBasedJobCoordinator.generateAndUpdateJobCoordinatorMetadata(mockJobModel);
+    assertFalse("JC metadata changed should remain unchanged",
+        clusterBasedJobCoordinator.isMetadataChangedAcrossAttempts());
+    verify(jobCoordinatorMetadataManager, times(0)).writeJobCoordinatorMetadata(any());
+
+    /*
+     * Verify if there are changes to metadata, we persist the new metadata & update the metadata changed flag
+     */
+    when(jobCoordinatorMetadataManager.checkForMetadataChanges(newMetadata, previousMetadata)).thenReturn(true);
+    clusterBasedJobCoordinator.generateAndUpdateJobCoordinatorMetadata(mockJobModel);
+    assertTrue("JC metadata changed should be true", clusterBasedJobCoordinator.isMetadataChangedAcrossAttempts());
+    verify(jobCoordinatorMetadataManager, times(1)).writeJobCoordinatorMetadata(newMetadata);
+  }
 }
diff --git a/samza-core/src/test/java/org/apache/samza/clustermanager/TestContainerPlacementActions.java b/samza-core/src/test/java/org/apache/samza/clustermanager/TestContainerPlacementActions.java
index 53bd5b0..c781f4d 100644
--- a/samza-core/src/test/java/org/apache/samza/clustermanager/TestContainerPlacementActions.java
+++ b/samza-core/src/test/java/org/apache/samza/clustermanager/TestContainerPlacementActions.java
@@ -150,7 +150,7 @@
     containerManager = spy(new ContainerManager(containerPlacementMetadataStore, state, clusterResourceManager, true, false, localityManager));
     allocatorWithHostAffinity = new MockContainerAllocatorWithHostAffinity(clusterResourceManager, config, state, containerManager);
     cpm = new ContainerProcessManager(clusterManagerConfig, state, new MetricsRegistryMap(),
-            clusterResourceManager, Optional.of(allocatorWithHostAffinity), containerManager, localityManager);
+            clusterResourceManager, Optional.of(allocatorWithHostAffinity), containerManager, localityManager, false);
   }
 
   @After
@@ -176,7 +176,7 @@
     containerManager = spy(new ContainerManager(containerPlacementMetadataStore, state, clusterResourceManager, true, true, mockLocalityManager));
     allocatorWithHostAffinity = new MockContainerAllocatorWithHostAffinity(clusterResourceManager, config, state, containerManager);
     cpm = new ContainerProcessManager(clusterManagerConfig, state, new MetricsRegistryMap(),
-        clusterResourceManager, Optional.of(allocatorWithHostAffinity), containerManager, mockLocalityManager);
+        clusterResourceManager, Optional.of(allocatorWithHostAffinity), containerManager, mockLocalityManager, false);
   }
 
   @Test(timeout = 10000)
@@ -556,7 +556,7 @@
     containerManager = spy(new ContainerManager(containerPlacementMetadataStore, state, clusterResourceManager, true, false, localityManager));
     allocatorWithHostAffinity = new MockContainerAllocatorWithHostAffinity(clusterResourceManager, config, state, containerManager);
     cpm = new ContainerProcessManager(clusterManagerConfig, state, new MetricsRegistryMap(),
-        clusterResourceManager, Optional.of(allocatorWithHostAffinity), containerManager, localityManager);
+        clusterResourceManager, Optional.of(allocatorWithHostAffinity), containerManager, localityManager, false);
 
     doAnswer(new Answer<Void>() {
       public Void answer(InvocationOnMock invocation) {
@@ -674,7 +674,7 @@
 
     ContainerProcessManager cpm = new ContainerProcessManager(
         new ClusterManagerConfig(new MapConfig(getConfig(), getConfigWithHostAffinityAndRetries(false, 1, true))), state,
-        new MetricsRegistryMap(), clusterResourceManager, Optional.of(allocatorWithoutHostAffinity), containerManager, localityManager);
+        new MetricsRegistryMap(), clusterResourceManager, Optional.of(allocatorWithoutHostAffinity), containerManager, localityManager, false);
 
     // Mimic Cluster Manager returning any request
     doAnswer(new Answer<Void>() {
@@ -807,7 +807,7 @@
         new MockContainerAllocatorWithHostAffinity(clusterResourceManager, config, state, containerManager);
     ContainerProcessManager cpm = new ContainerProcessManager(
         new ClusterManagerConfig(new MapConfig(getConfig(), getConfigWithHostAffinityAndRetries(true, 1, true))), state,
-        new MetricsRegistryMap(), clusterResourceManager, Optional.of(allocatorWithHostAffinity), containerManager, localityManager);
+        new MetricsRegistryMap(), clusterResourceManager, Optional.of(allocatorWithHostAffinity), containerManager, localityManager, false);
 
     doAnswer(new Answer<Void>() {
       public Void answer(InvocationOnMock invocation) {
diff --git a/samza-core/src/test/java/org/apache/samza/clustermanager/TestContainerProcessManager.java b/samza-core/src/test/java/org/apache/samza/clustermanager/TestContainerProcessManager.java
index d285f9e..bcbe53f 100644
--- a/samza-core/src/test/java/org/apache/samza/clustermanager/TestContainerProcessManager.java
+++ b/samza-core/src/test/java/org/apache/samza/clustermanager/TestContainerProcessManager.java
@@ -171,7 +171,8 @@
         clusterResourceManager,
         Optional.empty(),
         containerManager,
-        mockLocalityManager
+        mockLocalityManager,
+        false
     );
 
     allocator =
@@ -272,11 +273,63 @@
 
     // Verify only 1 was requested with allocator
     assertEquals(1, allocator.requestedContainers);
+    assertTrue("Ensure no processors were forcefully restarted", callback.resourceStatuses.isEmpty());
 
     cpm.stop();
   }
 
   @Test
+  public void testOnInitToForceRestartAMHighAvailability() throws Exception {
+    Map<String, String> configMap = new HashMap<>(configVals);
+    configMap.put(JobConfig.YARN_AM_HIGH_AVAILABILITY_ENABLED, "true");
+    Config conf = new MapConfig(configMap);
+    SamzaResource samzaResource = new SamzaResource(1, 1024, "host", "0");
+
+    SamzaApplicationState state = new SamzaApplicationState(getJobModelManager(2));
+    state.runningProcessors.put("0", samzaResource);
+
+    MockClusterResourceManagerCallback callback = new MockClusterResourceManagerCallback();
+    ClusterResourceManager clusterResourceManager = spy(new MockClusterResourceManager(callback, state));
+    ClusterManagerConfig clusterManagerConfig = spy(new ClusterManagerConfig(conf));
+    ContainerManager containerManager =
+        buildContainerManager(containerPlacementMetadataStore, state, clusterResourceManager,
+            clusterManagerConfig.getHostAffinityEnabled(), false);
+
+    ContainerProcessManager cpm =
+        buildContainerProcessManager(clusterManagerConfig, state, clusterResourceManager, Optional.empty(), true);
+
+    MockContainerAllocatorWithoutHostAffinity allocator = new MockContainerAllocatorWithoutHostAffinity(
+        clusterResourceManager,
+        conf,
+        state,
+        containerManager);
+
+    getPrivateFieldFromCpm("containerAllocator", cpm).set(cpm, allocator);
+    CountDownLatch latch = new CountDownLatch(1);
+    getPrivateFieldFromCpm("allocatorThread", cpm).set(cpm, new Thread() {
+      public void run() {
+        isRunning = true;
+        latch.countDown();
+      }
+    });
+
+    cpm.start();
+
+    if (!latch.await(2, TimeUnit.SECONDS)) {
+      Assert.fail("timed out waiting for the latch to expire");
+    }
+
+    verify(clusterResourceManager, times(1)).stopStreamProcessor(samzaResource);
+    assertEquals("CPM should stop the running container", 1, callback.resourceStatuses.size());
+
+    SamzaResourceStatus actualResourceStatus = callback.resourceStatuses.get(0);
+    assertEquals("Container 0 should be stopped", "0", actualResourceStatus.getContainerId());
+    assertEquals("Container 0 should have exited with preempted status", SamzaResourceStatus.PREEMPTED,
+        actualResourceStatus.getExitCode());
+    cpm.stop();
+  }
+
+  @Test
   public void testOnShutdown() throws Exception {
     Config conf = getConfig();
     SamzaApplicationState state = new SamzaApplicationState(getJobModelManager(1));
@@ -560,7 +613,8 @@
         containerManager);
 
     ContainerProcessManager cpm =
-        buildContainerProcessManager(clusterManagerConfig, state, clusterResourceManager, Optional.of(allocator), mockLocalityManager);
+        buildContainerProcessManager(clusterManagerConfig, state, clusterResourceManager, Optional.of(allocator),
+            mockLocalityManager, false);
 
     // start triggers a request
     cpm.start();
@@ -715,7 +769,7 @@
 
     ContainerProcessManager manager =
         new ContainerProcessManager(new ClusterManagerConfig(config), state, new MetricsRegistryMap(), clusterResourceManager,
-            Optional.of(allocator), containerManager, mockLocalityManager);
+            Optional.of(allocator), containerManager, mockLocalityManager, false);
 
     manager.start();
     SamzaResource resource = new SamzaResource(1, 1024, "host1", "resource-1");
@@ -751,7 +805,8 @@
         containerManager);
 
     ContainerProcessManager cpm =
-        spy(buildContainerProcessManager(new ClusterManagerConfig(cfg), state, clusterResourceManager, Optional.of(allocator), mockLocalityManager));
+        spy(buildContainerProcessManager(new ClusterManagerConfig(cfg), state, clusterResourceManager,
+            Optional.of(allocator), mockLocalityManager, false));
 
     cpm.start();
     assertFalse(cpm.shouldShutdown());
@@ -989,15 +1044,22 @@
   }
   private ContainerProcessManager buildContainerProcessManager(ClusterManagerConfig clusterManagerConfig, SamzaApplicationState state,
       ClusterResourceManager clusterResourceManager, Optional<ContainerAllocator> allocator) {
-    LocalityManager mockLocalityManager = mock(LocalityManager.class);
-    when(mockLocalityManager.readLocality()).thenReturn(new LocalityModel(new HashMap<>()));
-    return buildContainerProcessManager(clusterManagerConfig, state, clusterResourceManager, allocator, mockLocalityManager);
+    return buildContainerProcessManager(clusterManagerConfig, state, clusterResourceManager, allocator, false);
   }
 
   private ContainerProcessManager buildContainerProcessManager(ClusterManagerConfig clusterManagerConfig, SamzaApplicationState state,
-      ClusterResourceManager clusterResourceManager, Optional<ContainerAllocator> allocator, LocalityManager localityManager) {
+      ClusterResourceManager clusterResourceManager, Optional<ContainerAllocator> allocator, boolean restartContainer) {
+    LocalityManager mockLocalityManager = mock(LocalityManager.class);
+    when(mockLocalityManager.readLocality()).thenReturn(new LocalityModel(new HashMap<>()));
+    return buildContainerProcessManager(clusterManagerConfig, state, clusterResourceManager, allocator,
+        mockLocalityManager, restartContainer);
+  }
+
+  private ContainerProcessManager buildContainerProcessManager(ClusterManagerConfig clusterManagerConfig, SamzaApplicationState state,
+      ClusterResourceManager clusterResourceManager, Optional<ContainerAllocator> allocator, LocalityManager localityManager,
+      boolean restartContainers) {
     return new ContainerProcessManager(clusterManagerConfig, state, new MetricsRegistryMap(), clusterResourceManager,
         allocator, buildContainerManager(containerPlacementMetadataStore, state, clusterResourceManager,
-        clusterManagerConfig.getHostAffinityEnabled(), false, localityManager), localityManager);
+        clusterManagerConfig.getHostAffinityEnabled(), false, localityManager), localityManager, restartContainers);
   }
 }
diff --git a/samza-core/src/test/java/org/apache/samza/coordinator/TestJobCoordinatorMetadataManager.java b/samza-core/src/test/java/org/apache/samza/coordinator/TestJobCoordinatorMetadataManager.java
index bd177cc..70e65a3 100644
--- a/samza-core/src/test/java/org/apache/samza/coordinator/TestJobCoordinatorMetadataManager.java
+++ b/samza-core/src/test/java/org/apache/samza/coordinator/TestJobCoordinatorMetadataManager.java
@@ -29,6 +29,7 @@
 import org.apache.samza.container.TaskName;
 import org.apache.samza.coordinator.metadatastore.CoordinatorStreamStoreTestUtil;
 import org.apache.samza.coordinator.metadatastore.NamespaceAwareCoordinatorStreamStore;
+import org.apache.samza.coordinator.stream.CoordinatorStreamValueSerde;
 import org.apache.samza.coordinator.stream.messages.SetJobCoordinatorMetadataMessage;
 import org.apache.samza.job.JobCoordinatorMetadata;
 import org.apache.samza.job.model.ContainerModel;
@@ -36,6 +37,7 @@
 import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metadatastore.MetadataStore;
 import org.apache.samza.metrics.MetricsRegistryMap;
+import org.apache.samza.serializers.Serde;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -44,7 +46,9 @@
 import static org.apache.samza.coordinator.JobCoordinatorMetadataManager.CONTAINER_ID_PROPERTY;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyString;
 import static org.mockito.Mockito.doThrow;
@@ -144,6 +148,21 @@
   }
 
   @Test
+  public void testGenerateJobCoordinatorMetadataFailed() {
+    doThrow(new RuntimeException("Failed to generate epoch id"))
+        .when(jobCoordinatorMetadataManager).fetchEpochIdForJobCoordinator();
+
+    try {
+      jobCoordinatorMetadataManager.generateJobCoordinatorMetadata(new JobModel(OLD_CONFIG, containerModelMap), OLD_CONFIG);
+      fail("Expected generate job coordinator metadata to throw exception");
+    } catch (Exception e) {
+      assertTrue("Expecting SamzaException to be thrown", e instanceof SamzaException);
+      assertEquals("Metadata generation failed count should be 1", 1,
+          jobCoordinatorMetadataManager.getMetadataGenerationFailedCount().getCount());
+    }
+  }
+
+  @Test
   public void testGenerateJobCoordinatorMetadataForRepeatability() {
     when(jobCoordinatorMetadataManager.getEnvProperty(CONTAINER_ID_PROPERTY))
         .thenReturn(OLD_CONTAINER_ID);
@@ -175,6 +194,24 @@
   }
 
   @Test
+  public void testReadJobCoordinatorMetadataFailed() {
+    JobCoordinatorMetadata jobCoordinatorMetadata =
+        new JobCoordinatorMetadata(NEW_EPOCH_ID, NEW_CONFIG_ID, NEW_JOB_MODEL_ID);
+    Serde<String> mockSerde = spy(new CoordinatorStreamValueSerde(SetJobCoordinatorMetadataMessage.TYPE));
+    doThrow(new RuntimeException("Failed to read coordinator stream"))
+        .when(mockSerde).fromBytes(any());
+
+    jobCoordinatorMetadataManager = spy(new JobCoordinatorMetadataManager(metadataStore,
+        ClusterType.YARN, new MetricsRegistryMap(), mockSerde));
+    jobCoordinatorMetadataManager.writeJobCoordinatorMetadata(jobCoordinatorMetadata);
+
+    JobCoordinatorMetadata actualMetadata = jobCoordinatorMetadataManager.readJobCoordinatorMetadata();
+    assertNull("Read failed should return null", actualMetadata);
+    assertEquals("Metadata read failed count should be 1", 1,
+        jobCoordinatorMetadataManager.getMetadataReadFailedCount().getCount());
+  }
+
+  @Test
   public void testReadWriteJobCoordinatorMetadata() {
     JobCoordinatorMetadata jobCoordinatorMetadata =
         new JobCoordinatorMetadata(NEW_EPOCH_ID, NEW_CONFIG_ID, NEW_JOB_MODEL_ID);
@@ -190,10 +227,17 @@
     jobCoordinatorMetadataManager.writeJobCoordinatorMetadata(null);
   }
 
-  @Test (expected = SamzaException.class)
+  @Test
   public void testWriteJobCoordinatorMetadataBubblesException() {
-    doThrow(new RuntimeException("failed to write to coordinator stream"))
+    doThrow(new RuntimeException("Failed to write to coordinator stream"))
         .when(metadataStore).put(anyString(), any());
-    jobCoordinatorMetadataManager.writeJobCoordinatorMetadata(mock(JobCoordinatorMetadata.class));
+    try {
+      jobCoordinatorMetadataManager.writeJobCoordinatorMetadata(mock(JobCoordinatorMetadata.class));
+      fail("Expected write job coordinator metadata to throw exception");
+    } catch (Exception e) {
+      assertTrue("Expecting SamzaException to be thrown", e instanceof SamzaException);
+      assertEquals("Metadata write failed count should be 1", 1,
+          jobCoordinatorMetadataManager.getMetadataWriteFailedCount().getCount());
+    }
   }
 }
diff --git a/samza-yarn/src/main/java/org/apache/samza/job/yarn/YarnClusterResourceManager.java b/samza-yarn/src/main/java/org/apache/samza/job/yarn/YarnClusterResourceManager.java
index d1c5437..fa784e0 100644
--- a/samza-yarn/src/main/java/org/apache/samza/job/yarn/YarnClusterResourceManager.java
+++ b/samza-yarn/src/main/java/org/apache/samza/job/yarn/YarnClusterResourceManager.java
@@ -19,6 +19,7 @@
 
 package org.apache.samza.job.yarn;
 
+import com.google.common.annotations.VisibleForTesting;
 import java.time.Duration;
 import java.util.Set;
 import org.apache.hadoop.fs.FileStatus;
@@ -331,9 +332,26 @@
   public void stopStreamProcessor(SamzaResource resource) {
     synchronized (lock) {
       Container container = allocatedResources.get(resource);
+      String containerId = resource.getContainerId();
+      String containerHost = resource.getHost();
+      /*
+       * 1. Stop the container through NMClient if the container was instantiated as part of NMClient lifecycle.
+       * 2. Stop the container through AMClient by release the assigned container if the container was from the previous
+       *    attempt and managed by the AM due to AM-HA
+       * 3. Ignore the request if the container associated with the resource isn't present in the book keeping.
+       */
       if (container != null) {
-        log.info("Stopping Container ID: {} on host: {}", resource.getContainerId(), resource.getHost());
+        log.info("Stopping Container ID: {} on host: {}", containerId, containerHost);
         this.nmClientAsync.stopContainerAsync(container.getId(), container.getNodeId());
+      } else {
+        YarnContainer yarnContainer = state.runningProcessors.get(getRunningProcessorId(containerId));
+        if (yarnContainer != null) {
+          log.info("Stopping container from previous attempt with Container ID: {} on host: {}",
+              containerId, containerHost);
+          amClient.releaseAssignedContainer(yarnContainer.id());
+        } else {
+          log.info("No container with Container ID: {} exists. Ignoring the stop request", containerId);
+        }
       }
     }
   }
@@ -746,4 +764,9 @@
           "Ignoring notification.", containerId);
     }
   }
+
+  @VisibleForTesting
+  ConcurrentHashMap<SamzaResource, Container> getAllocatedResources() {
+    return allocatedResources;
+  }
 }
diff --git a/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestYarnClusterResourceManager.java b/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestYarnClusterResourceManager.java
index 08b18e8..89929f7 100644
--- a/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestYarnClusterResourceManager.java
+++ b/samza-yarn/src/test/java/org/apache/samza/job/yarn/TestYarnClusterResourceManager.java
@@ -46,6 +46,7 @@
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.MapConfig;
 import org.junit.Assert;
+import org.junit.Before;
 import org.junit.Test;
 import org.mockito.ArgumentCaptor;
 import org.mockito.Mockito;
@@ -58,19 +59,33 @@
 
 public class TestYarnClusterResourceManager {
 
+  private YarnConfiguration yarnConfiguration;
+  private Config config;
+  private SamzaAppMasterMetrics metrics;
+  private AMRMClientAsync asyncClient;
+  private SamzaYarnAppMasterLifecycle lifecycle;
+  private SamzaYarnAppMasterService service;
+  private NMClientAsync asyncNMClient;
+  private ClusterResourceManager.Callback callback;
+  private YarnAppState yarnAppState;
+
+  @Before
+  public void setup() {
+    yarnConfiguration = mock(YarnConfiguration.class);
+    config = mock(Config.class);
+    metrics = mock(SamzaAppMasterMetrics.class);
+    asyncClient = mock(AMRMClientAsync.class);
+    lifecycle = mock(SamzaYarnAppMasterLifecycle.class);
+    service = mock(SamzaYarnAppMasterService.class);
+    asyncNMClient = mock(NMClientAsync.class);
+    callback = mock(ClusterResourceManager.Callback.class);
+    yarnAppState = new YarnAppState(0, mock(ContainerId.class), "host", 8080, 8081);
+  }
+
   @Test
   public void testErrorInStartContainerShouldUpdateState() {
     // create mocks
     final int samzaContainerId = 1;
-    YarnConfiguration yarnConfiguration = mock(YarnConfiguration.class);
-    SamzaAppMasterMetrics metrics = mock(SamzaAppMasterMetrics.class);
-    Config config = mock(Config.class);
-    AMRMClientAsync asyncClient = mock(AMRMClientAsync.class);
-    YarnAppState yarnAppState = new YarnAppState(0, mock(ContainerId.class), "host", 8080, 8081);
-    SamzaYarnAppMasterLifecycle lifecycle = mock(SamzaYarnAppMasterLifecycle.class);
-    SamzaYarnAppMasterService service = mock(SamzaYarnAppMasterService.class);
-    NMClientAsync asyncNMClient = mock(NMClientAsync.class);
-    ClusterResourceManager.Callback callback = mock(ClusterResourceManager.Callback.class);
 
     // start the cluster manager
     YarnClusterResourceManager yarnClusterResourceManager =
@@ -94,16 +109,6 @@
 
   @Test
   public void testAllocatedResourceExpiryForYarn() {
-    YarnConfiguration yarnConfiguration = mock(YarnConfiguration.class);
-    SamzaAppMasterMetrics metrics = mock(SamzaAppMasterMetrics.class);
-    Config config = mock(Config.class);
-    AMRMClientAsync asyncClient = mock(AMRMClientAsync.class);
-    YarnAppState yarnAppState = new YarnAppState(0, mock(ContainerId.class), "host", 8080, 8081);
-    SamzaYarnAppMasterLifecycle lifecycle = mock(SamzaYarnAppMasterLifecycle.class);
-    SamzaYarnAppMasterService service = mock(SamzaYarnAppMasterService.class);
-    NMClientAsync asyncNMClient = mock(NMClientAsync.class);
-    ClusterResourceManager.Callback callback = mock(ClusterResourceManager.Callback.class);
-
     // start the cluster manager
     YarnClusterResourceManager yarnClusterResourceManager =
         new YarnClusterResourceManager(asyncClient, asyncNMClient, callback, yarnAppState, lifecycle, service, metrics,
@@ -118,15 +123,7 @@
   @Test
   public void testAMShutdownOnRMCallback() throws IOException, YarnException {
     // create mocks
-    YarnConfiguration yarnConfiguration = mock(YarnConfiguration.class);
-    SamzaAppMasterMetrics metrics = mock(SamzaAppMasterMetrics.class);
-    Config config = mock(Config.class);
-    AMRMClientAsync asyncClient = mock(AMRMClientAsync.class);
-    YarnAppState yarnAppState = new YarnAppState(0, mock(ContainerId.class), "host", 8080, 8081);
     SamzaYarnAppMasterLifecycle lifecycle = Mockito.spy(new SamzaYarnAppMasterLifecycle(512, 2, mock(SamzaApplicationState.class), yarnAppState, asyncClient, false));
-    SamzaYarnAppMasterService service = mock(SamzaYarnAppMasterService.class);
-    NMClientAsync asyncNMClient = mock(NMClientAsync.class);
-    ClusterResourceManager.Callback callback = mock(ClusterResourceManager.Callback.class);
 
     // start the cluster manager
     YarnClusterResourceManager yarnClusterResourceManager =
@@ -146,15 +143,7 @@
   @Test
   public void testAMShutdownThrowingExceptionOnRMCallback() throws IOException, YarnException {
     // create mocks
-    YarnConfiguration yarnConfiguration = mock(YarnConfiguration.class);
-    SamzaAppMasterMetrics metrics = mock(SamzaAppMasterMetrics.class);
-    Config config = mock(Config.class);
-    AMRMClientAsync asyncClient = mock(AMRMClientAsync.class);
-    YarnAppState yarnAppState = new YarnAppState(0, mock(ContainerId.class), "host", 8080, 8081);
     SamzaYarnAppMasterLifecycle lifecycle = Mockito.spy(new SamzaYarnAppMasterLifecycle(512, 2, mock(SamzaApplicationState.class), yarnAppState, asyncClient, false));
-    SamzaYarnAppMasterService service = mock(SamzaYarnAppMasterService.class);
-    NMClientAsync asyncNMClient = mock(NMClientAsync.class);
-    ClusterResourceManager.Callback callback = mock(ClusterResourceManager.Callback.class);
 
     doThrow(InvalidApplicationMasterRequestException.class).when(asyncClient).unregisterApplicationMaster(FinalApplicationStatus.FAILED, null, null);
 
@@ -174,18 +163,11 @@
   }
 
   @Test
-  public void testAMHACallbackInvokedForPreviousAttemptContainers() throws IOException, YarnException {
+  public void testAMHACallbackInvokedForPreviousAttemptContainers() {
     String previousAttemptContainerId = "0";
     String previousAttemptYarnContainerId = "container_1607304997422_0008_02_000002";
     // create mocks
-    YarnConfiguration yarnConfiguration = mock(YarnConfiguration.class);
-    SamzaAppMasterMetrics metrics = mock(SamzaAppMasterMetrics.class);
-    AMRMClientAsync asyncClient = mock(AMRMClientAsync.class);
     YarnAppState yarnAppState = Mockito.spy(new YarnAppState(0, mock(ContainerId.class), "host", 8080, 8081));
-    SamzaYarnAppMasterLifecycle lifecycle = mock(SamzaYarnAppMasterLifecycle.class);
-    SamzaYarnAppMasterService service = mock(SamzaYarnAppMasterService.class);
-    NMClientAsync asyncNMClient = mock(NMClientAsync.class);
-    ClusterResourceManager.Callback callback = mock(ClusterResourceManager.Callback.class);
 
     ContainerId containerId = mock(ContainerId.class);
     when(containerId.toString()).thenReturn(previousAttemptYarnContainerId);
@@ -222,4 +204,47 @@
     SamzaResource samzaResource = samzaResourceArgumentCaptor.getValue();
     assertEquals(previousAttemptYarnContainerId, samzaResource.getContainerId());
   }
+
+  @Test
+  public void testStopStreamProcessorForContainerFromPreviousAttempt() {
+    String containerId = "Yarn_Container_id_0";
+    String processorId = "Container_id_0";
+    YarnContainer runningYarnContainer = mock(YarnContainer.class);
+    ContainerId previousRunningContainerId = mock(ContainerId.class);
+    YarnAppState yarnAppState = Mockito.spy(new YarnAppState(0, mock(ContainerId.class), "host", 8080, 8081));
+
+    yarnAppState.runningProcessors.put(processorId, runningYarnContainer);
+    when(runningYarnContainer.id()).thenReturn(previousRunningContainerId);
+    when(previousRunningContainerId.toString()).thenReturn(containerId);
+
+    YarnClusterResourceManager yarnClusterResourceManager =
+        new YarnClusterResourceManager(asyncClient, asyncNMClient, callback, yarnAppState, lifecycle, service, metrics,
+            yarnConfiguration, config);
+
+    SamzaResource containerResourceFromPreviousRun = mock(SamzaResource.class);
+    when(containerResourceFromPreviousRun.getContainerId()).thenReturn(containerId);
+
+    yarnClusterResourceManager.stopStreamProcessor(containerResourceFromPreviousRun);
+    verify(asyncClient, times(1)).releaseAssignedContainer(previousRunningContainerId);
+  }
+
+  @Test
+  public void testStopStreamProcessorForContainerStartedInCurrentLifecycle() {
+    YarnClusterResourceManager yarnClusterResourceManager =
+        new YarnClusterResourceManager(asyncClient, asyncNMClient, callback, yarnAppState, lifecycle, service, metrics,
+            yarnConfiguration, config);
+
+    SamzaResource allocatedContainerResource = mock(SamzaResource.class);
+    Container runningContainer = mock(Container.class);
+    ContainerId runningContainerId = mock(ContainerId.class);
+    NodeId runningNodeId = mock(NodeId.class);
+
+    when(runningContainer.getId()).thenReturn(runningContainerId);
+    when(runningContainer.getNodeId()).thenReturn(runningNodeId);
+
+    yarnClusterResourceManager.getAllocatedResources().put(allocatedContainerResource, runningContainer);
+    yarnClusterResourceManager.stopStreamProcessor(allocatedContainerResource);
+
+    verify(asyncNMClient, times(1)).stopContainerAsync(runningContainerId, runningNodeId);
+  }
 }
\ No newline at end of file