TEZ-3363: Delete intermediate data at the vertex level for Shuffle Handler (#60) (Syed Shameerur Rahman reviewed by Laszlo Bodor)

diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java b/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
index 16d1dfc..71ebfee 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
@@ -884,6 +884,22 @@
   public static final boolean TEZ_AM_DAG_CLEANUP_ON_COMPLETION_DEFAULT = false;
 
   /**
+   * Integer value. Instructs AM to delete vertex shuffle data if a vertex and all its
+   * child vertices at a certain depth are completed. Value less than or equal to 0 indicates the feature
+   * is disabled.
+   * Let's say we have a dag  Map1 - Reduce2 - Reduce3 - Reduce4.
+   * case:1 height = 1
+   * when Reduce 2 completes all the shuffle data of Map1 will be deleted and so on for other vertex.
+   * case: 2 height = 2
+   * when Reduce 3 completes all the shuffle data of Map1 will be deleted and so on for other vertex.
+   */
+  @ConfigurationScope(Scope.AM)
+  @ConfigurationProperty(type="integer")
+  public static final String TEZ_AM_VERTEX_CLEANUP_HEIGHT = TEZ_AM_PREFIX
+      + "vertex.cleanup.height";
+  public static final int TEZ_AM_VERTEX_CLEANUP_HEIGHT_DEFAULT = 0;
+
+  /**
    * Boolean value. Instructs AM to delete intermediate attempt data for failed task attempts.
    */
   @ConfigurationScope(Scope.AM)
@@ -893,8 +909,8 @@
   public static final boolean TEZ_AM_TASK_ATTEMPT_CLEANUP_ON_FAILURE_DEFAULT = false;
 
   /**
-   * Int value. Upper limit on the number of threads used to delete DAG directories and failed task attempts
-   * directories on nodes.
+   * Int value. Upper limit on the number of threads used to delete DAG directories,
+   * Vertex directories and failed task attempts directories on nodes.
    */
   @ConfigurationScope(Scope.AM)
   @ConfigurationProperty(type="integer")
diff --git a/tez-common/src/main/java/org/apache/tez/common/DagContainerLauncher.java b/tez-common/src/main/java/org/apache/tez/common/DagContainerLauncher.java
index 6bda0a8..c2337af 100644
--- a/tez-common/src/main/java/org/apache/tez/common/DagContainerLauncher.java
+++ b/tez-common/src/main/java/org/apache/tez/common/DagContainerLauncher.java
@@ -24,9 +24,12 @@
 import org.apache.tez.common.security.JobTokenSecretManager;
 import org.apache.tez.dag.records.TezDAGID;
 import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezVertexID;
 import org.apache.tez.serviceplugins.api.ContainerLauncher;
 import org.apache.tez.serviceplugins.api.ContainerLauncherContext;
 
+import java.util.Set;
+
 /**
  * Plugin to allow custom container launchers to be written to launch containers that want to
  * support cleanup of DAG level directories upon DAG completion in session mode. The directories are created by
@@ -43,6 +46,9 @@
 
   public abstract void dagComplete(TezDAGID dag, JobTokenSecretManager jobTokenSecretManager);
 
+  public abstract void vertexComplete(TezVertexID vertex, JobTokenSecretManager jobTokenSecretManager,
+                                      Set<NodeId> nodeIdList);
+
   public abstract void taskAttemptFailed(TezTaskAttemptID taskAttemptID,
                                          JobTokenSecretManager jobTokenSecretManager, NodeId nodeId);
 }
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java b/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java
index 972fadf..5828861 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/DAGAppMaster.java
@@ -2739,6 +2739,10 @@
     return sb.toString();
   }
 
+  public void vertexComplete(TezVertexID completedVertexID, Set<NodeId> nodesList) {
+    getContainerLauncherManager().vertexComplete(completedVertexID, jobTokenSecretManager, nodesList);
+  }
+
   public void taskAttemptFailed(TezTaskAttemptID attemptID, NodeId nodeId) {
     getContainerLauncherManager().taskAttemptFailed(attemptID, jobTokenSecretManager, nodeId);
   }
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/event/VertexEventType.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/event/VertexEventType.java
index 15be94d..ed32529 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/event/VertexEventType.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/event/VertexEventType.java
@@ -34,6 +34,7 @@
   V_START,
   V_SOURCE_TASK_ATTEMPT_COMPLETED,
   V_SOURCE_VERTEX_STARTED,
+  V_DELETE_SHUFFLE_DATA,
   
   //Producer:Task
   V_TASK_COMPLETED,
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/event/VertexShuffleDataDeletion.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/event/VertexShuffleDataDeletion.java
new file mode 100644
index 0000000..8ea3a15
--- /dev/null
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/event/VertexShuffleDataDeletion.java
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app.dag.event;
+
+import org.apache.tez.dag.app.dag.Vertex;
+
+
+public class VertexShuffleDataDeletion extends VertexEvent {
+  // child vertex
+  private Vertex sourceVertex;
+  // parent vertex
+  private Vertex targetVertex;
+
+  public VertexShuffleDataDeletion(Vertex sourceVertex, Vertex targetVertex) {
+    super(targetVertex.getVertexId(), VertexEventType.V_DELETE_SHUFFLE_DATA);
+    this.sourceVertex = sourceVertex;
+    this.targetVertex = targetVertex;
+  }
+
+  public Vertex getSourceVertex() {
+    return sourceVertex;
+  }
+
+  public Vertex getTargetVertex() {
+    return targetVertex;
+  }
+}
\ No newline at end of file
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java
index c9337e4..aa28e02 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java
@@ -51,6 +51,7 @@
 import org.apache.tez.common.counters.LimitExceededException;
 import org.apache.tez.dag.app.dag.event.DAGEventTerminateDag;
 import org.apache.tez.dag.app.dag.event.DiagnosableEvent;
+import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils;
 import org.apache.tez.state.OnStateChangedCallback;
 import org.apache.tez.state.StateMachineTez;
 import org.slf4j.Logger;
@@ -1772,6 +1773,13 @@
 
     vertex.setInputVertices(inVertices);
     vertex.setOutputVertices(outVertices);
+    boolean cleanupShuffleDataAtVertexLevel = dag.dagConf.getInt(TezConfiguration.TEZ_AM_VERTEX_CLEANUP_HEIGHT,
+        TezConfiguration.TEZ_AM_VERTEX_CLEANUP_HEIGHT_DEFAULT) > 0 && ShuffleUtils.isTezShuffleHandler(dag.dagConf);
+    if (cleanupShuffleDataAtVertexLevel) {
+      int deletionHeight = dag.dagConf.getInt(TezConfiguration.TEZ_AM_VERTEX_CLEANUP_HEIGHT,
+              TezConfiguration.TEZ_AM_VERTEX_CLEANUP_HEIGHT_DEFAULT);
+      ((VertexImpl) vertex).initShuffleDeletionContext(deletionHeight);
+    }
   }
 
   /**
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
index 934dd4e..e55b10a 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
@@ -114,6 +114,8 @@
 import org.apache.tez.dag.app.dag.RootInputInitializerManager;
 import org.apache.tez.dag.app.dag.StateChangeNotifier;
 import org.apache.tez.dag.app.dag.Task;
+import org.apache.tez.dag.app.dag.TaskAttempt;
+import org.apache.hadoop.yarn.api.records.NodeId;
 import org.apache.tez.dag.app.dag.TaskAttemptStateInternal;
 import org.apache.tez.dag.app.dag.TaskTerminationCause;
 import org.apache.tez.dag.app.dag.Vertex;
@@ -130,6 +132,7 @@
 import org.apache.tez.dag.app.dag.event.TaskEventScheduleTask;
 import org.apache.tez.dag.app.dag.event.TaskEventTermination;
 import org.apache.tez.dag.app.dag.event.TaskEventType;
+import org.apache.tez.dag.app.dag.event.VertexShuffleDataDeletion;
 import org.apache.tez.dag.app.dag.event.VertexEvent;
 import org.apache.tez.dag.app.dag.event.VertexEventCommitCompleted;
 import org.apache.tez.dag.app.dag.event.VertexEventInputDataInformation;
@@ -187,6 +190,7 @@
 import org.apache.tez.runtime.api.impl.TaskSpec;
 import org.apache.tez.runtime.api.impl.TaskStatistics;
 import org.apache.tez.runtime.api.impl.TezEvent;
+import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils;
 import org.apache.tez.state.OnStateChangedCallback;
 import org.apache.tez.state.StateMachineTez;
 import org.apache.tez.util.StringInterner;
@@ -556,7 +560,8 @@
                   VertexEventType.V_ROUTE_EVENT,
                   VertexEventType.V_SOURCE_TASK_ATTEMPT_COMPLETED,
                   VertexEventType.V_TASK_ATTEMPT_COMPLETED,
-                  VertexEventType.V_TASK_RESCHEDULED))
+                  VertexEventType.V_TASK_RESCHEDULED,
+                  VertexEventType.V_DELETE_SHUFFLE_DATA))
 
           // Transitions from SUCCEEDED state
           .addTransition(
@@ -592,6 +597,9 @@
           .addTransition(VertexState.SUCCEEDED, VertexState.SUCCEEDED,
               VertexEventType.V_TASK_ATTEMPT_COMPLETED,
               new TaskAttemptCompletedEventTransition())
+          .addTransition(VertexState.SUCCEEDED, VertexState.SUCCEEDED,
+              VertexEventType.V_DELETE_SHUFFLE_DATA,
+              new VertexShuffleDeleteTransition())
 
 
           // Transitions from FAILED state
@@ -613,7 +621,8 @@
                   VertexEventType.V_ROOT_INPUT_INITIALIZED,
                   VertexEventType.V_SOURCE_TASK_ATTEMPT_COMPLETED,
                   VertexEventType.V_NULL_EDGE_INITIALIZED,
-                  VertexEventType.V_INPUT_DATA_INFORMATION))
+                  VertexEventType.V_INPUT_DATA_INFORMATION,
+                  VertexEventType.V_DELETE_SHUFFLE_DATA))
 
           // Transitions from KILLED state
           .addTransition(
@@ -635,7 +644,8 @@
                   VertexEventType.V_TASK_COMPLETED,
                   VertexEventType.V_ROOT_INPUT_INITIALIZED,
                   VertexEventType.V_NULL_EDGE_INITIALIZED,
-                  VertexEventType.V_INPUT_DATA_INFORMATION))
+                  VertexEventType.V_INPUT_DATA_INFORMATION,
+                  VertexEventType.V_DELETE_SHUFFLE_DATA))
 
           // No transitions from INTERNAL_ERROR state. Ignore all.
           .addTransition(
@@ -655,7 +665,8 @@
                   VertexEventType.V_INTERNAL_ERROR,
                   VertexEventType.V_ROOT_INPUT_INITIALIZED,
                   VertexEventType.V_NULL_EDGE_INITIALIZED,
-                  VertexEventType.V_INPUT_DATA_INFORMATION))
+                  VertexEventType.V_INPUT_DATA_INFORMATION,
+                  VertexEventType.V_DELETE_SHUFFLE_DATA))
           // create the topology tables
           .installTopology();
 
@@ -729,6 +740,9 @@
   @VisibleForTesting
   Map<Vertex, Edge> sourceVertices;
   private Map<Vertex, Edge> targetVertices;
+  private boolean cleanupShuffleDataAtVertexLevel;
+  @VisibleForTesting
+  VertexShuffleDataDeletionContext vShuffleDeletionContext;
   Set<Edge> uninitializedEdges = Sets.newHashSet();
   // using a linked hash map to conveniently map edge names to a contiguous index
   LinkedHashMap<String, Integer> ioIndices = Maps.newLinkedHashMap();
@@ -1151,7 +1165,9 @@
         .append(", ContainerLauncher=").append(containerLauncherIdentifier).append(":").append(containerLauncherName)
         .append(", TaskCommunicator=").append(taskCommunicatorIdentifier).append(":").append(taskCommName);
     LOG.info(sb.toString());
-
+    cleanupShuffleDataAtVertexLevel = vertexConf.getInt(TezConfiguration.TEZ_AM_VERTEX_CLEANUP_HEIGHT,
+            TezConfiguration.TEZ_AM_VERTEX_CLEANUP_HEIGHT_DEFAULT) > 0 &&
+            ShuffleUtils.isTezShuffleHandler(vertexConf);
     stateMachine = new StateMachineTez<VertexState, VertexEventType, VertexEvent, VertexImpl>(
         stateMachineFactory.make(this), this);
     augmentStateMachine();
@@ -2306,6 +2322,12 @@
       if((vertexSucceeded || vertexFailuresBelowThreshold) && vertex.terminationCause == null) {
         if(vertexSucceeded) {
           LOG.info("All tasks have succeeded, vertex:" + vertex.logIdentifier);
+          if (vertex.cleanupShuffleDataAtVertexLevel) {
+
+            for (Vertex v : vertex.vShuffleDeletionContext.getAncestors()) {
+              vertex.eventHandler.handle(new VertexShuffleDataDeletion(vertex, v));
+            }
+          }
         } else {
           LOG.info("All tasks in the vertex " + vertex.logIdentifier + " have completed and the percentage of failed tasks (failed/total) (" + vertex.failedTaskCount + "/" + vertex.numTasks + ") is less that the threshold of " + vertex.maxFailuresPercent);
           vertex.addDiagnostic("Vertex succeeded as percentage of failed tasks (failed/total) (" + vertex.failedTaskCount + "/" + vertex.numTasks + ") is less that the threshold of " + vertex.maxFailuresPercent);
@@ -3758,6 +3780,36 @@
     }
   }
 
+  private static class VertexShuffleDeleteTransition implements
+          SingleArcTransition<VertexImpl, VertexEvent> {
+
+    @Override
+    public void transition(VertexImpl vertex, VertexEvent event) {
+      int incompleteChildrenVertices = vertex.vShuffleDeletionContext.getIncompleteChildrenVertices();
+      incompleteChildrenVertices = incompleteChildrenVertices - 1;
+      vertex.vShuffleDeletionContext.setIncompleteChildrenVertices(incompleteChildrenVertices);
+      // check if all the child vertices are completed
+      if (incompleteChildrenVertices == 0) {
+        LOG.info("Vertex shuffle data deletion for vertex name: " +
+                vertex.getName() + " with vertex id: " + vertex.getVertexId());
+        // Get nodes of all the task attempts in vertex
+        Set<NodeId> nodes = Sets.newHashSet();
+        Map<TezTaskID, Task> tasksMap = vertex.getTasks();
+        tasksMap.keySet().forEach(taskId -> {
+          Map<TezTaskAttemptID, TaskAttempt> taskAttemptMap = tasksMap.get(taskId).getAttempts();
+          taskAttemptMap.keySet().forEach(attemptId -> {
+            nodes.add(taskAttemptMap.get(attemptId).getNodeId());
+          });
+        });
+        vertex.appContext.getAppMaster().vertexComplete(
+                vertex.vertexId, nodes);
+      } else {
+        LOG.debug("The number of incomplete child vertex are {} for the vertex {}",
+            incompleteChildrenVertices, vertex.vertexId);
+      }
+    }
+  }
+
   private static class TaskCompletedAfterVertexSuccessTransition implements
     MultipleArcTransition<VertexImpl, VertexEvent, VertexState> {
     @Override
@@ -4930,4 +4982,14 @@
   public Map<String, Set<String>> getDownstreamBlamingHosts(){
     return downstreamBlamingHosts;
   }
+
+  /**
+   * Initialize context from vertex shuffle deletion.
+   * @param deletionHeight
+   */
+  public void initShuffleDeletionContext(int deletionHeight) {
+    VertexShuffleDataDeletionContext vShuffleDeletionContext = new VertexShuffleDataDeletionContext(deletionHeight);
+    vShuffleDeletionContext.setSpannedVertices(this);
+    this.vShuffleDeletionContext = vShuffleDeletionContext;
+  }
 }
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexShuffleDataDeletionContext.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexShuffleDataDeletionContext.java
new file mode 100644
index 0000000..4ffdf11
--- /dev/null
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexShuffleDataDeletionContext.java
@@ -0,0 +1,96 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app.dag.impl;
+
+import org.apache.tez.dag.app.dag.Vertex;
+
+import java.util.HashSet;
+import java.util.Set;
+
+public class VertexShuffleDataDeletionContext {
+  private int deletionHeight;
+  private int incompleteChildrenVertices;
+  private Set<Vertex> ancestors;
+  private Set<Vertex> children;
+
+  public VertexShuffleDataDeletionContext(int deletionHeight) {
+    this.deletionHeight = deletionHeight;
+    this.incompleteChildrenVertices = 0;
+    this.ancestors = new HashSet<>();
+    this.children = new HashSet<>();
+  }
+
+  public void setSpannedVertices(Vertex vertex) {
+    getSpannedVerticesAncestors(vertex, ancestors, deletionHeight);
+    getSpannedVerticesChildren(vertex, children, deletionHeight);
+    setIncompleteChildrenVertices(children.size());
+  }
+
+  /**
+   * get all the ancestor vertices at a particular depth.
+   */
+  private static void getSpannedVerticesAncestors(Vertex vertex, Set<Vertex> ancestorVertices, int level) {
+    if (level == 0) {
+      ancestorVertices.add(vertex);
+      return;
+    }
+
+    if (level == 1) {
+      ancestorVertices.addAll(vertex.getInputVertices().keySet());
+      return;
+    }
+
+    vertex.getInputVertices().forEach((inVertex, edge) -> getSpannedVerticesAncestors(inVertex, ancestorVertices,
+        level - 1));
+  }
+
+  /**
+   * get all the child vertices at a particular depth.
+   */
+  private static void getSpannedVerticesChildren(Vertex vertex, Set<Vertex> childVertices, int level) {
+    if (level == 0) {
+      childVertices.add(vertex);
+      return;
+    }
+
+    if (level == 1) {
+      childVertices.addAll(vertex.getOutputVertices().keySet());
+      return;
+    }
+
+    vertex.getOutputVertices().forEach((outVertex, edge) -> getSpannedVerticesChildren(outVertex, childVertices,
+        level - 1));
+  }
+
+  public void setIncompleteChildrenVertices(int incompleteChildrenVertices) {
+    this.incompleteChildrenVertices = incompleteChildrenVertices;
+  }
+
+  public int getIncompleteChildrenVertices() {
+    return this.incompleteChildrenVertices;
+  }
+
+  public Set<Vertex> getAncestors() {
+    return this.ancestors;
+  }
+
+  public Set<Vertex> getChildren() {
+    return this.children;
+  }
+}
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/ContainerLauncherManager.java b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/ContainerLauncherManager.java
index b0e0f0c..65360d6 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/ContainerLauncherManager.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/ContainerLauncherManager.java
@@ -16,6 +16,7 @@
 
 import java.net.UnknownHostException;
 import java.util.List;
