TEZ-4067. Tez Speculation decision is calculated on each update by the dispatcher

Signed-off-by: Jonathan Eagles <jeagles@apache.org>
diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java b/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
index 7b00cf6..f087e3a 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
@@ -531,6 +531,14 @@
   public static final boolean TEZ_AM_SPECULATION_ENABLED_DEFAULT = false;
 
   /**
+   * Class used to estimate task resource needs.
+   */
+  @ConfigurationScope(Scope.VERTEX)
+  @ConfigurationProperty
+  public static final String TEZ_AM_SPECULATION_ESTIMATOR_CLASS =
+          TEZ_AM_PREFIX + "speculation.estimator.class";
+
+  /**
    * Float value. Specifies how many standard deviations away from the mean task execution time
    * should be considered as an outlier/slow task.
    */
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java b/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java
index 6636fb6..f29d199 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java
@@ -757,6 +757,8 @@
       String timeStamp = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(Calendar.getInstance().getTime());
       System.err.println(timeStamp + " Completed Dag: " + finishEvt.getDAGId().toString());
       System.out.println(timeStamp + " Completed Dag: " + finishEvt.getDAGId().toString());
+      // Stop vertex services if any
+      stopVertexServices(currentDAG);
       if (!isSession) {
         LOG.info("Not a session, AM will unregister as DAG has completed");
         this.taskSchedulerManager.setShouldUnregisterFlag();
@@ -1293,7 +1295,6 @@
       throw new SessionNotRunning("AM unable to accept new DAG submissions."
           + " In the process of shutting down");
     }
-
     // dag is in cleanup when dag state is completed but AM state is still RUNNING
     synchronized (idleStateLock) {
       while (currentDAG != null && currentDAG.isComplete() && state == DAGAppMasterState.RUNNING) {
@@ -1840,7 +1841,7 @@
     }
   }
 
