TEZ-4339: Expose real-time memory consumption of AM and task containers via DagClient (#157) (Laszlo Bodor reviewed by Rajesh Balamohan)

diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/DagTypeConverters.java b/tez-api/src/main/java/org/apache/tez/dag/api/DagTypeConverters.java
index 5a2cb64..c563f1f 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/DagTypeConverters.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/DagTypeConverters.java
@@ -628,6 +628,8 @@
     switch (statusGetOpts) {
     case GET_COUNTERS:
       return DAGProtos.StatusGetOptsProto.GET_COUNTERS;
+    case GET_MEMORY_USAGE:
+      return DAGProtos.StatusGetOptsProto.GET_MEMORY_USAGE;
     }
     throw new TezUncheckedException("Could not convert StatusGetOpts to" + " proto");
   }
@@ -636,6 +638,8 @@
     switch (proto) {
     case GET_COUNTERS:
       return StatusGetOpts.GET_COUNTERS;
+    case GET_MEMORY_USAGE:
+      return StatusGetOpts.GET_MEMORY_USAGE;
     }
     throw new TezUncheckedException("Could not convert to StatusGetOpts from" + " proto");
   }
diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGStatus.java b/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGStatus.java
index cbf641e..1f8db62 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGStatus.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGStatus.java
@@ -152,6 +152,14 @@
     return dagCounters;
   }
 
+  public long getMemoryUsedByAM() {
+    return proxy.getMemoryUsedByAM();
+  }
+
+  public long getMemoryUsedByTasks() {
+    return proxy.getMemoryUsedByTasks();
+  }
+
   @InterfaceAudience.Private
   DagStatusSource getSource() {
     return this.source;
@@ -201,12 +209,12 @@
   @Override
   public String toString() {
     StringBuilder sb = new StringBuilder();
-    sb.append("status=" + getState()
-      + ", progress=" + getDAGProgress()
-      + ", diagnostics="
-      + StringUtils.join(getDiagnostics(), LINE_SEPARATOR)
-      + ", counters="
-      + (getDAGCounters() == null ? "null" : getDAGCounters().toString()));
+    sb.append("status=" + getState());
+    sb.append(", progress=" + getDAGProgress());
+    sb.append(", diagnostics=" + StringUtils.join(getDiagnostics(), LINE_SEPARATOR));
+    sb.append(", memoryUsedByAM=").append(proxy.getMemoryUsedByAM());
+    sb.append(", memoryUsedByTasks=").append(proxy.getMemoryUsedByTasks());
+    sb.append(", counters=" + (getDAGCounters() == null ? "null" : getDAGCounters().toString()));
     return sb.toString();
   }
 
diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/client/StatusGetOpts.java b/tez-api/src/main/java/org/apache/tez/dag/api/client/StatusGetOpts.java
index 1a9df7a..3518d33 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/client/StatusGetOpts.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/client/StatusGetOpts.java
@@ -29,5 +29,6 @@
 @Evolving
 public enum StatusGetOpts {
   /** Retrieve Counters with Status */
-  GET_COUNTERS
+  GET_COUNTERS,
+  GET_MEMORY_USAGE
 }
diff --git a/tez-api/src/main/proto/DAGApiRecords.proto b/tez-api/src/main/proto/DAGApiRecords.proto
index 4c8c7f6..15f681d 100644
--- a/tez-api/src/main/proto/DAGApiRecords.proto
+++ b/tez-api/src/main/proto/DAGApiRecords.proto
@@ -275,6 +275,8 @@
   optional ProgressProto DAGProgress = 3;
   repeated StringProgressPairProto vertexProgress = 4;
   optional TezCountersProto dagCounters = 5;
+  optional int64 memoryUsedByAM = 6;
+  optional int64 memoryUsedByTasks = 7;
 }
 
 message PlanLocalResourcesProto {
@@ -299,6 +301,7 @@
 
 enum StatusGetOptsProto {
   GET_COUNTERS = 0;
+  GET_MEMORY_USAGE = 1;
 }
 
 message VertexLocationHintProto {
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/api/client/DAGStatusBuilder.java b/tez-dag/src/main/java/org/apache/tez/dag/api/client/DAGStatusBuilder.java
index 0002d8b..931c6d0 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/api/client/DAGStatusBuilder.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/api/client/DAGStatusBuilder.java
@@ -61,6 +61,13 @@
     getBuilder().addVertexProgress(builder.build());
   }
 
+  //TODO: let this be a map of values in protobuf 3.x
+  public void setMemoryUsage(long memoryUsedByAM, long memoryUsedByTasks) {
+    Builder builder = getBuilder();
+    builder.setMemoryUsedByAM(memoryUsedByAM);
+    builder.setMemoryUsedByTasks(memoryUsedByTasks);
+  }
+
   public DAGStatusProto getProto() {
     return getBuilder().build();
   }
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/TaskCommunicatorManager.java b/tez-dag/src/main/java/org/apache/tez/dag/app/TaskCommunicatorManager.java
index 3a99456..ac2f760 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/TaskCommunicatorManager.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/TaskCommunicatorManager.java
@@ -678,4 +678,12 @@
     return null;
   }
 