+import java.util.Set;
 
 import com.google.common.annotations.VisibleForTesting;
 import org.apache.tez.common.Preconditions;
@@ -37,6 +38,7 @@
 import org.apache.tez.dag.app.dag.event.DAGAppMasterEventUserServiceFatalError;
 import org.apache.tez.dag.records.TezDAGID;
 import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezVertexID;
 import org.apache.tez.serviceplugins.api.ContainerLaunchRequest;
 import org.apache.tez.serviceplugins.api.ContainerLauncher;
 import org.apache.tez.serviceplugins.api.ContainerLauncherContext;
@@ -202,6 +204,12 @@
     }
   }
 
+  public void vertexComplete(TezVertexID vertex, JobTokenSecretManager secretManager, Set<NodeId> nodeIdList) {
+    for (int i = 0; i < containerLaunchers.length; i++) {
+      containerLaunchers[i].vertexComplete(vertex, secretManager, nodeIdList);
+    }
+  }
+
   public void taskAttemptFailed(TezTaskAttemptID taskAttemptID, JobTokenSecretManager secretManager, NodeId nodeId) {
     for (int i = 0; i < containerLaunchers.length; i++) {
       containerLaunchers[i].taskAttemptFailed(taskAttemptID, secretManager, nodeId);
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/ContainerLauncherWrapper.java b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/ContainerLauncherWrapper.java
index 5d262bd..4703abe 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/ContainerLauncherWrapper.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/ContainerLauncherWrapper.java
@@ -14,11 +14,14 @@
 
 package org.apache.tez.dag.app.launcher;
 
+import java.util.Set;
+
 import org.apache.tez.common.DagContainerLauncher;
 import org.apache.hadoop.yarn.api.records.NodeId;
 import org.apache.tez.common.security.JobTokenSecretManager;
 import org.apache.tez.dag.records.TezDAGID;
 import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezVertexID;
 import org.apache.tez.serviceplugins.api.ContainerLaunchRequest;
 import org.apache.tez.serviceplugins.api.ContainerLauncher;
 import org.apache.tez.serviceplugins.api.ContainerStopRequest;
@@ -49,6 +52,12 @@
     }
   }
 
