Calculate requested container count based on adding allocated count and outstanding ContainerRequests in Yarn (#3524)

diff --git a/gobblin-yarn/src/main/java/org/apache/gobblin/yarn/YarnService.java b/gobblin-yarn/src/main/java/org/apache/gobblin/yarn/YarnService.java
index a81960f..2f8a06c 100644
--- a/gobblin-yarn/src/main/java/org/apache/gobblin/yarn/YarnService.java
+++ b/gobblin-yarn/src/main/java/org/apache/gobblin/yarn/YarnService.java
@@ -191,8 +191,6 @@
   // instance names get picked up when replacement containers get allocated.
   private final Set<String> unusedHelixInstanceNames = ConcurrentHashMap.newKeySet();
 
-  // The map from helix tag to requested container count
-  private final Map<String, Integer> requestedContainerCountMap = Maps.newConcurrentMap();
   // The map from helix tag to allocated container count
   private final Map<String, Integer> allocatedContainerCountMap = Maps.newConcurrentMap();
 
@@ -444,11 +442,12 @@
     for (Map.Entry<String, Integer> entry : yarnContainerRequestBundle.getHelixTagContainerCountMap().entrySet()) {
       String currentHelixTag = entry.getKey();
       int desiredContainerCount = entry.getValue();
-      int requestedContainerCount = requestedContainerCountMap.getOrDefault(currentHelixTag, 0);
+      // Calculate requested container count based on adding allocated count and outstanding ContainerRequests in Yarn
+      int requestedContainerCount = allocatedContainerCountMap.getOrDefault(currentHelixTag, 0)
+          + getMatchingRequestsCount(yarnContainerRequestBundle.getHelixTagResourceMap().get(currentHelixTag));
       for(; requestedContainerCount < desiredContainerCount; requestedContainerCount++) {
         requestContainer(Optional.absent(), yarnContainerRequestBundle.getHelixTagResourceMap().get(currentHelixTag));
       }
-      requestedContainerCountMap.put(currentHelixTag, requestedContainerCount);
     }
 
     // If the total desired is lower than the currently allocated amount then release free containers.
@@ -466,8 +465,6 @@
         ContainerInfo containerInfo = entry.getValue();
         if (!inUseInstances.contains(containerInfo.getHelixParticipantId())) {
           containersToRelease.add(containerInfo.getContainer());
-          requestedContainerCountMap.put(containerInfo.getHelixTag(),
-              requestedContainerCountMap.get(containerInfo.getHelixTag()) - 1);
         }
 
         if (containersToRelease.size() == numToShutdown) {
@@ -480,8 +477,8 @@
       this.eventBus.post(new ContainerReleaseRequest(containersToRelease));
     }
     this.yarnContainerRequest = yarnContainerRequestBundle;
-    LOGGER.info("Current tag-container being requested:{}, tag-container allocated: {}",
-        this.requestedContainerCountMap, this.allocatedContainerCountMap);
+    LOGGER.info("Current tag-container desired count:{}, tag-container allocated: {}",
+        yarnContainerRequestBundle.getHelixTagContainerCountMap(), this.allocatedContainerCountMap);
   }
 
   // Request initial containers with default resource and helix tag
@@ -663,8 +660,8 @@
     String helixTag = completedContainerInfo == null ? helixInstanceTags : completedContainerInfo.getHelixTag();
     allocatedContainerCountMap.put(helixTag, allocatedContainerCountMap.get(helixTag) - 1);
 
-    LOGGER.info(String.format("Container %s running Helix instance %s has completed with exit status %d",
-        containerStatus.getContainerId(), completedInstanceName, containerStatus.getExitStatus()));
+    LOGGER.info(String.format("Container %s running Helix instance %s with tag %s has completed with exit status %d",
+        containerStatus.getContainerId(), completedInstanceName, helixTag, containerStatus.getExitStatus()));
 
     if (!Strings.isNullOrEmpty(containerStatus.getDiagnostics())) {
       LOGGER.info(String.format("Received the following diagnostics information for container %s: %s",
@@ -759,6 +756,15 @@
   }
 
   /**
+   * Get the number of matching container requests for the specified resource memory and cores.
+   */
+  private int getMatchingRequestsCount(Resource resource) {
+    int priorityNum = resourcePriorityMap.getOrDefault(resource.toString(), 0);
+    Priority priority = Priority.newInstance(priorityNum);
+    return getAmrmClientAsync().getMatchingRequests(priority, ResourceRequest.ANY, resource).size();
+  }
+
+  /**
    * A custom implementation of {@link AMRMClientAsync.CallbackHandler}.
    */
   private class AMRMClientCallbackHandler implements AMRMClientAsync.CallbackHandler {