+  @Override
+  public long getTotalUsedMemory() {
+    long totalUsedMemory = 0;
+    for (int i = 0; i < taskCommunicators.length; i++) {
+      totalUsedMemory += taskCommunicators[i].getTaskCommunicator().getTotalUsedMemory();
+    }
+    return totalUsedMemory;
+  }
 }
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/TaskCommunicatorManagerInterface.java b/tez-dag/src/main/java/org/apache/tez/dag/app/TaskCommunicatorManagerInterface.java
index 254e74c..150977a 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/TaskCommunicatorManagerInterface.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/TaskCommunicatorManagerInterface.java
@@ -54,4 +54,5 @@
 
   String getCompletedLogsUrl(int taskCommId, TezTaskAttemptID attemptID, NodeId containerNodeId);
 
+  long getTotalUsedMemory();
 }
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/TezTaskCommunicatorImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/TezTaskCommunicatorImpl.java
index 6d69d36..48aee31 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/TezTaskCommunicatorImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/TezTaskCommunicatorImpl.java
@@ -104,6 +104,7 @@
     Credentials credentials = null;
     boolean credentialsChanged = false;
     boolean taskPulled = false;
+    long usedMemory = 0;
 
     void reset() {
       taskSpec = null;
@@ -382,6 +383,7 @@
       response.setLastRequestId(requestId);
       containerInfo.lastRequestId = requestId;
       containerInfo.lastResponse = response;
+      containerInfo.usedMemory = request.getUsedMemory();
       return response;
     }
 
@@ -466,4 +468,8 @@
   protected ContainerId getContainerForAttempt(TezTaskAttemptID taskAttemptId) {
     return attemptToContainerMap.get(taskAttemptId);
   }
+
+  public long getTotalUsedMemory() {
+    return registeredContainers.values().stream().mapToLong(c -> c.usedMemory).sum();
+  }
 }
\ No newline at end of file
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java
index 026ca29..07715cd 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java
@@ -19,6 +19,8 @@
 package org.apache.tez.dag.app.dag.impl;
 
 import java.io.IOException;
+import java.lang.management.ManagementFactory;
+import java.lang.management.MemoryMXBean;
 import java.net.URL;
 import java.security.PrivilegedExceptionAction;
 import java.util.ArrayList;
@@ -244,6 +246,8 @@
   private static final CommitCompletedTransition COMMIT_COMPLETED_TRANSITION =
       new CommitCompletedTransition();
 
+  private final MemoryMXBean memoryMXBean = ManagementFactory.getMemoryMXBean();
+
   protected static final
     StateMachineFactory<DAGImpl, DAGState, DAGEventType, DAGEvent>
        stateMachineFactory
@@ -940,6 +944,10 @@
       if (statusOptions.contains(StatusGetOpts.GET_COUNTERS)) {
         status.setDAGCounters(getAllCounters());
       }
+      if (statusOptions.contains(StatusGetOpts.GET_MEMORY_USAGE)) {
+        status.setMemoryUsage(memoryMXBean.getHeapMemoryUsage().getUsed(),
+            taskCommunicatorManagerInterface.getTotalUsedMemory());
+      }
       return status;
     } finally {
       readLock.unlock();
diff --git a/tez-dag/src/main/java/org/apache/tez/serviceplugins/api/TaskCommunicator.java b/tez-dag/src/main/java/org/apache/tez/serviceplugins/api/TaskCommunicator.java
index fceddf2..be6ad68 100644
--- a/tez-dag/src/main/java/org/apache/tez/serviceplugins/api/TaskCommunicator.java
+++ b/tez-dag/src/main/java/org/apache/tez/serviceplugins/api/TaskCommunicator.java
@@ -237,4 +237,13 @@
     return null;
   }
 