+  public void vertexComplete(TezVertexID vertex, JobTokenSecretManager jobTokenSecretManager, Set<NodeId> nodeIdList) {
+    if (real instanceof DagContainerLauncher) {
+      ((DagContainerLauncher) real).vertexComplete(vertex, jobTokenSecretManager, nodeIdList);
+    }
+  }
+
   public void taskAttemptFailed(TezTaskAttemptID taskAttemptID, JobTokenSecretManager jobTokenSecretManager,
                                 NodeId nodeId) {
     if (real instanceof DagContainerLauncher) {
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/DeletionTracker.java b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/DeletionTracker.java
index 87b7366..56760c8 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/DeletionTracker.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/DeletionTracker.java
@@ -18,11 +18,14 @@
 
 package org.apache.tez.dag.app.launcher;
 
+import java.util.Set;
+
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.yarn.api.records.NodeId;
 import org.apache.tez.common.security.JobTokenSecretManager;
 import org.apache.tez.dag.records.TezDAGID;
 import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezVertexID;
 
 public abstract class DeletionTracker {
 
@@ -36,6 +39,10 @@
     //do nothing
   }
 
+  public void vertexComplete(TezVertexID vertex, JobTokenSecretManager jobTokenSecretManager, Set<NodeId> nodeIdList) {
+    //do nothing
+  }
+
   public void taskAttemptFailed(TezTaskAttemptID taskAttemptID, JobTokenSecretManager jobTokenSecretManager,
                                 NodeId nodeId) {
     //do nothing
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/DeletionTrackerImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/DeletionTrackerImpl.java
index e4204bf..73eaa68 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/DeletionTrackerImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/DeletionTrackerImpl.java
@@ -21,6 +21,7 @@
 
 import java.util.HashMap;
 import java.util.Map;
+import java.util.Set;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.RejectedExecutionException;
@@ -35,6 +36,7 @@
 import org.apache.hadoop.conf.Configuration;
 import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezVertexID;
 import org.apache.tez.runtime.library.common.TezRuntimeUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -113,4 +115,25 @@
     }
     nodeIdShufflePortMap = null;
   }
+
+  @Override
+  public void vertexComplete(TezVertexID vertex, JobTokenSecretManager jobTokenSecretManager, Set<NodeId> nodeIdList) {
+    super.vertexComplete(vertex, jobTokenSecretManager, nodeIdList);
+    String vertexId = String.format("%02d", vertex.getId());
+    for (NodeId nodeId : nodeIdList) {
+      Integer shufflePort = null;
+      if (nodeIdShufflePortMap != null) {
+        shufflePort = nodeIdShufflePortMap.get(nodeId);
+      }
+      if (shufflePort != null) {
+        VertexDeleteRunnable vertexDeleteRunnable = new VertexDeleteRunnable(vertex, jobTokenSecretManager, nodeId,
+                shufflePort, vertexId, TezRuntimeUtils.getHttpConnectionParams(conf));
+        try {
+          dagCleanupService.submit(vertexDeleteRunnable);
+        } catch (RejectedExecutionException rejectedException) {
+          LOG.info("Ignoring deletion request for " + vertexDeleteRunnable);
+        }
+      }
+    }
+  }
 }
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java
index ebc8f95..47cc9f1 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/LocalContainerLauncher.java
@@ -24,6 +24,7 @@
 import java.nio.ByteBuffer;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.Set;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.Callable;
 import java.util.concurrent.CancellationException;
@@ -50,6 +51,7 @@
 import org.apache.tez.dag.records.TezDAGID;
 import org.apache.tez.common.security.JobTokenSecretManager;
 import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezVertexID;
 import org.apache.tez.runtime.library.common.TezRuntimeUtils;
 import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils;
 import org.apache.tez.serviceplugins.api.ContainerLaunchRequest;
@@ -96,6 +98,7 @@
   int shufflePort = TezRuntimeUtils.INVALID_PORT;
   private DeletionTracker deletionTracker;
   private boolean dagDelete;
+  private boolean vertexDelete;
   private boolean failedTaskAttemptDelete;
 
   private final ConcurrentHashMap<ContainerId, ListenableFuture<?>>
@@ -162,11 +165,14 @@
     dagDelete = ShuffleUtils.isTezShuffleHandler(conf) &&
         conf.getBoolean(TezConfiguration.TEZ_AM_DAG_CLEANUP_ON_COMPLETION,
         TezConfiguration.TEZ_AM_DAG_CLEANUP_ON_COMPLETION_DEFAULT);
+    vertexDelete = ShuffleUtils.isTezShuffleHandler(conf) &&
+            conf.getInt(TezConfiguration.TEZ_AM_VERTEX_CLEANUP_HEIGHT,
+                    TezConfiguration.TEZ_AM_VERTEX_CLEANUP_HEIGHT_DEFAULT) > 0;
     failedTaskAttemptDelete = ShuffleUtils.isTezShuffleHandler(conf) &&
         conf.getBoolean(TezConfiguration.TEZ_AM_TASK_ATTEMPT_CLEANUP_ON_FAILURE,
         TezConfiguration.TEZ_AM_TASK_ATTEMPT_CLEANUP_ON_FAILURE_DEFAULT);
 
