TEZ-4103. Progress in DAG, Vertex, and tasks is incorrect

Signed-off-by: Jonathan Eagles <jeagles@apache.org>
diff --git a/tez-api/src/main/java/org/apache/tez/common/ProgressHelper.java b/tez-api/src/main/java/org/apache/tez/common/ProgressHelper.java
index 07b066c..1518ccd 100644
--- a/tez-api/src/main/java/org/apache/tez/common/ProgressHelper.java
+++ b/tez-api/src/main/java/org/apache/tez/common/ProgressHelper.java
@@ -19,74 +19,155 @@
 package org.apache.tez.common;
 
 import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import java.util.concurrent.RejectedExecutionException;
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.atomic.AtomicReference;
 import org.apache.tez.runtime.api.AbstractLogicalInput;
 import org.apache.tez.runtime.api.LogicalInput;
 import org.apache.tez.runtime.api.ProcessorContext;
-import org.apache.tez.runtime.api.ProgressFailedException;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
-
 import java.util.Map;
 import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
 
 public class ProgressHelper {
-  private static final Logger LOG = LoggerFactory.getLogger(ProgressHelper.class);
-  private String processorName;
+  private static final Logger LOG =
+      LoggerFactory.getLogger(ProgressHelper.class);
+  private static final float MIN_PROGRESS_VAL = 0.0f;
+  private static final float MAX_PROGRESS_VAL = 1.0f;
+  private final String processorName;
   protected final Map<String, LogicalInput> inputs;
-  final ProcessorContext processorContext;
+  private final ProcessorContext processorContext;
+  private final AtomicReference<ScheduledFuture<?>> periodicMonitorTaskRef;
+  private long monitorExecPeriod;
+  private volatile ScheduledExecutorService scheduledExecutorService;
 
-  volatile ScheduledExecutorService scheduledExecutorService;
-  Runnable monitorProgress = new Runnable() {
-    @Override
-    public void run() {
-      try {
-        float progSum = 0.0f;
-        float progress;
-        if (inputs != null && inputs.size() != 0) {
-          for (LogicalInput input : inputs.values()) {
-            if (input instanceof AbstractLogicalInput) {
-              float inputProgress = ((AbstractLogicalInput) input).getProgress();
-              if (inputProgress >= 0.0f && inputProgress <= 1.0f) {
-                progSum += inputProgress;
-              }
-            }
-          }
-          progress = (1.0f) * progSum / inputs.size();
-        } else {
-          progress = 1.0f;
-        }
-        processorContext.setProgress(progress);
-      } catch (ProgressFailedException pe) {
-        LOG.warn("Encountered ProgressFailedException during Processor progress update"
-            + pe);
-      } catch (InterruptedException ie) {
-        LOG.warn("Encountered InterruptedException during Processor progress update"
-            + ie);
-      }
-    }
-  };
+  public static final float processProgress(float val) {
+    return (Float.isNaN(val)) ? MIN_PROGRESS_VAL
+        : Math.max(MIN_PROGRESS_VAL, Math.min(MAX_PROGRESS_VAL, val));
+  }
 
-  public ProgressHelper(Map<String, LogicalInput> _inputs, ProcessorContext context, String processorName) {
-    this.inputs = _inputs;
+  public static final boolean isProgressWithinRange(float val) {
+    return (val <= MAX_PROGRESS_VAL && val >= MIN_PROGRESS_VAL);
+  }
+
+  public ProgressHelper(Map<String, LogicalInput> inputsParam,
+      ProcessorContext context, String processorName) {
+    this.periodicMonitorTaskRef = new AtomicReference<>(null);
+    this.inputs = inputsParam;
     this.processorContext = context;
     this.processorName = processorName;
   }
 
   public void scheduleProgressTaskService(long delay, long period) {
-    scheduledExecutorService = Executors.newScheduledThreadPool(1, new ThreadFactoryBuilder()
-        .setDaemon(true).setNameFormat("TaskProgressService{" + processorName+ ":" + processorContext.getTaskVertexName()
-            + "} #%d").build());
-    scheduledExecutorService.scheduleWithFixedDelay(monitorProgress, delay, period,
-        TimeUnit.MILLISECONDS);
-  }
-
-  public void shutDownProgressTaskService() {
-    if (scheduledExecutorService != null) {
-      scheduledExecutorService.shutdownNow();
-      scheduledExecutorService = null;
+    monitorExecPeriod = period;
+    scheduledExecutorService =
+        Executors.newScheduledThreadPool(1,
+            new ThreadFactoryBuilder().setDaemon(true).setNameFormat(
+                "TaskProgressService{" + processorName + ":" + processorContext
+                    .getTaskVertexName()
+                    + "} #%d").build());
+    try {
+      createPeriodicTask(delay);
+    } catch (RejectedExecutionException | IllegalArgumentException ex) {
+      LOG.error("Could not create periodic scheduled task for processor={}",
+          processorName, ex);
     }
   }
 
+  private Runnable createRunnableMonitor() {
+    return new Runnable() {
+      @Override
+      public void run() {
+        try {
+          float progSum = MIN_PROGRESS_VAL;
+          int invalidInput = 0;
+          float progressVal = MIN_PROGRESS_VAL;
+          if (inputs != null && !inputs.isEmpty()) {
+            for (LogicalInput input : inputs.values()) {
+              if (!(input instanceof AbstractLogicalInput)) {
+                /**
+                 * According to javdoc in
+                 * {@link org.apache.tez.runtime.api.AbstractLogicalInput} all
+                 * implementations must extend AbstractLogicalInput.
+                 */
+                continue;
+              }
+              final float inputProgress =
+                  ((AbstractLogicalInput) input).getProgress();
+              if (!isProgressWithinRange(inputProgress)) {
+                final int invalidSnapshot = ++invalidInput;
+                if (LOG.isDebugEnabled()) {
+                  LOG.debug(
+                      "progress update: Incorrect value in progress helper in "
+                          + "processor={}, inputProgress={}, inputsSize={}, "
+                          + "invalidInput={}",
+                      processorName, inputProgress, inputs.size(),
+                      invalidSnapshot);
+                }
+              }
+              progSum += processProgress(inputProgress);
+            }
+            // No need to process the average within the valid range since the
+            // processorContext validates the value before being set.
+            progressVal = progSum / inputs.size();
+          }
+          // Report progress as 0.0f when if are errors.
+          processorContext.setProgress(progressVal);
+        } catch (Throwable th) {
+          if (LOG.isDebugEnabled()) {
+            LOG.debug("progress update: Encountered InterruptedException during"
+                + " Processor={}", processorName, th);
+          }
+          if (th instanceof InterruptedException) {
+            // set interrupt flag to true sand exit
+            Thread.currentThread().interrupt();
+            return;
+          }
+        }
+      }
+    };
+  }
+
+  private boolean createPeriodicTask(long delay)
+      throws RejectedExecutionException, IllegalArgumentException {
+    stopPeriodicMonitor();
+    final Runnable runnableMonitor = createRunnableMonitor();
+    ScheduledFuture<?> futureTask = scheduledExecutorService
+        .scheduleWithFixedDelay(runnableMonitor, delay, monitorExecPeriod,
+            TimeUnit.MILLISECONDS);
+    periodicMonitorTaskRef.set(futureTask);
+    return true;
+  }
+
+  private void stopPeriodicMonitor() {
+    ScheduledFuture<?> scheduledMonitorRes =
+        this.periodicMonitorTaskRef.get();
+    if (scheduledMonitorRes != null && !scheduledMonitorRes.isCancelled()) {
+      scheduledMonitorRes.cancel(true);
+      this.periodicMonitorTaskRef.set(null);
+    }
+  }
+
+  public void shutDownProgressTaskService() {
+    stopPeriodicMonitor();
+    if (scheduledExecutorService != null) {
+      scheduledExecutorService.shutdown();
+      try {
+        if (!scheduledExecutorService.awaitTermination(monitorExecPeriod,
+            TimeUnit.MILLISECONDS)) {
+          scheduledExecutorService.shutdownNow();
+        }
+      } catch (InterruptedException e) {
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("Interrupted exception while shutting down the "
+              + "executor service for the processor name={}", processorName);
+        }
+      }
+      scheduledExecutorService.shutdownNow();
+    }
+    scheduledExecutorService = null;
+  }
 }
diff --git a/tez-api/src/main/java/org/apache/tez/runtime/api/ProcessorContext.java b/tez-api/src/main/java/org/apache/tez/runtime/api/ProcessorContext.java
index acb2a57..3782a8d 100644
--- a/tez-api/src/main/java/org/apache/tez/runtime/api/ProcessorContext.java
+++ b/tez-api/src/main/java/org/apache/tez/runtime/api/ProcessorContext.java
@@ -22,6 +22,7 @@
 import java.util.Collection;
 
 import org.apache.hadoop.classification.InterfaceAudience.Public;
+import org.apache.tez.common.ProgressHelper;
 
 /**
  * Context handle for the Processor to initialize itself.
@@ -31,12 +32,31 @@
 public interface ProcessorContext extends TaskContext {
 
   /**
+   * validate that progress is the valid range.
+   * @param progress
+   * @return the processed value of the progress that is guaranteed to be within
+   *          the valid range.
+   */
+  static float preProcessProgress(float progress) {
+    return ProgressHelper.processProgress(progress);
+  }
+
+  /**
    * Set the overall progress of this Task Attempt.
    * This automatically results in invocation of {@link ProcessorContext#notifyProgress()} 
    * and so invoking that separately is not required.
    * @param progress Progress in the range from [0.0 - 1.0f]
    */
-  public void setProgress(float progress);
+  default void setProgress(float progress) {
+    setProgressInternally(preProcessProgress(progress));
+  }
+
+  /**
+   * The actual implementation of the taskAttempt progress.
+   * All implementations needs to override this method
+   * @param progress
+   */
+  void setProgressInternally(float progress);
 
   /**
    * Check whether this attempt can commit its output
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java
index db51cee..18b7128 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java
@@ -43,6 +43,7 @@
 import org.apache.commons.lang.StringUtils;
 import org.apache.commons.lang.exception.ExceptionUtils;
 import org.apache.tez.Utils;
+import org.apache.tez.common.ProgressHelper;
 import org.apache.tez.common.TezUtilsInternal;
 import org.apache.tez.common.counters.LimitExceededException;
 import org.apache.tez.dag.app.dag.event.DAGEventTerminateDag;
@@ -804,19 +805,30 @@
   public float getProgress() {
     this.readLock.lock();
     try {
-      float progress = 0.0f;
+      float accProg = 0.0f;
+      float dagProgress = 0.0f;
+      int verticesCount = getVertices().size();
       for (Vertex v : getVertices().values()) {
         float vertexProgress = v.getProgress();
-        if (vertexProgress >= 0.0f && vertexProgress <= 1.0f) {
-          progress += vertexProgress;
+        if (LOG.isDebugEnabled()) {
+          if (!ProgressHelper.isProgressWithinRange(vertexProgress)) {
+            LOG.debug("progress update: Vertex progress is invalid range"
+                + "; v={}, progress={}", v.getName(), vertexProgress);
+          }
+        }
+        accProg += ProgressHelper.processProgress(vertexProgress);
+      }
+      if (LOG.isDebugEnabled()) {
+        if (verticesCount == 0) {
+          LOG.debug("progress update: DAGImpl getProgress() returns 0.0f: "
+              + "vertices count is 0");
         }
       }
-      float dagProgress = progress / getTotalVertices();
-      if (dagProgress >= 0.0f && progress <= 1.0f) {
-        return dagProgress;
-      } else {
-        return 0.0f;
+      if (verticesCount > 0) {
+        dagProgress =
+            ProgressHelper.processProgress(accProg / verticesCount);
       }
+      return dagProgress;
     } finally {
       this.readLock.unlock();
     }
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
index 9a59e88..52fe932 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
@@ -64,6 +64,7 @@
 import org.apache.hadoop.yarn.util.Clock;
 import org.apache.tez.client.TezClientUtils;
 import org.apache.tez.common.ATSConstants;
+import org.apache.tez.common.ProgressHelper;
 import org.apache.tez.common.ReflectionUtils;
 import org.apache.tez.common.TezUtilsInternal;
 import org.apache.tez.common.counters.AggregateTezCounters;
@@ -1572,20 +1573,31 @@
   List<EventInfo> getOnDemandRouteEvents() {
     return onDemandRouteEvents;
   }
-  
+
+  /**
+   * Updates the progress value in the vertex.
+   * This should be called only when the vertex is running state.
+   * No need to acquire the lock since this is nested inside
+   * {@link #getProgress() getProgress} method.
+   */
   private void computeProgress() {
-    this.readLock.lock();
-    try {
-      float progress = 0f;
-      for (Task task : this.tasks.values()) {
-        progress += (task.getProgress());
+
+    float accProg = 0.0f;
+    int tasksCount = this.tasks.size();
+    for (Task task : this.tasks.values()) {
+      float taskProg = task.getProgress();
+      if (LOG.isDebugEnabled()) {
+        if (!ProgressHelper.isProgressWithinRange(taskProg)) {
+          LOG.debug("progress update: vertex={}, task={} incorrect; range={}",
+              getName(), task.getTaskId().toString(), taskProg);
+        }
       }
-      if (this.numTasks != 0) {
-        progress /= this.numTasks;
-      }
-      this.progress = progress;
-    } finally {
-      this.readLock.unlock();
+      accProg += ProgressHelper.processProgress(taskProg);
+    }
+    // tasksCount is 0, do not reset the current progress.
+    if (tasksCount > 0) {
+      // force the progress to be below within the range
+      progress = ProgressHelper.processProgress(accProg / tasksCount);
     }
   }
 
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/LegacySpeculator.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/LegacySpeculator.java
index 0a0e9a2..23b057a 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/LegacySpeculator.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/LegacySpeculator.java
@@ -34,6 +34,7 @@
 import java.util.concurrent.locks.ReentrantReadWriteLock;
 import org.apache.hadoop.service.AbstractService;
 import org.apache.hadoop.service.ServiceOperations;
+import org.apache.tez.common.ProgressHelper;
 import org.apache.tez.dag.api.TezConfiguration;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -529,6 +530,12 @@
     }
 
     public void setProgress(float progress) {
+      if (LOG.isDebugEnabled()) {
+        if (!ProgressHelper.isProgressWithinRange(progress)) {
+          LOG.debug("Progress update: speculator received progress in invalid "
+              + "range={}", progress);
+        }
+      }
       this.progress = progress;
     }
 
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/TestDAGAppMaster.java b/tez-dag/src/test/java/org/apache/tez/dag/app/TestDAGAppMaster.java
index 7a7dfe2..92e43aa 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/TestDAGAppMaster.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/TestDAGAppMaster.java
@@ -465,9 +465,15 @@
     when(mockVertex.getProgress()).thenReturn(-10f);
     Assert.assertEquals("Progress was negative and should be reported as 0",
         0, am.getProgress(), 0);
+    when(mockVertex.getProgress()).thenReturn(1.0000567f);
+    Assert.assertEquals(
+        "Progress was greater than 1 by a small float precision "
+            + "1.0000567 and should be reported as 1",
+        1.0f, am.getProgress(), 0.0f);
     when(mockVertex.getProgress()).thenReturn(10f);
-    Assert.assertEquals("Progress was greater than 1 and should be reported as 0",
-        0, am.getProgress(), 0);
+    Assert.assertEquals(
+        "Progress was greater than 1 and should be reported as 1",
+        1.0f, am.getProgress(), 0.0f);
   }
 
   @SuppressWarnings("deprecation")
diff --git a/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/processor/MRTaskReporter.java b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/processor/MRTaskReporter.java
index 2fa75bf..e3fdc27 100644
--- a/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/processor/MRTaskReporter.java
+++ b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/processor/MRTaskReporter.java
@@ -23,6 +23,7 @@
 import org.apache.hadoop.mapred.Counters;
 import org.apache.hadoop.mapred.InputSplit;
 import org.apache.hadoop.mapred.Reporter;
+import org.apache.tez.common.ProgressHelper;
 import org.apache.tez.common.counters.TezCounter;
 import org.apache.tez.mapreduce.hadoop.mapred.MRCounters;
 import org.apache.tez.mapreduce.hadoop.mapred.MRReporter;
@@ -62,6 +63,9 @@
   }
 
   public void setProgress(float progress) {
+    // Validate that the progress is within the valid range. This guarantees
+    // that reporter and processorContext gets the same value.
+    progress = ProgressHelper.processProgress(progress);
     reporter.setProgress(progress);
     if (isProcessorContext) {
       ((ProcessorContext)context).setProgress(progress);
diff --git a/tez-runtime-internals/src/main/java/org/apache/tez/runtime/api/impl/TezProcessorContextImpl.java b/tez-runtime-internals/src/main/java/org/apache/tez/runtime/api/impl/TezProcessorContextImpl.java
index beae693..54605c8 100644
--- a/tez-runtime-internals/src/main/java/org/apache/tez/runtime/api/impl/TezProcessorContextImpl.java
+++ b/tez-runtime-internals/src/main/java/org/apache/tez/runtime/api/impl/TezProcessorContextImpl.java
@@ -93,8 +93,8 @@
   }
 
   @Override
-  public void setProgress(float progress) {
-    if (Math.abs(progress - runtimeTask.getProgress()) >= 0.001f) {
+  public void setProgressInternally(float progress) {
+    if (Float.compare(progress, runtimeTask.getProgress()) != 0) {
       runtimeTask.setProgress(progress);
       notifyProgress();
     }