+  /**
+   * Return the amount of memory used by the containers. Each container is supposed to refresh
+   * its current state via heartbeat requests, and the TaskCommunicator implementation is supposed
+   * to aggregate this properly.
+   * @return memory in MB
+   */
+  public long getTotalUsedMemory() {
+    return 0;
+  }
 }
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java b/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java
index 9bceaec..b3ddaa0 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/MockDAGAppMaster.java
@@ -431,7 +431,7 @@
                   EventProducerConsumerType.SYSTEM, cData.vName, "", cData.taId),
                   MockDAGAppMaster.this.getContext().getClock().getTime()));
               TezHeartbeatRequest request = new TezHeartbeatRequest(cData.numUpdates, events,
-                  cData.nextPreRoutedFromEventId, cData.cIdStr, cData.taId, cData.nextFromEventId, 50000);
+                  cData.nextPreRoutedFromEventId, cData.cIdStr, cData.taId, cData.nextFromEventId, 50000, 0);
               doHeartbeat(request, cData);
             } else if (version != null && cData.taId.getId() <= version.intValue()) {
               preemptContainer(cData);
@@ -443,7 +443,7 @@
                       EventProducerConsumerType.SYSTEM, cData.vName, "", cData.taId),
                   MockDAGAppMaster.this.getContext().getClock().getTime()));
               TezHeartbeatRequest request = new TezHeartbeatRequest(++cData.numUpdates, events,
-                  cData.nextPreRoutedFromEventId, cData.cIdStr, cData.taId, cData.nextFromEventId, 10000);
+                  cData.nextPreRoutedFromEventId, cData.cIdStr, cData.taId, cData.nextFromEventId, 10000, 0);
               doHeartbeat(request, cData);
               cData.clear();
             }
diff --git a/tez-runtime-internals/src/main/java/org/apache/tez/runtime/api/impl/TezHeartbeatRequest.java b/tez-runtime-internals/src/main/java/org/apache/tez/runtime/api/impl/TezHeartbeatRequest.java
index 7ed89f8..fd5bc17 100644
--- a/tez-runtime-internals/src/main/java/org/apache/tez/runtime/api/impl/TezHeartbeatRequest.java
+++ b/tez-runtime-internals/src/main/java/org/apache/tez/runtime/api/impl/TezHeartbeatRequest.java
@@ -39,13 +39,14 @@
   private int preRoutedStartIndex;
   private int maxEvents;
   private long requestId;
+  private long usedMemory;
 
   public TezHeartbeatRequest() {
   }
 
   public TezHeartbeatRequest(long requestId, List<TezEvent> events,
       int preRoutedStartIndex, String containerIdentifier,
-      TezTaskAttemptID taskAttemptID, int startIndex, int maxEvents) {
+      TezTaskAttemptID taskAttemptID, int startIndex, int maxEvents, long usedMemory) {
     this.containerIdentifier = containerIdentifier;
     this.requestId = requestId;
     this.events = Collections.unmodifiableList(events);
@@ -53,6 +54,7 @@
     this.preRoutedStartIndex = preRoutedStartIndex;
     this.maxEvents = maxEvents;
     this.currentTaskAttemptID = taskAttemptID;
+    this.usedMemory = usedMemory;
   }
 
   public String getContainerIdentifier() {
@@ -83,6 +85,10 @@
     return currentTaskAttemptID;
   }
 
+  public long getUsedMemory() {
+    return usedMemory;
+  }
+
   @Override
   public void write(DataOutput out) throws IOException {
     if (events != null) {
@@ -105,6 +111,7 @@
     out.writeInt(maxEvents);
     out.writeLong(requestId);
     Text.writeString(out, containerIdentifier);
+    out.writeLong(usedMemory);
   }
 
   @Override
@@ -128,6 +135,7 @@
     maxEvents = in.readInt();
     requestId = in.readLong();
     containerIdentifier = Text.readString(in);
+    usedMemory = in.readLong();
   }
 
   @Override
@@ -140,6 +148,7 @@
         + ", maxEventsToGet=" + maxEvents
         + ", taskAttemptId=" + currentTaskAttemptID
         + ", eventCount=" + (events != null ? events.size() : 0)
+        + ", usedMemory=" + usedMemory
         + " }";
   }
 }
diff --git a/tez-runtime-internals/src/main/java/org/apache/tez/runtime/task/TaskReporter.java b/tez-runtime-internals/src/main/java/org/apache/tez/runtime/task/TaskReporter.java
index 978942d..eeb2434 100644
--- a/tez-runtime-internals/src/main/java/org/apache/tez/runtime/task/TaskReporter.java
+++ b/tez-runtime-internals/src/main/java/org/apache/tez/runtime/task/TaskReporter.java
@@ -19,6 +19,8 @@
 package org.apache.tez.runtime.task;
 
 import java.io.IOException;