-    if (dagDelete || failedTaskAttemptDelete) {
+    if (dagDelete || vertexDelete || failedTaskAttemptDelete) {
       String deletionTrackerClassName = conf.get(TezConfiguration.TEZ_AM_DELETION_TRACKER_CLASS,
           TezConfiguration.TEZ_AM_DELETION_TRACKER_CLASS_DEFAULT);
       deletionTracker = ReflectionUtils.createClazzInstance(
@@ -455,6 +461,13 @@
   }
 
   @Override
+  public void vertexComplete(TezVertexID dag, JobTokenSecretManager jobTokenSecretManager, Set<NodeId> nodeIdList) {
+    if (vertexDelete && deletionTracker != null) {
+      deletionTracker.vertexComplete(dag, jobTokenSecretManager, nodeIdList);
+    }
+  }
+
+  @Override
   public void taskAttemptFailed(TezTaskAttemptID taskAttemptID, JobTokenSecretManager jobTokenSecretManager,
                                 NodeId nodeId) {
     if (failedTaskAttemptDelete && deletionTracker != null) {
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/TezContainerLauncherImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/TezContainerLauncherImpl.java
index 88ed4f7..654224a 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/TezContainerLauncherImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/TezContainerLauncherImpl.java
@@ -21,8 +21,8 @@
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.Map;
+import java.util.Set;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.LinkedBlockingQueue;
@@ -43,6 +43,7 @@
 import org.apache.tez.dag.api.TezException;
 import org.apache.tez.dag.records.TezDAGID;
 import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezVertexID;
 import org.apache.tez.runtime.library.common.TezRuntimeUtils;
 import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils;
 import org.apache.tez.serviceplugins.api.ContainerLaunchRequest;
@@ -93,6 +94,7 @@
   private AtomicBoolean serviceStopped = new AtomicBoolean(false);
   private DeletionTracker deletionTracker = null;
   private boolean dagDelete;
+  private boolean vertexDelete;
   private boolean failedTaskAttemptDelete;
 
   private Container getContainer(ContainerOp event) {
@@ -339,11 +341,14 @@
     dagDelete = ShuffleUtils.isTezShuffleHandler(conf) &&
         conf.getBoolean(TezConfiguration.TEZ_AM_DAG_CLEANUP_ON_COMPLETION,
             TezConfiguration.TEZ_AM_DAG_CLEANUP_ON_COMPLETION_DEFAULT);
+    vertexDelete = ShuffleUtils.isTezShuffleHandler(conf) &&
+            conf.getInt(TezConfiguration.TEZ_AM_VERTEX_CLEANUP_HEIGHT,
+                    TezConfiguration.TEZ_AM_VERTEX_CLEANUP_HEIGHT_DEFAULT) > 0;
     failedTaskAttemptDelete = ShuffleUtils.isTezShuffleHandler(conf) &&
         conf.getBoolean(TezConfiguration.TEZ_AM_TASK_ATTEMPT_CLEANUP_ON_FAILURE,
             TezConfiguration.TEZ_AM_TASK_ATTEMPT_CLEANUP_ON_FAILURE_DEFAULT);
 
-    if (dagDelete || failedTaskAttemptDelete) {
+    if (dagDelete || vertexDelete || failedTaskAttemptDelete) {
       String deletionTrackerClassName = conf.get(TezConfiguration.TEZ_AM_DELETION_TRACKER_CLASS,
           TezConfiguration.TEZ_AM_DELETION_TRACKER_CLASS_DEFAULT);
       deletionTracker = ReflectionUtils.createClazzInstance(
@@ -455,6 +460,13 @@
   }
 
   @Override
+  public void vertexComplete(TezVertexID vertex, JobTokenSecretManager jobTokenSecretManager, Set<NodeId> nodeIdList) {
+    if (vertexDelete && deletionTracker != null) {
+      deletionTracker.vertexComplete(vertex, jobTokenSecretManager, nodeIdList);
+    }
+  }
+
+  @Override
   public void taskAttemptFailed(TezTaskAttemptID taskAttemptID, JobTokenSecretManager jobTokenSecretManager,
                                 NodeId nodeId) {
     if (failedTaskAttemptDelete && deletionTracker != null) {
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/VertexDeleteRunnable.java b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/VertexDeleteRunnable.java
new file mode 100644
index 0000000..a8d2537
--- /dev/null
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/launcher/VertexDeleteRunnable.java
@@ -0,0 +1,82 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app.launcher;
+
+import org.apache.hadoop.yarn.api.records.NodeId;
+import org.apache.tez.common.security.JobTokenSecretManager;
+import org.apache.tez.dag.records.TezVertexID;
+import org.apache.tez.http.BaseHttpConnection;
+import org.apache.tez.http.HttpConnectionParams;
+import org.apache.tez.runtime.library.common.TezRuntimeUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.net.URL;
+
+public class VertexDeleteRunnable implements Runnable {
+  private static final Logger LOG = LoggerFactory.getLogger(VertexDeleteRunnable.class);
+  final private TezVertexID vertex;
+  final private JobTokenSecretManager jobTokenSecretManager;
+  final private NodeId nodeId;
+  final private int shufflePort;
+  final private String vertexId;
+  final private HttpConnectionParams httpConnectionParams;
+
+  VertexDeleteRunnable(TezVertexID vertex, JobTokenSecretManager jobTokenSecretManager,
+                              NodeId nodeId, int shufflePort, String vertexId,
+                              HttpConnectionParams httpConnectionParams) {
+    this.vertex = vertex;
+    this.jobTokenSecretManager = jobTokenSecretManager;
+    this.nodeId = nodeId;
+    this.shufflePort = shufflePort;
+    this.vertexId = vertexId;
+    this.httpConnectionParams = httpConnectionParams;
+  }
+
+  @Override
+  public void run() {
+    BaseHttpConnection httpConnection = null;
+    try {
+      URL baseURL = TezRuntimeUtils.constructBaseURIForShuffleHandlerVertexComplete(
+          nodeId.getHost(), shufflePort,
+          vertex.getDAGID().getApplicationId().toString(), vertex.getDAGID().getId(), vertexId, false);
+      httpConnection = TezRuntimeUtils.getHttpConnection(true, baseURL, httpConnectionParams,
+          "VertexDelete", jobTokenSecretManager);
+      httpConnection.connect();
+      httpConnection.getInputStream();
+    } catch (Exception e) {
+      LOG.warn("Could not setup HTTP Connection to the node %s " + nodeId.getHost() +
+          " for vertex shuffle delete. ", e);
+    } finally {
+      try {
+        if (httpConnection != null) {
+          httpConnection.cleanup(true);
+        }
+      } catch (IOException e) {
+        LOG.warn("Encountered IOException for " + nodeId.getHost() + " during close. ", e);
+      }
+    }
+  }
+
+  @Override
+  public String toString() {
+    return "VertexDeleteRunnable nodeId=" + nodeId + ", shufflePort=" + shufflePort + ", vertexId=" + vertexId;
+  }
+}
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
index 5cdcf49..c118110 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
@@ -20,6 +20,15 @@
 
 import java.nio.ByteBuffer;
 
+import org.apache.tez.common.TezUtils;
+import org.apache.tez.common.security.JobTokenSecretManager;
+import org.apache.tez.dag.api.NamedEntityDescriptor;
+import org.apache.tez.dag.app.DAGAppMaster;
+import org.apache.tez.dag.app.dag.TaskAttempt;
+import org.apache.tez.dag.app.launcher.ContainerLauncherManager;
+import org.apache.tez.dag.app.launcher.TezContainerLauncherImpl;
+import org.apache.tez.dag.app.rm.container.AMContainer;
+import org.apache.tez.serviceplugins.api.ContainerLauncherDescriptor;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
@@ -2395,12 +2404,254 @@
         .build();
   }
 
-  private void setupVertices() {
+  /**
+   * The dag is of the following structure.
+   *    vertex1   vertex2
+   *         \  /
+   *         vertex 3
+   *         /   \
+   *     vertex4 vertex5
+   *         \   /
+   *         vertex6
+   * @return dagPlan
+   */
+
+  public DAGPlan createDAGPlanVertexShuffleDelete() {
+    LOG.info("Setting up dag plan");
+    DAGPlan dag = DAGPlan.newBuilder()
+        .setName("testverteximpl")
+        .setDagConf(DAGProtos.ConfigurationProto.newBuilder()
+            .addConfKeyValues(DAGProtos.PlanKeyValuePair.newBuilder()
+                .setKey(TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS)
+                .setValue(3 + "")))
+        .addVertex(
+            VertexPlan.newBuilder()
+                .setName("vertex1")
+                .setType(PlanVertexType.NORMAL)
+                .addTaskLocationHint(
+                    PlanTaskLocationHint.newBuilder()
+                        .addHost("host1")
+                        .addRack("rack1")
+                        .build()
+                )
+                .setTaskConfig(
+                    PlanTaskConfiguration.newBuilder()
+                        .setNumTasks(1)
+                        .setVirtualCores(4)
+                        .setMemoryMb(1024)
+                        .setJavaOpts("")
+                        .setTaskModule("x1.y1")
+                        .build()
+                )
+                .setVertexConf(DAGProtos.ConfigurationProto.newBuilder()
+                    .addConfKeyValues(DAGProtos.PlanKeyValuePair.newBuilder()
+                        .setKey(TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS)
+                        .setValue(2+"")))
+                .addOutEdgeId("e1")
+                .build()
+        )
+        .addVertex(
+            VertexPlan.newBuilder()
+                .setName("vertex2")
+                .setType(PlanVertexType.NORMAL)
+                .addTaskLocationHint(
+                    PlanTaskLocationHint.newBuilder()
+                        .addHost("host2")
+                        .addRack("rack2")
+                        .build()
+                )
+                .setTaskConfig(
+                    PlanTaskConfiguration.newBuilder()
+                        .setNumTasks(2)
+                        .setVirtualCores(4)
+                        .setMemoryMb(1024)
+                        .setJavaOpts("")
+                        .setTaskModule("x2.y2")
+                        .build()
+                )
+                .addOutEdgeId("e2")
+                .build()
+        )
+        .addVertex(
+            VertexPlan.newBuilder()
+                .setName("vertex3")
+                .setType(PlanVertexType.NORMAL)
+                .setProcessorDescriptor(TezEntityDescriptorProto.newBuilder().setClassName("x3.y3"))
+                .addTaskLocationHint(
+                    PlanTaskLocationHint.newBuilder()
+                        .addHost("host3")
+                        .addRack("rack3")
+                        .build()
+                )
+                .setTaskConfig(
+                    PlanTaskConfiguration.newBuilder()
+                        .setNumTasks(2)
+                        .setVirtualCores(4)
+                        .setMemoryMb(1024)
+                        .setJavaOpts("foo")
+                        .setTaskModule("x3.y3")
+                        .build()
+                )
+                .addInEdgeId("e1")
+                .addInEdgeId("e2")
+                .addOutEdgeId("e3")
+                .addOutEdgeId("e4")
+                .build()
+        )
+        .addVertex(
+            VertexPlan.newBuilder()
+                .setName("vertex4")
+                .setType(PlanVertexType.NORMAL)
+                .addTaskLocationHint(
+                    PlanTaskLocationHint.newBuilder()
+                        .addHost("host4")
+                        .addRack("rack4")
+                        .build()
+                )
+                .setTaskConfig(
+                    PlanTaskConfiguration.newBuilder()
+                        .setNumTasks(2)
+                        .setVirtualCores(4)
+                        .setMemoryMb(1024)
+                        .setJavaOpts("")
+                        .setTaskModule("x4.y4")
+                        .build()
+                )
+                .addInEdgeId("e3")
+                .addOutEdgeId("e5")
+                .build()
+        )
+        .addVertex(
+            VertexPlan.newBuilder()
+                .setName("vertex5")
+                .setType(PlanVertexType.NORMAL)
+                .addTaskLocationHint(
+                    PlanTaskLocationHint.newBuilder()
+                        .addHost("host5")
+                        .addRack("rack5")
+                        .build()
+                )
+                .setTaskConfig(
+                    PlanTaskConfiguration.newBuilder()
+                        .setNumTasks(2)
+                        .setVirtualCores(4)
+                        .setMemoryMb(1024)
+                        .setJavaOpts("")
+                        .setTaskModule("x5.y5")
+                        .build()
+                )
+                .addInEdgeId("e4")
+                .addOutEdgeId("e6")
+                .build()
+        )
+        .addVertex(
+            VertexPlan.newBuilder()
+                .setName("vertex6")
+                .setType(PlanVertexType.NORMAL)
+                .addTaskLocationHint(
+                    PlanTaskLocationHint.newBuilder()
+                        .addHost("host6")
+                        .addRack("rack6")
+                        .build()
+                )
+                .setTaskConfig(
+                    PlanTaskConfiguration.newBuilder()
+                        .setNumTasks(2)
+                        .setVirtualCores(4)
+                        .setMemoryMb(1024)
+                        .setJavaOpts("")
+                        .setTaskModule("x6.y6")
+                        .build()
+                )
+                .addInEdgeId("e5")
+                .addInEdgeId("e6")
+                .build()
+        )
+        .addEdge(
+            EdgePlan.newBuilder()
+                .setEdgeDestination(TezEntityDescriptorProto.newBuilder().setClassName("i3_v1"))
+                .setInputVertexName("vertex1")
+                .setEdgeSource(TezEntityDescriptorProto.newBuilder().setClassName("o1"))
+                .setOutputVertexName("vertex3")
+                .setDataMovementType(PlanEdgeDataMovementType.SCATTER_GATHER)
+                .setId("e1")
+                .setDataSourceType(PlanEdgeDataSourceType.PERSISTED)
+                .setSchedulingType(PlanEdgeSchedulingType.SEQUENTIAL)
+                .build()
+        )
+        .addEdge(
+            EdgePlan.newBuilder()
+                .setEdgeDestination(TezEntityDescriptorProto.newBuilder().setClassName("i3_v2"))
+                .setInputVertexName("vertex2")
+                .setEdgeSource(TezEntityDescriptorProto.newBuilder().setClassName("o2"))
+                .setOutputVertexName("vertex3")
+                .setDataMovementType(PlanEdgeDataMovementType.SCATTER_GATHER)
+                .setId("e2")
+                .setDataSourceType(PlanEdgeDataSourceType.PERSISTED)
+                .setSchedulingType(PlanEdgeSchedulingType.SEQUENTIAL)
+                .build()
+        )
+        .addEdge(
+            EdgePlan.newBuilder()
+                .setEdgeDestination(TezEntityDescriptorProto.newBuilder().setClassName("i4_v3"))
+                .setInputVertexName("vertex3")
+                .setEdgeSource(TezEntityDescriptorProto.newBuilder().setClassName("o3_v4"))
+                .setOutputVertexName("vertex4")
+                .setDataMovementType(PlanEdgeDataMovementType.SCATTER_GATHER)
+                .setId("e3")
+                .setDataSourceType(PlanEdgeDataSourceType.PERSISTED)
+                .setSchedulingType(PlanEdgeSchedulingType.SEQUENTIAL)
+                .build()
+        )
+        .addEdge(
+            EdgePlan.newBuilder()
+                .setEdgeDestination(TezEntityDescriptorProto.newBuilder().setClassName("i5_v3"))
+                .setInputVertexName("vertex3")
+                .setEdgeSource(TezEntityDescriptorProto.newBuilder().setClassName("o3_v5"))
+                .setOutputVertexName("vertex5")
+                .setDataMovementType(PlanEdgeDataMovementType.SCATTER_GATHER)
+                .setId("e4")
+                .setDataSourceType(PlanEdgeDataSourceType.PERSISTED)
+                .setSchedulingType(PlanEdgeSchedulingType.SEQUENTIAL)
+                .build()
+        )
+        .addEdge(
+            EdgePlan.newBuilder()
+                .setEdgeDestination(TezEntityDescriptorProto.newBuilder().setClassName("i6_v4"))
+                .setInputVertexName("vertex4")
+                .setEdgeSource(TezEntityDescriptorProto.newBuilder().setClassName("o4"))
+                .setOutputVertexName("vertex6")
+                .setDataMovementType(PlanEdgeDataMovementType.SCATTER_GATHER)
+                .setId("e5")
+                .setDataSourceType(PlanEdgeDataSourceType.PERSISTED)
+                .setSchedulingType(PlanEdgeSchedulingType.SEQUENTIAL)
+                .build()
+        )
+        .addEdge(
+            EdgePlan.newBuilder()
+                .setEdgeDestination(TezEntityDescriptorProto.newBuilder().setClassName("i6_v5"))
+                .setInputVertexName("vertex5")
+                .setEdgeSource(TezEntityDescriptorProto.newBuilder().setClassName("o5"))
+                .setOutputVertexName("vertex6")
+                .setDataMovementType(PlanEdgeDataMovementType.SCATTER_GATHER)
+                .setId("e6")
+                .setDataSourceType(PlanEdgeDataSourceType.PERSISTED)
+                .setSchedulingType(PlanEdgeSchedulingType.SEQUENTIAL)
+                .build()
+        )
+        .build();
+
+    return dag;
+  }
+
+  private void setupVertices(boolean cleanupShuffleDataAtVertexLevel) {
     int vCnt = dagPlan.getVertexCount();
     LOG.info("Setting up vertices from dag plan, verticesCnt=" + vCnt);
     vertices = new HashMap<String, VertexImpl>();
     vertexIdMap = new HashMap<TezVertexID, VertexImpl>();
     Configuration dagConf = new Configuration(false);
+    dagConf.setBoolean(TezConfiguration.TEZ_AM_DAG_CLEANUP_ON_COMPLETION, true);
+    conf.setInt(TezConfiguration.TEZ_AM_VERTEX_CLEANUP_HEIGHT, cleanupShuffleDataAtVertexLevel ? 1 : 0);
     dagConf.set("abc", "foobar");
     for (int i = 0; i < vCnt; ++i) {
       VertexPlan vPlan = dagPlan.getVertex(i);
@@ -2447,7 +2698,6 @@
 
       Map<Vertex, Edge> outVertices =
           new HashMap<Vertex, Edge>();
-
       for(String inEdgeId : vertexPlan.getInEdgeIdList()){
         EdgePlan edgePlan = edgePlans.get(inEdgeId);
         Vertex inVertex = this.vertices.get(edgePlan.getInputVertexName());
@@ -2472,8 +2722,14 @@
           + ", outputVerticesCnt=" + outVertices.size());
       vertex.setOutputVertices(outVertices);
     }
+
+    for (Map.Entry<String, VertexImpl> vertex : vertices.entrySet()) {
+      VertexImpl vertexImpl = vertex.getValue();
+      vertexImpl.initShuffleDeletionContext(2);
+    }
   }
 
+
   public void setupPreDagCreation() {
     LOG.info("____________ RESETTING CURRENT DAG ____________");
     conf = new Configuration();
@@ -2488,8 +2744,9 @@
   }
 
   @SuppressWarnings({ "unchecked", "rawtypes" })
-  public void setupPostDagCreation() throws TezException {
+  public void setupPostDagCreation(boolean cleanupShuffleDataAtVertexLevel) throws TezException {
     String dagName = "dag0";
+    taskCommunicatorManagerInterface = mock(TaskCommunicatorManagerInterface.class);
     // dispatcher may be created multiple times (setupPostDagCreation may be called multiples)
     if (dispatcher != null) {
       dispatcher.stop();
@@ -2499,6 +2756,40 @@
     when(appContext.getHadoopShim()).thenReturn(new DefaultHadoopShim());
     when(appContext.getContainerLauncherName(anyInt())).thenReturn(
         TezConstants.getTezYarnServicePluginName());
+    DAGAppMaster mockDagAppMaster = mock(DAGAppMaster.class);
+    when(appContext.getAppMaster()).thenReturn(mockDagAppMaster);
+    doCallRealMethod().when(mockDagAppMaster).vertexComplete(any(TezVertexID.class), any(Set.class));
+    List<NamedEntityDescriptor> containerDescriptors = new ArrayList<>();
+    ContainerLauncherDescriptor containerLaunchers =
+        ContainerLauncherDescriptor.create("ContainerLaunchers",
+            TezContainerLauncherImpl.class.getName());
+    conf.setBoolean(TezConfiguration.TEZ_AM_DAG_CLEANUP_ON_COMPLETION, true);
+    conf.set(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID, "tez_shuffle");
+    conf.setInt(TezConfiguration.TEZ_AM_VERTEX_CLEANUP_HEIGHT, 0);
+    try {
+      containerLaunchers.setUserPayload(UserPayload.create(
+          TezUtils.createByteStringFromConf(conf).asReadOnlyByteBuffer()));
+    } catch (IOException e) {
+      e.printStackTrace();
+    }
+    containerDescriptors.add(containerLaunchers);
+    ContainerLauncherManager mockContainerLauncherManager = spy(new ContainerLauncherManager(appContext,
+        taskCommunicatorManagerInterface, "test", containerDescriptors, false));
+    doCallRealMethod().when(mockContainerLauncherManager).vertexComplete(any(
+        TezVertexID.class), any(JobTokenSecretManager.class
+    ), any(Set.class));
+    when(appContext.getAppMaster().getContainerLauncherManager()).thenReturn(
+        mockContainerLauncherManager);
+    mockContainerLauncherManager.init(conf);
+    mockContainerLauncherManager.start();
+    AMContainerMap amContainerMap = mock(AMContainerMap.class);
+    AMContainer amContainer = mock(AMContainer.class);
+    Container mockContainer = mock(Container.class);
+    when(amContainer.getContainer()).thenReturn(mockContainer);
+    when(mockContainer.getNodeId()).thenReturn(mock(NodeId.class));
+    when(mockContainer.getNodeHttpAddress()).thenReturn("localhost:12345");
+    when(amContainerMap.get(any(ContainerId.class))).thenReturn(amContainer);
+    when(appContext.getAllContainers()).thenReturn(amContainerMap);
 
     thh = mock(TaskHeartbeatHandler.class);
     historyEventHandler = mock(HistoryEventHandler.class);
@@ -2557,7 +2848,7 @@
       updateTracker.stop();
     }
     updateTracker = new StateChangeNotifierForTest(appContext.getCurrentDAG());
-    setupVertices();
+    setupVertices(cleanupShuffleDataAtVertexLevel);
     when(dag.getVertex(any(TezVertexID.class))).thenAnswer(new Answer<Vertex>() {
       @Override
       public Vertex answer(InvocationOnMock invocation) throws Throwable {
@@ -2622,7 +2913,7 @@
     setupPreDagCreation();
     dagPlan = createTestDAGPlan();
     invalidDagPlan = createInvalidDAGPlan();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
   }
 
   @After
@@ -2750,7 +3041,7 @@
   public void testNonExistVertexManager() throws TezException {
     setupPreDagCreation();
     dagPlan = createDAGPlanWithNonExistVertexManager();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
     VertexImpl v1 = vertices.get("vertex1");
     v1.handle(new VertexEvent(v1.getVertexId(), VertexEventType.V_INIT));
     Assert.assertEquals(VertexState.FAILED, v1.getState());
@@ -2763,7 +3054,7 @@
   public void testNonExistInputInitializer() throws TezException {
     setupPreDagCreation();
     dagPlan = createDAGPlanWithNonExistInputInitializer();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
     VertexImpl v1 = vertices.get("vertex1");
     v1.handle(new VertexEvent(v1.getVertexId(), VertexEventType.V_INIT));
     Assert.assertEquals(VertexState.FAILED, v1.getState());
@@ -2776,7 +3067,7 @@
   public void testNonExistOutputCommitter() throws TezException {
     setupPreDagCreation();
     dagPlan = createDAGPlanWithNonExistOutputCommitter();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
     VertexImpl v1 = vertices.get("vertex1");
     v1.handle(new VertexEvent(v1.getVertexId(), VertexEventType.V_INIT));
     Assert.assertEquals(VertexState.FAILED, v1.getState());
@@ -2815,7 +3106,7 @@
     setupPreDagCreation();
     // initialize() will make VM call planned() and started() will make VM call done()
     dagPlan = createDAGPlanWithVMException("TestVMStateUpdate", VMExceptionLocation.NoExceptionDoReconfigure);
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     TestUpdateListener listener = new TestUpdateListener();
     updateTracker
@@ -3824,7 +4115,7 @@
     conf.setFloat(TezConfiguration.TEZ_VERTEX_FAILURES_MAXPERCENT, 50.0f);
     conf.setInt(TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS, 1);
     dagPlan = createTestDAGPlan();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
     initAllVertices(VertexState.INITED);
 
     VertexImpl v4 = vertices.get("vertex4");
@@ -3879,7 +4170,7 @@
     conf.setFloat(TezConfiguration.TEZ_VERTEX_FAILURES_MAXPERCENT, 50.0f);
     conf.setInt(TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS, 1);
     dagPlan = createTestDAGPlan();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
     initAllVertices(VertexState.INITED);
 
     VertexImpl v4 = vertices.get("vertex4");
@@ -3978,7 +4269,7 @@
   public void testTerminatingVertexForTaskComplete() throws Exception {
     setupPreDagCreation();
     dagPlan = createSamplerDAGPlan(false);
-    setupPostDagCreation();
+    setupPostDagCreation(false);
     VertexImpl vertex = spy(vertices.get("A"));
     initVertex(vertex);
     startVertex(vertex);
@@ -3996,7 +4287,7 @@
   public void testTerminatingVertexForVComplete() throws Exception {
     setupPreDagCreation();
     dagPlan = createSamplerDAGPlan(false);
-    setupPostDagCreation();
+    setupPostDagCreation(false);
     VertexImpl vertex = spy(vertices.get("A"));
     initVertex(vertex);
     startVertex(vertex);
@@ -4251,7 +4542,7 @@
   public void testVertexInitWithCustomVertexManager() throws Exception {
     setupPreDagCreation();
     dagPlan = createDAGWithCustomVertexManager();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
     
     int numTasks = 3;
     VertexImpl v1 = vertices.get("v1");
@@ -4305,7 +4596,7 @@
   public void testVertexManagerHeuristic() throws TezException {
     setupPreDagCreation();
     dagPlan = createDAGPlanWithMixedEdges();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
     initAllVertices(VertexState.INITED);
     Assert.assertEquals(ImmediateStartVertexManager.class, 
         vertices.get("vertex1").getVertexManager().getPlugin().getClass());
@@ -4330,7 +4621,7 @@
     useCustomInitializer = true;
     setupPreDagCreation();
     dagPlan = createDAGPlanForOneToOneSplit("TestInputInitializer", -1, true);
-    setupPostDagCreation();
+    setupPostDagCreation(false);
     
     int numTasks = 5;
     VertexImplWithControlledInitializerManager v1 = (VertexImplWithControlledInitializerManager) vertices
@@ -4397,7 +4688,7 @@
     // create a diamond shaped dag with 1-1 edges. 
     setupPreDagCreation();
     dagPlan = createDAGPlanForOneToOneSplit(null, numTasks, false);
-    setupPostDagCreation();
+    setupPostDagCreation(false);
     VertexImpl v1 = vertices.get("vertex1");
     v1.vertexReconfigurationPlanned();
     initAllVertices(VertexState.INITED);
@@ -4436,7 +4727,7 @@
     // create a diamond shaped dag with 1-1 edges. 
     setupPreDagCreation();
     dagPlan = createDAGPlanForOneToOneSplit(null, numTasks, false);
-    setupPostDagCreation();
+    setupPostDagCreation(false);
     VertexImpl v1 = vertices.get("vertex1");
     v1.vertexReconfigurationPlanned();
     initAllVertices(VertexState.INITED);
@@ -4478,7 +4769,7 @@
     // create a diamond shaped dag with 1-1 edges. 
     setupPreDagCreation();
     dagPlan = createDAGPlanForOneToOneSplit(null, numTasks, false);
-    setupPostDagCreation();
+    setupPostDagCreation(false);
     VertexImpl v1 = vertices.get("vertex1");
     initAllVertices(VertexState.INITED);
     
@@ -4522,7 +4813,7 @@
     useCustomInitializer = true;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithInputInitializer("TestInputInitializer");
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithControlledInitializerManager v1 = (VertexImplWithControlledInitializerManager) vertices
         .get("vertex1");
@@ -4567,7 +4858,7 @@
     setupPreDagCreation();
     dagPlan =
         createDAGPlanWithInitializer0Tasks(RootInitializerSettingParallelismTo0.class.getName());
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImpl v1 = vertices.get("vertex1");
     VertexImpl v2 = vertices.get("vertex2");
@@ -4615,7 +4906,7 @@
     initializer.setNumVertexStateUpdateEvents(3);
     setupPreDagCreation();
     dagPlan = createDAGPlanWithRunningInitializer();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithRunningInputInitializer v1 =
         (VertexImplWithRunningInputInitializer) vertices.get("vertex1");
@@ -4650,7 +4941,7 @@
         (EventHandlingRootInputInitializer) customInitializer;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithRunningInitializer4();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithRunningInputInitializer v1 =
         (VertexImplWithRunningInputInitializer) vertices.get("vertex1");
@@ -4738,7 +5029,7 @@
     initializer.setNumExpectedEvents(4);
     setupPreDagCreation();
     dagPlan = createDAGPlanWithRunningInitializer4();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithRunningInputInitializer v1 =
         (VertexImplWithRunningInputInitializer) vertices.get("vertex1");
@@ -4861,7 +5152,7 @@
         (EventHandlingRootInputInitializer) customInitializer;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithRunningInitializer4();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithRunningInputInitializer v1 =
         (VertexImplWithRunningInputInitializer) vertices.get("vertex1");
@@ -4941,7 +5232,7 @@
         (EventHandlingRootInputInitializer) customInitializer;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithRunningInitializer3();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithRunningInputInitializer v1 =
         (VertexImplWithRunningInputInitializer) vertices.get("vertex1");
@@ -5027,7 +5318,7 @@
         (EventHandlingRootInputInitializer) customInitializer;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithRunningInitializer();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithRunningInputInitializer v1 =
         (VertexImplWithRunningInputInitializer) vertices.get("vertex1");
@@ -5104,7 +5395,7 @@
   public void testTaskSchedulingWithCustomEdges() throws TezException {
     setupPreDagCreation();
     dagPlan = createCustomDAGWithCustomEdges();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     /**
      *
@@ -5402,7 +5693,7 @@
     useCustomInitializer = true;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithMultipleInitializers("TestInputInitializer");
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithControlledInitializerManager v1 = (VertexImplWithControlledInitializerManager) vertices
         .get("vertex1");
@@ -5432,7 +5723,7 @@
     useCustomInitializer = true;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithMultipleInitializers("TestInputInitializer");
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithControlledInitializerManager v1 = (VertexImplWithControlledInitializerManager) vertices
         .get("vertex1");
@@ -5462,7 +5753,7 @@
     useCustomInitializer = true;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithInputInitializer("TestInputInitializer");
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithControlledInitializerManager v1 = (VertexImplWithControlledInitializerManager) vertices
         .get("vertex1");
@@ -5563,7 +5854,7 @@
     useCustomInitializer = true;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithInputInitializer("TestInputInitializer");
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithControlledInitializerManager v1 = (VertexImplWithControlledInitializerManager) vertices
         .get("vertex1");
@@ -5638,7 +5929,7 @@
     useCustomInitializer = true;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithInputDistributor("TestInputInitializer");
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithControlledInitializerManager v1 = (VertexImplWithControlledInitializerManager) vertices
         .get("vertex1");
@@ -5673,7 +5964,7 @@
     useCustomInitializer = true;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithInputInitializer("TestInputInitializer");
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     int expectedNumTasks = RootInputSpecUpdaterVertexManager.NUM_TASKS;
     VertexImplWithControlledInitializerManager v3 = (VertexImplWithControlledInitializerManager) vertices
@@ -5703,7 +5994,7 @@
     useCustomInitializer = true;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithInputInitializer("TestInputInitializer");
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     int expectedNumTasks = RootInputSpecUpdaterVertexManager.NUM_TASKS;
     VertexImplWithControlledInitializerManager v4 = (VertexImplWithControlledInitializerManager) vertices
@@ -6015,7 +6306,7 @@
   public void testVertexGroupInput() throws TezException {
     setupPreDagCreation();
     dagPlan = createVertexGroupDAGPlan();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImpl vA = vertices.get("A");
     VertexImpl vB = vertices.get("B");
@@ -6044,7 +6335,7 @@
     // been initialized
     setupPreDagCreation();
     dagPlan = createSamplerDAGPlan(true);
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImpl vA = vertices.get("A");
     VertexImpl vB = vertices.get("B");
@@ -6093,7 +6384,7 @@
     // been initialized
     setupPreDagCreation();
     dagPlan = createSamplerDAGPlan(true);
-    setupPostDagCreation();
+    setupPostDagCreation(false);
     
     VertexImpl vA = vertices.get("A");
     VertexImpl vB = vertices.get("B");
@@ -6167,7 +6458,7 @@
     // been initialized
     setupPreDagCreation();
     dagPlan = createSamplerDAGPlan(false);
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImpl vA = vertices.get("A");
     VertexImpl vB = vertices.get("B");
@@ -6190,7 +6481,7 @@
     // been initialized
     setupPreDagCreation();
     dagPlan = createSamplerDAGPlan2();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImpl vA = vertices.get("A");
     VertexImpl vB = vertices.get("B");
@@ -6215,7 +6506,7 @@
   public void testTez2684() throws IOException, TezException {
     setupPreDagCreation();
     dagPlan = createSamplerDAGPlan2();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImpl vA = vertices.get("A");
     VertexImpl vB = vertices.get("B");
@@ -6255,7 +6546,7 @@
   public void testVertexGraceParallelism() throws IOException, TezException {
     setupPreDagCreation();
     dagPlan = createDAGPlanForGraceParallelism();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImpl vA = vertices.get("A");
     VertexImpl vB = vertices.get("B");
@@ -6323,7 +6614,7 @@
     useCustomInitializer = true;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithCountingVM();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImpl v1 = vertices.get("vertex1");
     VertexImpl v2 = vertices.get("vertex2");
@@ -6380,7 +6671,7 @@
     useCustomInitializer = true;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithVMException("TestInputInitializer", VMExceptionLocation.Initialize);
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithControlledInitializerManager v1 = (VertexImplWithControlledInitializerManager) vertices
         .get("vertex1");
@@ -6399,7 +6690,7 @@
     useCustomInitializer = true;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithVMException("TestInputInitializer", VMExceptionLocation.OnRootVertexInitialized);
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithControlledInitializerManager v1 = (VertexImplWithControlledInitializerManager) vertices
         .get("vertex1");
@@ -6423,7 +6714,7 @@
     useCustomInitializer = true;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithVMException("TestInputInitializer", VMExceptionLocation.OnVertexStarted);
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithControlledInitializerManager v1 = (VertexImplWithControlledInitializerManager) vertices
         .get("vertex1");
@@ -6450,7 +6741,7 @@
     useCustomInitializer = true;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithVMException("TestInputInitializer", VMExceptionLocation.OnSourceTaskCompleted);
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithControlledInitializerManager v1 = (VertexImplWithControlledInitializerManager) vertices
         .get("vertex1");
@@ -6486,7 +6777,7 @@
     useCustomInitializer = true;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithVMException("TestInputInitializer", VMExceptionLocation.OnVertexManagerEventReceived);
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithControlledInitializerManager v1 = (VertexImplWithControlledInitializerManager) vertices
         .get("vertex1");
@@ -6514,7 +6805,7 @@
     useCustomInitializer = true;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithVMException("TestVMStateUpdate", VMExceptionLocation.OnVertexManagerVertexStateUpdated);
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithControlledInitializerManager v1 = (VertexImplWithControlledInitializerManager) vertices
         .get("vertex1");
@@ -6543,7 +6834,7 @@
         (EventHandlingRootInputInitializer) customInitializer;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithIIException();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithRunningInputInitializer v1 =
         (VertexImplWithRunningInputInitializer) vertices.get("vertex1");
@@ -6564,7 +6855,7 @@
     useCustomInitializer = true;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithIIException();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithControlledInitializerManager v1 =
         (VertexImplWithControlledInitializerManager)vertices.get("vertex1");
@@ -6588,7 +6879,7 @@
     useCustomInitializer = true;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithIIException();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithControlledInitializerManager v1 =
         (VertexImplWithControlledInitializerManager)vertices.get("vertex1");
@@ -6616,7 +6907,7 @@
         (EventHandlingRootInputInitializer) customInitializer;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithRunningInitializer();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithRunningInputInitializer v1 =
         (VertexImplWithRunningInputInitializer) vertices.get("vertex1");
@@ -6666,7 +6957,7 @@
         (EventHandlingRootInputInitializer) customInitializer;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithRunningInitializer();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithRunningInputInitializer v1 =
         (VertexImplWithRunningInputInitializer) vertices.get("vertex1");
@@ -6695,7 +6986,7 @@
         (EventHandlingRootInputInitializer) customInitializer;
     setupPreDagCreation();
     dagPlan = createDAGPlanWithRunningInitializer();
-    setupPostDagCreation();
+    setupPostDagCreation(false);
 
     VertexImplWithRunningInputInitializer v1 =
         (VertexImplWithRunningInputInitializer) vertices.get("vertex1");
@@ -7212,4 +7503,132 @@
     Assert.assertTrue(localResourceMap.containsKey("dag lr"));
     Assert.assertTrue(localResourceMap.containsKey("vertex lr"));
   }
+
+  @Test
+  public void testVertexShuffleDelete() throws Exception {
+    setupPreDagCreation();
+    dagPlan = createDAGPlanVertexShuffleDelete();
+    setupPostDagCreation(true);
+    checkSpannedVertices();
+    runVertices();
+    Mockito.verify(appContext.getAppMaster().getContainerLauncherManager(),
+        times(3)).vertexComplete(any(TezVertexID.class),
+        any(JobTokenSecretManager.class), any(Set.class));
+  }
+
+  private void checkSpannedVertices() {
+    // vertex1 should have 0 ancestor and 2 children at height = 2
+    VertexImpl v1 = vertices.get("vertex1");
+    checkResults(v1.vShuffleDeletionContext.getAncestors(), new ArrayList<>());
+    checkResults(v1.vShuffleDeletionContext.getChildren(), Arrays.asList("vertex5", "vertex4"));
+
+    // vertex2 should have 0 ancestor and 2 children at height = 2
+    VertexImpl v2 = vertices.get("vertex2");
+    checkResults(v2.vShuffleDeletionContext.getAncestors(), new ArrayList<>());
+    checkResults(v2.vShuffleDeletionContext.getChildren(), Arrays.asList("vertex5", "vertex4"));
+
+    // vertex3 should have 0 ancestor and 1 children at height = 2
+    VertexImpl v3 = vertices.get("vertex3");
+    checkResults(v3.vShuffleDeletionContext.getAncestors(), new ArrayList<>());
+    checkResults(v3.vShuffleDeletionContext.getChildren(), Arrays.asList("vertex6"));
+
+    // vertex4 should have 2 ancestor and 0 children at height = 2
+    VertexImpl v4 = vertices.get("vertex4");
+    checkResults(v4.vShuffleDeletionContext.getAncestors(), Arrays.asList("vertex1", "vertex2"));
+    checkResults(v4.vShuffleDeletionContext.getChildren(), new ArrayList<>());
+
+    // vertex5 should have 2 ancestor and 0 children at height = 2
+    VertexImpl v5 = vertices.get("vertex5");
+    checkResults(v5.vShuffleDeletionContext.getAncestors(), Arrays.asList("vertex1", "vertex2"));
+    checkResults(v5.vShuffleDeletionContext.getChildren(), new ArrayList<>());
+
+    // vertex6 should have 1 ancestor and 0 children at height = 2
+    VertexImpl v6 = vertices.get("vertex6");
+    checkResults(v6.vShuffleDeletionContext.getAncestors(), Arrays.asList("vertex3"));
+    checkResults(v6.vShuffleDeletionContext.getChildren(), new ArrayList<>());
+  }
+
+  private void checkResults(Set<Vertex> actual, List<String> expected) {
+    assertEquals(actual.size(), expected.size());
+    for (Vertex vertex : actual) {
+      assertTrue(expected.contains(vertex.getName()));
+    }
+  }
+
+  private void runVertices() {
+    VertexImpl v1 = vertices.get("vertex1");
+    VertexImpl v2 = vertices.get("vertex2");
+    VertexImpl v3 = vertices.get("vertex3");
+    VertexImpl v4 = vertices.get("vertex4");
+    VertexImpl v5 = vertices.get("vertex5");
+    VertexImpl v6 = vertices.get("vertex6");
+    dispatcher.getEventHandler().handle(new VertexEvent(v1.getVertexId(), VertexEventType.V_INIT));
+    dispatcher.getEventHandler().handle(new VertexEvent(v2.getVertexId(), VertexEventType.V_INIT));
+    dispatcher.await();
+    dispatcher.getEventHandler().handle(new VertexEvent(v1.getVertexId(), VertexEventType.V_START));
+    dispatcher.getEventHandler().handle(new VertexEvent(v2.getVertexId(), VertexEventType.V_START));
+    dispatcher.await();
+
+    TezTaskID v1t1 = TezTaskID.getInstance(v1.getVertexId(), 0);
+    Map<TezTaskAttemptID, TaskAttempt> attempts = v1.getTask(v1t1).getAttempts();
+    startAttempts(attempts);
+    v1.handle(new VertexEventTaskCompleted(v1t1, TaskState.SUCCEEDED));
+    TezTaskID v2t1 = TezTaskID.getInstance(v2.getVertexId(), 0);
+    attempts = v2.getTask(v2t1).getAttempts();
+    startAttempts(attempts);
+    v2.handle(new VertexEventTaskCompleted(v2t1, TaskState.SUCCEEDED));
+    TezTaskID v2t2 = TezTaskID.getInstance(v2.getVertexId(), 1);
+    attempts = v2.getTask(v2t2).getAttempts();
+    startAttempts(attempts);
+    v2.handle(new VertexEventTaskCompleted(v2t2, TaskState.SUCCEEDED));
+    TezTaskID v3t1 = TezTaskID.getInstance(v3.getVertexId(), 0);
+    v3.scheduleTasks(Lists.newArrayList(ScheduleTaskRequest.create(0, null)));
+    dispatcher.await();
+    attempts = v3.getTask(v3t1).getAttempts();
+    startAttempts(attempts);
+    v3.handle(new VertexEventTaskCompleted(v3t1, TaskState.SUCCEEDED));
+    TezTaskID v3t2 = TezTaskID.getInstance(v3.getVertexId(), 1);
+    attempts = v3.getTask(v3t2).getAttempts();
+    startAttempts(attempts);
+    v3.handle(new VertexEventTaskCompleted(v3t2, TaskState.SUCCEEDED));
+    dispatcher.await();
+    TezTaskID v4t1 = TezTaskID.getInstance(v4.getVertexId(), 0);
+    attempts = v4.getTask(v4t1).getAttempts();
+    startAttempts(attempts);
+    v4.handle(new VertexEventTaskCompleted(v4t1, TaskState.SUCCEEDED));
+    TezTaskID v4t2 = TezTaskID.getInstance(v4.getVertexId(), 1);
+    attempts = v4.getTask(v4t2).getAttempts();
+    startAttempts(attempts);
+    v4.handle(new VertexEventTaskCompleted(v4t2, TaskState.SUCCEEDED));
+    TezTaskID v5t1 = TezTaskID.getInstance(v5.getVertexId(), 0);
+    attempts = v5.getTask(v5t1).getAttempts();
+    startAttempts(attempts);
+    v5.handle(new VertexEventTaskCompleted(v5t1, TaskState.SUCCEEDED));
+    TezTaskID v5t2 = TezTaskID.getInstance(v5.getVertexId(), 1);
+    attempts = v5.getTask(v5t2).getAttempts();
+    startAttempts(attempts);
+    v5.handle(new VertexEventTaskCompleted(v5t2, TaskState.SUCCEEDED));
+    TezTaskID v6t1 = TezTaskID.getInstance(v6.getVertexId(), 0);
+    attempts = v6.getTask(v6t1).getAttempts();
+    startAttempts(attempts);
+    v6.handle(new VertexEventTaskCompleted(v6t1, TaskState.SUCCEEDED));
+    TezTaskID v6t2 = TezTaskID.getInstance(v6.getVertexId(), 1);
+    attempts = v6.getTask(v6t2).getAttempts();
+    startAttempts(attempts);
+    v6.handle(new VertexEventTaskCompleted(v6t2, TaskState.SUCCEEDED));
+    dispatcher.await();
+  }
+
+  private void startAttempts(Map<TezTaskAttemptID, TaskAttempt> attempts) {
+    for (Map.Entry<TezTaskAttemptID, TaskAttempt> entry : attempts.entrySet()) {
+      TezTaskAttemptID id = entry.getKey();
+      TaskAttemptImpl taskAttempt = (TaskAttemptImpl)entry.getValue();
+      taskAttempt.handle(new TaskAttemptEventSchedule(id, 10, 10));
+      dispatcher.await();
+      ContainerId mockContainer = mock(ContainerId.class, RETURNS_DEEP_STUBS);
+      taskAttempt.handle(new TaskAttemptEventSubmitted(id, mockContainer));
+      taskAttempt.handle(new TaskAttemptEventStartedRemotely(id));
+      dispatcher.await();
+    }
+  }
 }
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/launcher/TestContainerLauncherWrapper.java b/tez-dag/src/test/java/org/apache/tez/dag/app/launcher/TestContainerLauncherWrapper.java
index c4f4eff..cb7d62d 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/launcher/TestContainerLauncherWrapper.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/launcher/TestContainerLauncherWrapper.java
@@ -24,7 +24,7 @@
   @Test(timeout = 5000)
   public void testDelegation() throws Exception {
     PluginWrapperTestHelpers.testDelegation(ContainerLauncherWrapper.class, ContainerLauncher.class,
-        Sets.newHashSet("getContainerLauncher", "dagComplete", "taskAttemptFailed"));
+        Sets.newHashSet("getContainerLauncher", "dagComplete", "vertexComplete", "taskAttemptFailed"));
   }
 
 }
diff --git a/tez-plugins/tez-aux-services/src/main/java/org/apache/tez/auxservices/ShuffleHandler.java b/tez-plugins/tez-aux-services/src/main/java/org/apache/tez/auxservices/ShuffleHandler.java
index 7e6fd75..0fa1c03 100644
--- a/tez-plugins/tez-aux-services/src/main/java/org/apache/tez/auxservices/ShuffleHandler.java
+++ b/tez-plugins/tez-aux-services/src/main/java/org/apache/tez/auxservices/ShuffleHandler.java
@@ -18,6 +18,7 @@
 
 package org.apache.tez.auxservices;
 
+import org.apache.hadoop.fs.RemoteIterator;
 import org.apache.hadoop.util.DiskChecker;
 import static org.fusesource.leveldbjni.JniDBFactory.asString;
 import static org.fusesource.leveldbjni.JniDBFactory.bytes;
@@ -1009,6 +1010,7 @@
       final Map<String, List<String>> q = new QueryStringDecoder(request.getUri()).parameters();
       final List<String> keepAliveList = q.get("keepAlive");
       final List<String> dagCompletedQ = q.get("dagAction");
+      final List<String> vertexCompletedQ = q.get("vertexAction");
       final List<String> taskAttemptFailedQ = q.get("taskAttemptAction");
       boolean keepAliveParam = false;
       if (keepAliveList != null && keepAliveList.size() == 1) {
@@ -1019,6 +1021,7 @@
       final Range reduceRange = splitReduces(q.get("reduce"));
       final List<String> jobQ = q.get("job");
       final List<String> dagIdQ = q.get("dag");
+      final List<String> vertexIdQ = q.get("vertex");
       if (LOG.isDebugEnabled()) {
         LOG.debug("RECV: " + request.getUri() +
             "\n  mapId: " + mapIds +
@@ -1031,6 +1034,9 @@
       if (deleteDagDirectories(ctx.channel(), dagCompletedQ, jobQ, dagIdQ))  {
         return;
       }
+      if (deleteVertexDirectories(ctx.channel(), vertexCompletedQ, jobQ, dagIdQ, vertexIdQ)) {
+        return;
+      }
       if (deleteTaskAttemptDirectories(ctx.channel(), taskAttemptFailedQ, jobQ, dagIdQ, mapIds)) {
         return;
       }
@@ -1155,6 +1161,25 @@
       return false;
     }
 
+    private boolean deleteVertexDirectories(Channel channel, List<String> vertexCompletedQ,
+                                            List<String> jobQ, List<String> dagIdQ,
+                                            List<String> vertexIdQ) {
+      if (jobQ == null || jobQ.isEmpty()) {
+        return false;
+      }
+      if (notEmptyAndContains(vertexCompletedQ, "delete") && !isNullOrEmpty(vertexIdQ)) {
+        try {
+          deleteTaskDirsOfVertex(jobQ.get(0), dagIdQ.get(0), vertexIdQ.get(0), userRsrc.get(jobQ.get(0)));
+        } catch (IOException e) {
+          LOG.warn("Encountered exception during vertex delete " + e);
+        }
+        channel.writeAndFlush(new DefaultHttpResponse(HTTP_1_1, OK))
+                .addListener(ChannelFutureListener.CLOSE);
+        return true;
+      }
+      return false;
+    }
+
     private boolean deleteTaskAttemptDirectories(Channel channel, List<String> taskAttemptFailedQ,
                                             List<String> jobQ, List<String> dagIdQ, List<String> taskAttemptIdQ) {
       if (jobQ == null || jobQ.isEmpty()) {
@@ -1256,6 +1281,29 @@
       return baseStr;
     }
 
+    /**
+     * Delete shuffle data in task directories belonging to a vertex.
+     */
+    private void deleteTaskDirsOfVertex(String jobId, String dagId, String vertexId, String user) throws IOException {
+      String baseStr = getBaseLocation(jobId, dagId, user);
+      FileContext lfc = FileContext.getLocalFSFileContext();
+      for(Path dagPath : lDirAlloc.getAllLocalPathsToRead(baseStr, conf)) {
+        RemoteIterator<FileStatus> status = lfc.listStatus(dagPath);
+        final JobID jobID = JobID.forName(jobId);
+        String taskDirPrefix = String.format("attempt%s_%s_%s_",
+            jobID.toString().replace("job", ""), dagId, vertexId);
+        while (status.hasNext()) {
+          FileStatus fileStatus = status.next();
+          Path attemptPath = fileStatus.getPath();
+          if (attemptPath.getName().startsWith(taskDirPrefix)) {
+            if(lfc.delete(attemptPath, true)) {
+              LOG.debug("deleted shuffle data in task directory: {}", attemptPath);
+            }
+          }
+        }
+      }
+    }
+
     private String getDagLocation(String jobId, String dagId, String user) {
       final JobID jobID = JobID.forName(jobId);
       final ApplicationId appID =
diff --git a/tez-plugins/tez-aux-services/src/test/java/org/apache/tez/auxservices/TestShuffleHandler.java b/tez-plugins/tez-aux-services/src/test/java/org/apache/tez/auxservices/TestShuffleHandler.java
index 45dd0ad..b91e0eb 100644
--- a/tez-plugins/tez-aux-services/src/test/java/org/apache/tez/auxservices/TestShuffleHandler.java
+++ b/tez-plugins/tez-aux-services/src/test/java/org/apache/tez/auxservices/TestShuffleHandler.java
@@ -24,6 +24,7 @@
 import static org.junit.Assert.assertTrue;
 import static io.netty.buffer.Unpooled.wrappedBuffer;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
 import static org.junit.Assume.assumeTrue;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
@@ -1312,6 +1313,83 @@
     }
   }
 
+  @Test
+  public void testVertexShuffleDelete() throws Exception {
+    final ArrayList<Throwable> failures = new ArrayList<Throwable>(1);
+    Configuration conf = new Configuration();
+    conf.setInt(ShuffleHandler.MAX_SHUFFLE_CONNECTIONS, 3);
+    conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0);
+    conf.set(CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION,
+            "simple");
+    UserGroupInformation.setConfiguration(conf);
+    File absLogDir = new File("target", TestShuffleHandler.class.
+            getSimpleName() + "LocDir").getAbsoluteFile();
+    conf.set(YarnConfiguration.NM_LOCAL_DIRS, absLogDir.getAbsolutePath());
+    ApplicationId appId = ApplicationId.newInstance(12345L, 1);
+    String appAttemptId = "attempt_12345_0001_1_00_000000_0_10003_0";
+    String user = "randomUser";
+    List<File> fileMap = new ArrayList<File>();
+    String vertexDirStr = StringUtils.join(Path.SEPARATOR, new String[] { absLogDir.getAbsolutePath(),
+        ShuffleHandler.USERCACHE, user, ShuffleHandler.APPCACHE, appId.toString(), "dag_1/output/" + appAttemptId});
+    File vertexDir = new File(vertexDirStr);
+    Assert.assertFalse("vertex directory should not be present", vertexDir.exists());
+    createShuffleHandlerFiles(absLogDir, user, appId.toString(), appAttemptId,
+            conf, fileMap);
+    ShuffleHandler shuffleHandler = new ShuffleHandler() {
+      @Override
+      protected Shuffle getShuffle(Configuration conf) {
+        // replace the shuffle handler with one stubbed for testing
+        return new Shuffle(conf) {
+          @Override
+          protected void sendError(ChannelHandlerContext ctx, String message,
+                                   HttpResponseStatus status) {
+            if (failures.size() == 0) {
+              failures.add(new Error(message));
+              ctx.channel().close();
+            }
+          }
+        };
+      }
+    };
+    shuffleHandler.init(conf);
+    try {
+      shuffleHandler.start();
+      DataOutputBuffer outputBuffer = new DataOutputBuffer();
+      outputBuffer.reset();
+      Token<JobTokenIdentifier> jt =
+              new Token<JobTokenIdentifier>("identifier".getBytes(),
+                      "password".getBytes(), new Text(user), new Text("shuffleService"));
+      jt.write(outputBuffer);
+      shuffleHandler
+              .initializeApplication(new ApplicationInitializationContext(user,
+                      appId, ByteBuffer.wrap(outputBuffer.getData(), 0,
+                      outputBuffer.getLength())));
+      URL url =
+              new URL(
+                      "http://127.0.0.1:"
+                              + shuffleHandler.getConfig().get(
+                              ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY)
+                              + "/mapOutput?vertexAction=delete&job=job_12345_0001&dag=1&vertex=00");
+      HttpURLConnection conn = (HttpURLConnection) url.openConnection();
+      conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_NAME,
+              ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
+      conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION,
+              ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
+      Assert.assertTrue("Attempt Directory does not exist!", vertexDir.exists());
+      conn.connect();
+      try {
+        DataInputStream is = new DataInputStream(conn.getInputStream());
+        is.close();
+        Assert.assertFalse("Vertex Directory was not deleted", vertexDir.exists());
+      } catch (EOFException e) {
+        fail("Encountered Exception!" + e.getMessage());
+      }
+    } finally {
+      shuffleHandler.stop();
+      FileUtil.fullyDelete(absLogDir);
+    }
+  }
+
   @Test(timeout = 5000)
   public void testFailedTaskAttemptDelete() throws Exception {
     final ArrayList<Throwable> failures = new ArrayList<Throwable>(1);
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/TezRuntimeUtils.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/TezRuntimeUtils.java
index 48b23bc..a75925c 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/TezRuntimeUtils.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/TezRuntimeUtils.java
@@ -187,6 +187,25 @@
     return new URL(sb.toString());
   }
 
+  public static URL constructBaseURIForShuffleHandlerVertexComplete(
+       String host, int port, String appId, int dagIdentifier, String vertexIndentifier, boolean sslShuffle)
+       throws MalformedURLException {
+    String httpProtocol = (sslShuffle) ? "https://" : "http://";
+    StringBuilder sb = new StringBuilder(httpProtocol);
+    sb.append(host);
+    sb.append(":");
+    sb.append(port);
+    sb.append("/");
+    sb.append("mapOutput?vertexAction=delete");
+    sb.append("&job=");
+    sb.append(appId.replace("application", "job"));
+    sb.append("&dag=");
+    sb.append(String.valueOf(dagIdentifier));
+    sb.append("&vertex=");
+    sb.append(String.valueOf(vertexIndentifier));
+    return new URL(sb.toString());
+  }
+
   public static URL constructBaseURIForShuffleHandlerTaskAttemptFailed(
       String host, int port, String appId, int dagIdentifier, String taskAttemptIdentifier, boolean sslShuffle)
       throws MalformedURLException {