-  void startServices(){
+  void startServices() {
     try {
       Throwable firstError = null;
       List<ServiceThread> threads = new ArrayList<ServiceThread>();
@@ -1888,12 +1889,16 @@
   }
 
   void stopServices() {
+    Exception firstException = null;
     // stop in reverse order of start
+    if (currentDAG != null) {
+      stopVertexServices(currentDAG);
+    }
     List<Service> serviceList = new ArrayList<Service>(services.size());
     for (ServiceWithDependency sd : services.values()) {
       serviceList.add(sd.service);
     }
-    Exception firstException = null;
+
     for (int i = services.size() - 1; i >= 0; i--) {
       Service service = serviceList.get(i);
       if (LOG.isDebugEnabled()) {
@@ -1933,7 +1938,6 @@
 
   @Override
   public synchronized void serviceStart() throws Exception {
-
     //start all the components
     startServices();
     super.serviceStart();
@@ -2060,6 +2064,9 @@
         DAGEventRecoverEvent recoverDAGEvent = new DAGEventRecoverEvent(
             recoveredDAGData.recoveredDAG.getID(), recoveredDAGData);
         dagEventDispatcher.handle(recoverDAGEvent);
+        // If we reach here, then we have recoverable DAG and we need to
+        // reinitialize the vertex services including speculators.
+        startVertexServices(currentDAG);
         this.state = DAGAppMasterState.RUNNING;
       }
     } else {
@@ -2543,6 +2550,18 @@
     this.state = DAGAppMasterState.RUNNING;
   }
 
+  private void startVertexServices(DAG dag) {
+    for (Vertex v : dag.getVertices().values()) {
+      v.startServices();
+    }
+  }
+
+  void stopVertexServices(DAG dag) {
+    for (Vertex v: dag.getVertices().values()) {
+      v.stopServices();
+    }
+  }
+
   private void startDAGExecution(DAG dag, final Map<String, LocalResource> additionalAmResources)
       throws TezException {
     currentDAG = dag;
@@ -2574,7 +2593,8 @@
     // This is a synchronous call, not an event through dispatcher. We want
     // job-init to be done completely here.
     dagEventDispatcher.handle(initDagEvent);
-
+    // Start the vertex services
+    startVertexServices(dag);
     // All components have started, start the job.
     /** create a job-start event to get this ball rolling */
     DAGEvent startDagEvent = new DAGEventStartDag(currentDAG.getID(), additionalUrlsForClasspath);
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/Vertex.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/Vertex.java
index 0b2406f..f3ef72b 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/Vertex.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/Vertex.java
@@ -26,6 +26,7 @@
 import javax.annotation.Nullable;
 
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.service.AbstractService;
 import org.apache.hadoop.yarn.api.records.Resource;
 import org.apache.tez.common.counters.TezCounters;
 import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
@@ -73,7 +74,6 @@
   LinkedHashMap<String, Integer> getIOIndices();
   String getName();
   VertexState getState();
-
   /**
    * Get all the counters of this vertex.
    * @return aggregate task-counters
@@ -169,7 +169,10 @@
       int fromEventId, int nextPreRoutedFromEventId, int maxEvents);
   
   void handleSpeculatorEvent(SpeculatorEvent event);
-
+  AbstractService getSpeculator();
+  void initServices();
+  void startServices();
+  void stopServices();
   ProcessorDescriptor getProcessorDescriptor();
   public DAG getDAG();
   @Nullable
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
index a2ef475..9a59e88 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
@@ -49,6 +49,9 @@
 import org.apache.hadoop.classification.InterfaceAudience.Private;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.service.AbstractService;
+import org.apache.hadoop.service.ServiceOperations;
+import org.apache.hadoop.service.ServiceStateException;
 import org.apache.hadoop.util.StringInterner;
 import org.apache.hadoop.yarn.api.records.LocalResource;
 import org.apache.hadoop.yarn.api.records.Resource;
@@ -306,8 +309,10 @@
 
   @VisibleForTesting
   final List<VertexManagerEvent> pendingVmEvents = new LinkedList<>();
-  
-  LegacySpeculator speculator;
+
+  private final AtomicBoolean servicesInited;
+  private LegacySpeculator speculator;
+  private List<AbstractService> services;
 
   @VisibleForTesting
   Map<String, ListenableFuture<Void>> commitFutures = new ConcurrentHashMap<String, ListenableFuture<Void>>();
@@ -869,6 +874,94 @@
     }
   }
 
+  @Override
+  public void initServices() {
+    if (servicesInited.get()) {
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("Skipping Initing services for vertex because already"
+            + " Initialized, name=" + this.vertexName);
+      }
+      return;
+    }
+    writeLock.lock();
+    try {
+      List<AbstractService> servicesToAdd = new ArrayList<>();
+      if (isSpeculationEnabled()) {
+        // Initialize the speculator
+        if (LOG.isDebugEnabled()) {
+          LOG.debug(
+              "Initing service vertex speculator, name=" + this.vertexName);
+        }
+        speculator = new LegacySpeculator(vertexConf, getAppContext(), this);
+        speculator.init(vertexConf);
+        servicesToAdd.add(speculator);
+      }
+      services = Collections.synchronizedList(servicesToAdd);
+      servicesInited.set(true);
+    } finally {
+      writeLock.unlock();
+    }
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("Initing service vertex, name=" + this.vertexName);
+    }
+  }
+
+  @Override
+  public void startServices() {
+    writeLock.lock();
+    try {
+      if (!servicesInited.get()) {
+        initServices();
+      }
+      for (AbstractService srvc : services) {
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("starting service : " + srvc.getName()
+              + ", for vertex: " + getName());
+        }
+        srvc.start();
+      }
+    } finally {
+      writeLock.unlock();
+    }
+  }
+
+  @Override
+  public void stopServices() {
+    Exception firstException = null;
+    List<AbstractService> stoppedServices = new ArrayList<>();
+    writeLock.lock();
+    try {
+      if (servicesInited.get()) {
+        for (AbstractService srvc : services) {
+          if (LOG.isDebugEnabled()) {
+            LOG.debug("Stopping service : " + srvc);
+          }
+          Exception ex = ServiceOperations.stopQuietly(srvc);
+          if (ex != null && firstException == null) {
+            LOG.warn(String.format(
+                "Failed to stop service=(%s) for vertex name=(%s)",
+                srvc.getName(), getName()), ex);
+            firstException = ex;
+          } else {
+            stoppedServices.add(srvc);
+          }
+        }
+        services.clear();
+      }
+      servicesInited.set(false);
+    } finally {
+      writeLock.unlock();
+    }
+    // wait for services to stop
+    for (AbstractService srvc : stoppedServices) {
+      srvc.waitForServiceToStop(60000L);
+    }
+    // After stopping all services, rethrow the first exception raised
+    if (firstException != null) {
+      throw ServiceStateException.convert(firstException);
+    }
+  }
+
   public VertexImpl(TezVertexID vertexId, VertexPlan vertexPlan,
       String vertexName, Configuration dagConf, EventHandler eventHandler,
       TaskCommunicatorManagerInterface taskCommunicatorManagerInterface, Clock clock,
@@ -972,11 +1065,11 @@
 
     this.dagVertexGroups = dagVertexGroups;
     
-    isSpeculationEnabled = vertexConf.getBoolean(TezConfiguration.TEZ_AM_SPECULATION_ENABLED,
-        TezConfiguration.TEZ_AM_SPECULATION_ENABLED_DEFAULT);
-    if (isSpeculationEnabled()) {
-      speculator = new LegacySpeculator(vertexConf, getAppContext(), this);
-    }
+    isSpeculationEnabled =
+        vertexConf.getBoolean(TezConfiguration.TEZ_AM_SPECULATION_ENABLED,
+            TezConfiguration.TEZ_AM_SPECULATION_ENABLED_DEFAULT);
+    servicesInited = new AtomicBoolean(false);
+    initServices();
 
     maxFailuresPercent = vertexConf.getFloat(TezConfiguration.TEZ_VERTEX_FAILURES_MAXPERCENT,
             TezConfiguration.TEZ_VERTEX_FAILURES_MAXPERCENT_DEFAULT);
@@ -2329,6 +2422,11 @@
         abortVertex(VertexStatus.State.valueOf(finalState.name()));
         eventHandler.handle(new DAGEvent(getDAGId(),
             DAGEventType.INTERNAL_ERROR));
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("stopping services attached to the succeeded Vertex,"
+              + "name=" + getName());
+        }
+        stopServices();
         try {
           logJobHistoryVertexFailedEvent(finalState);
         } catch (IOException e) {
@@ -2344,6 +2442,11 @@
         abortVertex(VertexStatus.State.valueOf(finalState.name()));
         eventHandler.handle(new DAGEventVertexCompleted(getVertexId(),
             finalState, terminationCause));
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("stopping services attached to the succeeded Vertex,"
+              + "name=" + getName());
+        }
+        stopServices();
         try {
           logJobHistoryVertexFailedEvent(finalState);
         } catch (IOException e) {
@@ -2356,6 +2459,12 @@
             logJobHistoryVertexFinishedEvent();
             eventHandler.handle(new DAGEventVertexCompleted(getVertexId(),
                 finalState));
+            // Stop related services
+            if (LOG.isDebugEnabled()) {
+              LOG.debug("stopping services attached to the succeeded Vertex,"
+                  + "name=" + getName());
+            }
+            stopServices();
           } catch (LimitExceededException e) {
             LOG.error("Counter limits exceeded for vertex: " + getLogIdentifier(), e);
             finalState = VertexState.FAILED;
@@ -2374,6 +2483,12 @@
         }
         break;
       default:
+        // Stop related services
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("stopping services attached with Unexpected State,"
+              + "name=" + getName());
+        }
+        stopServices();
         throw new TezUncheckedException("Unexpected VertexState: " + finalState);
     }
     return finalState;
@@ -2458,6 +2573,8 @@
     } else {
       initedTime = clock.getTime();
     }
+    // set the vertex services to be initialized.
+    initServices();
     // Only initialize committer when it is in non-recovery mode or vertex is not recovered to completed 
     // state in recovery mode
     if (recoveryData == null || recoveryData.getVertexFinishedEvent() == null) {
@@ -3316,6 +3433,12 @@
     if (finishTime == 0) {
       setFinishTime();
     }
+    // Stop related services
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("stopping services attached to the aborted Vertex, name="
+          + getName());
+    }
+    stopServices();
   }
 
   private void mayBeConstructFinalFullCounters() {
@@ -4763,6 +4886,6 @@
     }
   }
 
-  @VisibleForTesting
-  public LegacySpeculator getSpeculator() { return speculator; }
+  @Override
+  public AbstractService getSpeculator() { return speculator; }
 }
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/LegacySpeculator.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/LegacySpeculator.java
index 3e7c2c0..0a0e9a2 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/LegacySpeculator.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/LegacySpeculator.java
@@ -18,13 +18,22 @@
 
 package org.apache.tez.dag.app.dag.speculation.legacy;
 
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
 
 import com.google.common.annotations.VisibleForTesting;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.locks.ReadWriteLock;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+import org.apache.hadoop.service.AbstractService;
+import org.apache.hadoop.service.ServiceOperations;
 import org.apache.tez.dag.api.TezConfiguration;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -41,8 +50,6 @@
 import org.apache.tez.dag.records.TezTaskAttemptID;
 import org.apache.tez.dag.records.TezTaskID;
 
-import com.google.common.base.Preconditions;
-
 /**
  * Maintains runtime estimation statistics. Makes periodic updates
  * estimates based on progress and decides on when to trigger a 
@@ -54,7 +61,7 @@
  * because it may be likely a wasted attempt. There is a delay between
  * successive speculations.
  */
-public class LegacySpeculator {
+public class LegacySpeculator extends AbstractService {
   
   private static final long ON_SCHEDULE = Long.MIN_VALUE;
   private static final long ALREADY_SPECULATING = Long.MIN_VALUE + 1;
@@ -75,7 +82,7 @@
 
   private final ConcurrentMap<TezTaskID, Boolean> runningTasks
       = new ConcurrentHashMap<TezTaskID, Boolean>();
-
+  private ReadWriteLock lock = new ReentrantReadWriteLock();
   // Used to track any TaskAttempts that aren't heart-beating for a while, so
   // that we can aggressively speculate instead of waiting for task-timeout.
   private final ConcurrentMap<TezTaskAttemptID, TaskAttemptHistoryStatistics>
@@ -86,13 +93,17 @@
   // in progress.
   private static final long MAX_WAITTING_TIME_FOR_HEARTBEAT = 9 * 1000;
 
-  private final Set<TezTaskID> waitingToSpeculate = new HashSet<TezTaskID>();
+  private final Set<TezTaskID> mayHaveSpeculated = new HashSet<TezTaskID>();
 
   private Vertex vertex;
   private TaskRuntimeEstimator estimator;
   private final long taskTimeout;
   private final Clock clock;
   private long nextSpeculateTime = Long.MIN_VALUE;
+  private Thread speculationBackgroundThread = null;
+  private volatile boolean stopped = false;
+  /* Allow the speculator to wait on a blockingQueue in case we use it for event notification */
+  private BlockingQueue<Object> scanControl = new LinkedBlockingQueue<Object>();
 
   @VisibleForTesting
   public int getMinimumAllowedSpeculativeTasks() { return minimumAllowedSpeculativeTasks;}
@@ -119,17 +130,72 @@
   
   static private TaskRuntimeEstimator getEstimator
       (Configuration conf, Vertex vertex) {
-    TaskRuntimeEstimator estimator = new LegacyTaskRuntimeEstimator();
-    estimator.contextualize(conf, vertex);
-    
+    TaskRuntimeEstimator estimator;
+    // "tez.am.speculation.estimator.class"
+    Class<? extends TaskRuntimeEstimator> estimatorClass =
+        conf.getClass(TezConfiguration.TEZ_AM_SPECULATION_ESTIMATOR_CLASS,
+            LegacyTaskRuntimeEstimator.class,
+            TaskRuntimeEstimator.class);
+    try {
+      Constructor<? extends TaskRuntimeEstimator> estimatorConstructor
+          = estimatorClass.getConstructor();
+      estimator = estimatorConstructor.newInstance();
+      estimator.contextualize(conf, vertex);
+    } catch (NoSuchMethodException e) {
+      LOG.error("Can't make a speculation runtime estimator", e);
+      throw new RuntimeException(e);
+    } catch (IllegalAccessException e) {
+      LOG.error("Can't make a speculation runtime estimator", e);
+      throw new RuntimeException(e);
+    } catch (InstantiationException e) {
+      LOG.error("Can't make a speculation runtime estimator", e);
+      throw new RuntimeException(e);
+    } catch (InvocationTargetException e) {
+      LOG.error("Can't make a speculation runtime estimator", e);
+      throw new RuntimeException(e);
+    }
     return estimator;
   }
 
+  @Override
+  protected void serviceStart() throws Exception {
+    lock.writeLock().lock();
+    try {
+      assert (speculationBackgroundThread == null);
+
+      if (speculationBackgroundThread == null) {
+        speculationBackgroundThread =
+            new Thread(createThread(),
+                "DefaultSpeculator background processing");
+        speculationBackgroundThread.start();
+      }
+      super.serviceStart();
+    } catch (Exception e) {
+      LOG.warn("Speculator thread could not launch", e);
+    } finally {
+      lock.writeLock().unlock();
+    }
+  }
+
+  public boolean isStarted() {
+    boolean result = false;
+    lock.readLock().lock();
+    try {
+      if (this.speculationBackgroundThread != null) {
+        result = getServiceState().equals(STATE.STARTED);
+      }
+    } finally {
+      lock.readLock().unlock();
+    }
+    return result;
+  }
+
   // This constructor is designed to be called by other constructors.
   //  However, it's public because we do use it in the test cases.
   // Normally we figure out our own estimator.
   public LegacySpeculator
       (Configuration conf, TaskRuntimeEstimator estimator, Clock clock, Vertex vertex) {
+    super(LegacySpeculator.class.getName());
     this.vertex = vertex;
     this.estimator = estimator;
     this.clock = clock;
@@ -153,28 +219,46 @@
             TezConfiguration.TEZ_AM_MINIMUM_ALLOWED_SPECULATIVE_TASKS_DEFAULT);
   }
 
-/*   *************************************************************    */
-
-  void maybeSpeculate() {
-    long now = clock.getTime();
-    
-    if (now < nextSpeculateTime) {
-      return;
+  @Override
+  protected void serviceStop() throws Exception {
+    lock.writeLock().lock();
+    try {
+      stopped = true;
+      // this could be called before background thread is established
+      if (speculationBackgroundThread != null) {
+        speculationBackgroundThread.interrupt();
+      }
+      super.serviceStop();
+      speculationBackgroundThread = null;
+    } finally {
+      lock.writeLock().unlock();
     }
-    
-    int speculations = maybeScheduleASpeculation();
-    long mininumRecomp
-        = speculations > 0 ? soonestRetryAfterSpeculate
-                           : soonestRetryAfterNoSpeculate;
+  }
 
-    long wait = Math.max(mininumRecomp,
-          clock.getTime() - now);
-    nextSpeculateTime = now + wait;
-
-    if (speculations > 0) {
-      LOG.info("We launched " + speculations
-          + " speculations.  Waiting " + wait + " milliseconds.");
-    }
+  public Runnable createThread() {
+    return new Runnable() {
+      @Override
+      public void run() {
+        while (!stopped && !Thread.currentThread().isInterrupted()) {
+          long backgroundRunStartTime = clock.getTime();
+          try {
+            int speculations = computeSpeculations();
+            long nextRecompTime = speculations > 0 ? soonestRetryAfterSpeculate
+                : soonestRetryAfterNoSpeculate;
+            long wait = Math.max(nextRecompTime, clock.getTime() - backgroundRunStartTime);
+            if (speculations > 0) {
+              LOG.info("We launched " + speculations
+                  + " speculations.  Waiting " + wait + " milliseconds.");
+            }
+            Object pollResult = scanControl.poll(wait, TimeUnit.MILLISECONDS);
+          } catch (InterruptedException ie) {
+            if (!stopped) {
+              LOG.warn("Speculator thread interrupted", ie);
+            }
+          }
+        }
+      }
+    };
   }
 
 /*   *************************************************************    */
@@ -186,7 +270,6 @@
   public void notifyAttemptStatusUpdate(TezTaskAttemptID taId, TaskAttemptState reportedState,
       long timestamp) {
     statusUpdate(taId, reportedState, timestamp);
-    maybeSpeculate();
   }
 
   /**
@@ -197,12 +280,15 @@
    * @param timestamp the time this status corresponds to.  This matters
    *        because statuses contain progress.
    */
-  private void statusUpdate(TezTaskAttemptID attemptID, TaskAttemptState reportedState, long timestamp) {
+  private void statusUpdate(TezTaskAttemptID attemptID,
+      TaskAttemptState reportedState, long timestamp) {
 
     TezTaskID taskID = attemptID.getTaskID();
     Task task = vertex.getTask(taskID);
 
-    Preconditions.checkState(task != null, "Null task for attempt: " + attemptID);
+    if (task == null) {
+      return;
+    }
 
     estimator.updateAttempt(attemptID, reportedState, timestamp);
 
@@ -257,9 +343,19 @@
 
     // short circuit completed tasks. no need to spend time on them
     if (task.getState() == TaskState.SUCCEEDED) {
+      // remove the task from may have speculated if it exists
+      mayHaveSpeculated.remove(taskID);
       return NOT_RUNNING;
     }
 
+    if (!mayHaveSpeculated.contains(taskID) && !shouldUseTimeout) {
+      acceptableRuntime = estimator.thresholdRuntime(taskID);
+      if (acceptableRuntime == Long.MAX_VALUE) {
+        return ON_SCHEDULE;
+      }
+    }
+
+    TezTaskAttemptID runningTaskAttemptID = null;
     int numberRunningAttempts = 0;
 
     for (TaskAttempt taskAttempt : attempts.values()) {
@@ -267,36 +363,8 @@
       if (taskAttemptState == TaskAttemptState.RUNNING
           || taskAttemptState == TaskAttemptState.STARTING) {
         if (++numberRunningAttempts > 1) {
-          waitingToSpeculate.remove(taskID);
           return ALREADY_SPECULATING;
         }
-      }
-    }
-
-    // If we are here, there's at most one task attempt.
-    if (numberRunningAttempts == 0) {
-      return NOT_RUNNING;
-    }
-
-    if ((numberRunningAttempts == 1) && waitingToSpeculate.contains(taskID)) {
-      return ALREADY_SPECULATING;
-    }
-    else {
-      if (!shouldUseTimeout) {
-        acceptableRuntime = estimator.thresholdRuntime(taskID);
-        if (acceptableRuntime == Long.MAX_VALUE) {
-          return ON_SCHEDULE;
-        }
-      }
-    }
-
-    TezTaskAttemptID runningTaskAttemptID = null;
-
-    for (TaskAttempt taskAttempt : attempts.values()) {
-      TaskAttemptState taskAttemptState = taskAttempt.getState();
-      if (taskAttemptState == TaskAttemptState.RUNNING
-          || taskAttemptState == TaskAttemptState.STARTING) {
-
         runningTaskAttemptID = taskAttempt.getID();
 
         long taskAttemptStartTime
@@ -338,7 +406,8 @@
               if (data.notHeartbeatedInAWhile(now)) {
                 // Stats have stagnated for a while, simulate heart-beat.
                 // Now simulate the heart-beat
-                statusUpdate(taskAttempt.getID(), taskAttempt.getState(), clock.getTime());
+                statusUpdate(taskAttempt.getID(), taskAttempt.getState(),
+                    clock.getTime());
               }
             } else {
               // Stats have changed - update our data structure
@@ -361,6 +430,11 @@
       }
     }
 
+    // If we are here, there's at most one task attempt.
+    if (numberRunningAttempts == 0) {
+      return NOT_RUNNING;
+    }
+
     if ((acceptableRuntime == Long.MIN_VALUE) && !shouldUseTimeout) {
       acceptableRuntime = estimator.thresholdRuntime(taskID);
       if (acceptableRuntime == Long.MAX_VALUE) {
@@ -371,14 +445,14 @@
     return result;
   }
 
-  //Add attempt to a given Task.
+  // Add attempt to a given Task.
   protected void addSpeculativeAttempt(TezTaskID taskID) {
     LOG.info("DefaultSpeculator.addSpeculativeAttempt -- we are speculating " + taskID);
     vertex.scheduleSpeculativeTask(taskID);
-    waitingToSpeculate.add(taskID);
+    mayHaveSpeculated.add(taskID);
   }
 
-  private int maybeScheduleASpeculation() {
+  int computeSpeculations() {
     int successes = 0;
 
     long now = clock.getTime();
@@ -390,19 +464,18 @@
 
     int numberAllowedSpeculativeTasks
         = (int) Math.max(minimumAllowedSpeculativeTasks,
-                         proportionTotalTasksSpeculatable * tasks.size());
-
+        proportionTotalTasksSpeculatable * tasks.size());
     TezTaskID bestTaskID = null;
     long bestSpeculationValue = -1L;
     boolean shouldUseTimeout =
-            (tasks.size() <= VERTEX_SIZE_THRESHOLD_FOR_TIMEOUT_SPECULATION) &&
+        (tasks.size() <= VERTEX_SIZE_THRESHOLD_FOR_TIMEOUT_SPECULATION) &&
             (taskTimeout >= 0);
 
     // this loop is potentially pricey.
     // TODO track the tasks that are potentially worth looking at
     for (Map.Entry<TezTaskID, Task> taskEntry : tasks.entrySet()) {
       long mySpeculationValue = speculationValue(taskEntry.getValue(), now,
-              shouldUseTimeout);
+          shouldUseTimeout);
 
       if (mySpeculationValue == ALREADY_SPECULATING) {
         ++numberSpeculationsAlready;
@@ -419,7 +492,7 @@
     }
     numberAllowedSpeculativeTasks
         = (int) Math.max(numberAllowedSpeculativeTasks,
-                         proportionRunningTasksSpeculatable * numberRunningTasks);
+                        proportionRunningTasksSpeculatable * numberRunningTasks);
 
     // If we found a speculation target, fire it off
     if (bestTaskID != null
@@ -427,7 +500,6 @@
       addSpeculativeAttempt(bestTaskID);
       ++successes;
     }
-
     return successes;
   }
 
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/TestSpeculation.java b/tez-dag/src/test/java/org/apache/tez/dag/app/TestSpeculation.java
index e1aa448..a81d4d3 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/TestSpeculation.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/TestSpeculation.java
@@ -44,7 +44,6 @@
 import org.apache.tez.dag.app.dag.Task;
 import org.apache.tez.dag.app.dag.TaskAttempt;
 import org.apache.tez.dag.app.dag.impl.DAGImpl;
-import org.apache.tez.dag.app.dag.impl.VertexImpl;
 import org.apache.tez.dag.app.dag.speculation.legacy.LegacySpeculator;
 import org.apache.tez.dag.library.vertexmanager.ShuffleVertexManager;
 import org.apache.tez.dag.records.TaskAttemptTerminationCause;
@@ -78,7 +77,7 @@
       throw new RuntimeException("init failure", e);
     }
   }
-  
+
   MockTezClient createTezSession() throws Exception {
     TezConfiguration tezconf = new TezConfiguration(defaultConf);
     AtomicBoolean mockAppLauncherGoFlag = new AtomicBoolean(false);
@@ -109,11 +108,12 @@
     confToExpected.put(Long.MAX_VALUE >> 1, 1); // Really long time to speculate
     confToExpected.put(100L, 2);
     confToExpected.put(-1L, 1); // Don't speculate
-
+    defaultConf.setLong(TezConfiguration.TEZ_AM_SOONEST_RETRY_AFTER_NO_SPECULATE, 50);
     for(Map.Entry<Long, Integer> entry : confToExpected.entrySet()) {
       defaultConf.setLong(
               TezConfiguration.TEZ_AM_LEGACY_SPECULATIVE_SINGLE_TASK_VERTEX_TIMEOUT,
               entry.getKey());
+
       DAG dag = DAG.create("test");
       Vertex vA = Vertex.create("A",
               ProcessorDescriptor.create("Proc.class"),
@@ -154,15 +154,14 @@
     defaultConf.setInt(TezConfiguration.TEZ_AM_MINIMUM_ALLOWED_SPECULATIVE_TASKS, 20);
     defaultConf.setDouble(TezConfiguration.TEZ_AM_PROPORTION_TOTAL_TASKS_SPECULATABLE, 0.2);
     defaultConf.setDouble(TezConfiguration.TEZ_AM_PROPORTION_RUNNING_TASKS_SPECULATABLE, 0.25);
-    defaultConf.setLong(TezConfiguration.TEZ_AM_SOONEST_RETRY_AFTER_NO_SPECULATE, 2000);
-    defaultConf.setLong(TezConfiguration.TEZ_AM_SOONEST_RETRY_AFTER_SPECULATE, 10000);
+    defaultConf.setLong(TezConfiguration.TEZ_AM_SOONEST_RETRY_AFTER_NO_SPECULATE, 25);
+    defaultConf.setLong(TezConfiguration.TEZ_AM_SOONEST_RETRY_AFTER_SPECULATE, 50);
 
     DAG dag = DAG.create("test");
     Vertex vA = Vertex.create("A", ProcessorDescriptor.create("Proc.class"), 5);
     dag.addVertex(vA);
 
     MockTezClient tezClient = createTezSession();
-    
     DAGClient dagClient = tezClient.submitDAG(dag);
     DAGImpl dagImpl = (DAGImpl) mockApp.getContext().getCurrentDAG();
     TezVertexID vertexId = TezVertexID.getInstance(dagImpl.getID(), 0);
@@ -195,12 +194,13 @@
           .getValue());
     }
 
-    LegacySpeculator speculator = ((VertexImpl) dagImpl.getVertex(vA.getName())).getSpeculator();
+    LegacySpeculator speculator =
+        (LegacySpeculator)(dagImpl.getVertex(vA.getName())).getSpeculator();
     Assert.assertEquals(20, speculator.getMinimumAllowedSpeculativeTasks());
     Assert.assertEquals(.2, speculator.getProportionTotalTasksSpeculatable(), 0);
     Assert.assertEquals(.25, speculator.getProportionRunningTasksSpeculatable(), 0);
-    Assert.assertEquals(2000, speculator.getSoonestRetryAfterNoSpeculate());
-    Assert.assertEquals(10000, speculator.getSoonestRetryAfterSpeculate());
+    Assert.assertEquals(25, speculator.getSoonestRetryAfterNoSpeculate());
+    Assert.assertEquals(50, speculator.getSoonestRetryAfterSpeculate());
 
     tezClient.stop();
   }
@@ -214,15 +214,18 @@
   public void testBasicSpeculationWithoutProgress() throws Exception {
     testBasicSpeculation(false);
   }
-  
-  @Test (timeout=10000)
+
+  @Test (timeout=100000)
   public void testBasicSpeculationPerVertexConf() throws Exception {
     DAG dag = DAG.create("test");
     String vNameNoSpec = "A";
     String vNameSpec = "B";
+    String speculatorSleepTime = "50";
     Vertex vA = Vertex.create(vNameNoSpec, ProcessorDescriptor.create("Proc.class"), 5);
     Vertex vB = Vertex.create(vNameSpec, ProcessorDescriptor.create("Proc.class"), 5);
     vA.setConf(TezConfiguration.TEZ_AM_SPECULATION_ENABLED, "false");
+    vB.setConf(TezConfiguration.TEZ_AM_SOONEST_RETRY_AFTER_NO_SPECULATE,
+        speculatorSleepTime);
     dag.addVertex(vA);
     dag.addVertex(vB);
     // min/max src fraction is set to 1. So vertices will run sequentially
@@ -233,14 +236,14 @@
                 InputDescriptor.create("I"))));
 
     MockTezClient tezClient = createTezSession();
-    
+
     DAGClient dagClient = tezClient.submitDAG(dag);
     DAGImpl dagImpl = (DAGImpl) mockApp.getContext().getCurrentDAG();
     TezVertexID vertexId = dagImpl.getVertex(vNameSpec).getVertexId();
     TezVertexID vertexIdNoSpec = dagImpl.getVertex(vNameNoSpec).getVertexId();
     // original attempt is killed and speculative one is successful
-    TezTaskAttemptID killedTaId = TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, 0),
-        0);
+    TezTaskAttemptID killedTaId =
+        TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, 0), 0);
     TezTaskAttemptID noSpecTaId = TezTaskAttemptID
         .getInstance(TezTaskID.getInstance(vertexIdNoSpec, 0), 0);
 
@@ -249,15 +252,23 @@
     mockLauncher.setStatusUpdatesForTask(noSpecTaId, 100);
 
     mockLauncher.startScheduling(true);
-    dagClient.waitForCompletion();
-    Assert.assertEquals(DAGStatus.State.SUCCEEDED, dagClient.getDAGStatus(null).getState());
     org.apache.tez.dag.app.dag.Vertex vSpec = dagImpl.getVertex(vertexId);
     org.apache.tez.dag.app.dag.Vertex vNoSpec = dagImpl.getVertex(vertexIdNoSpec);
+    // Wait enough time to give chance for the speculator to trigger
+    // speculation on VB.
+    // This would fail because of JUnit time out.
+    do {
+      Thread.sleep(100);
+    } while (vSpec.getAllCounters().findCounter(TaskCounter.NUM_SPECULATIONS)
+        .getValue() <= 0);
+    dagClient.waitForCompletion();
     // speculation for vA but not for vB
-    Assert.assertTrue(vSpec.getAllCounters().findCounter(TaskCounter.NUM_SPECULATIONS)
-        .getValue() > 0);
-    Assert.assertEquals(0, vNoSpec.getAllCounters().findCounter(TaskCounter.NUM_SPECULATIONS)
-        .getValue());
+    Assert.assertTrue("Num Speculations is not higher than 0",
+        vSpec.getAllCounters().findCounter(TaskCounter.NUM_SPECULATIONS)
+            .getValue() > 0);
+    Assert.assertEquals(0,
+        vNoSpec.getAllCounters().findCounter(TaskCounter.NUM_SPECULATIONS)
+            .getValue());
 
     tezClient.stop();
   }