+import java.lang.management.ManagementFactory;
+import java.lang.management.MemoryMXBean;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.List;
@@ -136,6 +138,7 @@
 
     private static final int LOG_COUNTER_START_INTERVAL = 5000; // 5 seconds
     private static final float LOG_COUNTER_BACKOFF = 1.3f;
+    private static final int HEAP_MEMORY_USAGE_UPDATE_INTERVAL = 5000; // 5 seconds
 
     private final RuntimeTask task;
     private final EventMetaData updateEventMetadata;
@@ -157,6 +160,10 @@
     private final ReentrantLock lock = new ReentrantLock();
     private final Condition condition = lock.newCondition();
 
+    private final MemoryMXBean memoryMXBean = ManagementFactory.getMemoryMXBean();
+    private long usedMemory = 0;
+    private long heapMemoryUsageUpdatedTime = System.currentTimeMillis() - HEAP_MEMORY_USAGE_UPDATE_INTERVAL;
+
     /*
      * Keeps track of regular timed heartbeats. Is primarily used as a timing mechanism to send /
      * log counters.
@@ -263,7 +270,7 @@
       int fromPreRoutedEventId = task.getNextPreRoutedEventId();
       int maxEvents = Math.min(maxEventsToGet, task.getMaxEventsToHandle());
       TezHeartbeatRequest request = new TezHeartbeatRequest(requestId, events, fromPreRoutedEventId,
-          containerIdStr, task.getTaskAttemptID(), fromEventId, maxEvents);
+          containerIdStr, task.getTaskAttemptID(), fromEventId, maxEvents, getUsedMemory());
       LOG.debug("Sending heartbeat to AM, request={}", request);
 
       maybeLogCounters();
@@ -305,6 +312,15 @@
       return new ResponseWrapper(false, numEventsReceived);
     }
 
+    private long getUsedMemory() {
+      long now = System.currentTimeMillis();
+      if (now - heapMemoryUsageUpdatedTime > HEAP_MEMORY_USAGE_UPDATE_INTERVAL) {
+        usedMemory = memoryMXBean.getHeapMemoryUsage().getUsed();
+        heapMemoryUsageUpdatedTime = now;
+      }
+      return usedMemory;
+    }
+
     public void markComplete() {
       // Notify to clear pending events, if any.
       lock.lock();
diff --git a/tez-tests/src/test/java/org/apache/tez/mapreduce/TestMRRJobsDAGApi.java b/tez-tests/src/test/java/org/apache/tez/mapreduce/TestMRRJobsDAGApi.java
index 96b7bbf..95d5bcf 100644
--- a/tez-tests/src/test/java/org/apache/tez/mapreduce/TestMRRJobsDAGApi.java
+++ b/tez-tests/src/test/java/org/apache/tez/mapreduce/TestMRRJobsDAGApi.java
@@ -20,6 +20,7 @@
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
 
 import java.io.File;
 import java.io.IOException;
@@ -211,13 +212,18 @@
           + dagStatus.getState());
       Thread.sleep(500l);
       dagStatus = dagClient.getDAGStatus(null);
+      assertTrue("Memory used by AM is supposed to be 0 if not requested", dagStatus.getMemoryUsedByAM() == 0);
+      assertTrue("Memory used by tasks is supposed to be 0 if not requested", dagStatus.getMemoryUsedByTasks() == 0);
     }
-    dagStatus = dagClient.getDAGStatus(Sets.newHashSet(StatusGetOpts.GET_COUNTERS));
+    dagStatus = dagClient.getDAGStatus(Sets.newHashSet(StatusGetOpts.GET_COUNTERS, StatusGetOpts.GET_MEMORY_USAGE));
 
     assertEquals(DAGStatus.State.SUCCEEDED, dagStatus.getState());
     assertNotNull(dagStatus.getDAGCounters());
     assertNotNull(dagStatus.getDAGCounters().getGroup(FileSystemCounter.class.getName()));
     assertNotNull(dagStatus.getDAGCounters().findCounter(TaskCounter.GC_TIME_MILLIS));
+    assertTrue("Memory used by AM is supposed to be >0", dagStatus.getMemoryUsedByAM() > 0);
+    assertTrue("Memory used by tasks is supposed to be >0", dagStatus.getMemoryUsedByTasks() > 0);
+
     ExampleDriver.printDAGStatus(dagClient, new String[] { "SleepVertex" }, true, true);
     tezSession.stop();
   }