MAPREDUCE-2492. The new MapReduce API should make available task's progress to the task. (amarrk)

git-svn-id: https://svn.apache.org/repos/asf/hadoop/mapreduce/trunk@1126591 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/CHANGES.txt b/CHANGES.txt
index 0c757f6..c42ac83 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -14,6 +14,9 @@
 
   IMPROVEMENTS
 
+    MAPREDUCE-2492. The new MapReduce API should make available task's
+    progress to the task. (amarrk)
+
     MAPREDUCE-2153. Bring in more job configuration properties in to the trace 
     file. (Rajesh Balamohan via amarrk)
 
diff --git a/src/contrib/mrunit/src/java/org/apache/hadoop/mrunit/mapreduce/mock/MockReporter.java b/src/contrib/mrunit/src/java/org/apache/hadoop/mrunit/mapreduce/mock/MockReporter.java
index cc425fe..621344a 100644
--- a/src/contrib/mrunit/src/java/org/apache/hadoop/mrunit/mapreduce/mock/MockReporter.java
+++ b/src/contrib/mrunit/src/java/org/apache/hadoop/mrunit/mapreduce/mock/MockReporter.java
@@ -58,5 +58,10 @@
 
     return counter;
   }
+  
+  @Override
+  public float getProgress() {
+    return 0;
+  }
 }
 
diff --git a/src/contrib/mrunit/src/java/org/apache/hadoop/mrunit/mock/MockReporter.java b/src/contrib/mrunit/src/java/org/apache/hadoop/mrunit/mock/MockReporter.java
index 1fb9fa1..03bbf6c 100644
--- a/src/contrib/mrunit/src/java/org/apache/hadoop/mrunit/mock/MockReporter.java
+++ b/src/contrib/mrunit/src/java/org/apache/hadoop/mrunit/mock/MockReporter.java
@@ -92,5 +92,9 @@
 
     return counter;
   }
+  
+  public float getProgress() {
+    return 0;
+  };
 }
 
diff --git a/src/java/org/apache/hadoop/mapred/MapTask.java b/src/java/org/apache/hadoop/mapred/MapTask.java
index 15f9404..0310fc4 100644
--- a/src/java/org/apache/hadoop/mapred/MapTask.java
+++ b/src/java/org/apache/hadoop/mapred/MapTask.java
@@ -59,7 +59,6 @@
 import org.apache.hadoop.mapreduce.lib.map.WrappedMapper;
 import org.apache.hadoop.mapreduce.split.JobSplit.TaskSplitIndex;
 import org.apache.hadoop.mapreduce.task.MapContextImpl;
-import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.util.IndexedSortable;
 import org.apache.hadoop.util.IndexedSorter;
 import org.apache.hadoop.util.Progress;
@@ -294,8 +293,16 @@
     this.umbilical = umbilical;
 
     if (isMapTask()) {
-      mapPhase = getProgress().addPhase("map", 0.667f);
-      sortPhase  = getProgress().addPhase("sort", 0.333f);
+      // If there are no reducers then there won't be any sort. Hence the map 
+      // phase will govern the entire attempt's progress.
+      if (conf.getNumReduceTasks() == 0) {
+        mapPhase = getProgress().addPhase("map", 1.0f);
+      } else {
+        // If there are reducers then the entire attempt's progress will be 
+        // split between the map phase (67%) and the sort phase (33%).
+        mapPhase = getProgress().addPhase("map", 0.667f);
+        sortPhase  = getProgress().addPhase("sort", 0.333f);
+      }
     }
     TaskReporter reporter = startReporter(umbilical);
  
@@ -388,7 +395,10 @@
     try {
       runner.run(in, new OldOutputCollector(collector, conf), reporter);
       mapPhase.complete();
-      setPhase(TaskStatus.Phase.SORT);
+      // start the sort phase only if there are reducers
+      if (numReduceTasks > 0) {
+        setPhase(TaskStatus.Phase.SORT);
+      }
       statusUpdate(umbilical);
       collector.flush();
     } finally {
diff --git a/src/java/org/apache/hadoop/mapred/ReduceTask.java b/src/java/org/apache/hadoop/mapred/ReduceTask.java
index 6acf3c2..31cc587 100644
--- a/src/java/org/apache/hadoop/mapred/ReduceTask.java
+++ b/src/java/org/apache/hadoop/mapred/ReduceTask.java
@@ -361,6 +361,8 @@
                     taskStatus, copyPhase, sortPhase, this);
       rIter = shuffle.run();
     } else {
+      // local job runner doesn't have a copy phase
+      copyPhase.complete();
       final FileSystem rfs = FileSystem.getLocal(job).getRaw();
       rIter = Merger.merge(job, rfs, job.getMapOutputKeyClass(),
                            job.getMapOutputValueClass(), codec, 
diff --git a/src/java/org/apache/hadoop/mapred/Reporter.java b/src/java/org/apache/hadoop/mapred/Reporter.java
index ea8a18f..82ba71e 100644
--- a/src/java/org/apache/hadoop/mapred/Reporter.java
+++ b/src/java/org/apache/hadoop/mapred/Reporter.java
@@ -64,6 +64,10 @@
       public InputSplit getInputSplit() throws UnsupportedOperationException {
         throw new UnsupportedOperationException("NULL reporter has no input");
       }
+      @Override
+      public float getProgress() {
+        return 0;
+      }
     };
 
   /**
@@ -120,4 +124,10 @@
    */
   public abstract InputSplit getInputSplit() 
     throws UnsupportedOperationException;
+  
+  /**
+   * Get the progress of the task. Progress is represented as a number between
+   * 0 and 1 (inclusive).
+   */
+  public float getProgress();
 }
diff --git a/src/java/org/apache/hadoop/mapred/Task.java b/src/java/org/apache/hadoop/mapred/Task.java
index 66f0d03..230ad43 100644
--- a/src/java/org/apache/hadoop/mapred/Task.java
+++ b/src/java/org/apache/hadoop/mapred/Task.java
@@ -569,6 +569,11 @@
       // indicate that progress update needs to be sent
       setProgressFlag();
     }
+    
+    public float getProgress() {
+      return taskProgress.getProgress();
+    };
+    
     public void progress() {
       // indicate that progress update needs to be sent
       setProgressFlag();
diff --git a/src/java/org/apache/hadoop/mapred/TaskAttemptContextImpl.java b/src/java/org/apache/hadoop/mapred/TaskAttemptContextImpl.java
index dd35b4f..4e064b7 100644
--- a/src/java/org/apache/hadoop/mapred/TaskAttemptContextImpl.java
+++ b/src/java/org/apache/hadoop/mapred/TaskAttemptContextImpl.java
@@ -60,6 +60,11 @@
   public JobConf getJobConf() {
     return (JobConf) getConfiguration();
   }
+  
+  @Override
+  public float getProgress() {
+    return reporter.getProgress();
+  }
 
   @Override
   public Counter getCounter(Enum<?> counterName) {
diff --git a/src/java/org/apache/hadoop/mapreduce/StatusReporter.java b/src/java/org/apache/hadoop/mapreduce/StatusReporter.java
index 8f9c2b6..5ce721e 100644
--- a/src/java/org/apache/hadoop/mapreduce/StatusReporter.java
+++ b/src/java/org/apache/hadoop/mapreduce/StatusReporter.java
@@ -24,5 +24,11 @@
   public abstract Counter getCounter(Enum<?> name);
   public abstract Counter getCounter(String group, String name);
   public abstract void progress();
+  /**
+   * Get the current progress.
+   * @return a number between 0.0 and 1.0 (inclusive) indicating the attempt's 
+   * progress.
+   */
+  public abstract float getProgress();
   public abstract void setStatus(String status);
 }
diff --git a/src/java/org/apache/hadoop/mapreduce/TaskAttemptContext.java b/src/java/org/apache/hadoop/mapreduce/TaskAttemptContext.java
index 52335ca..7df6c36 100644
--- a/src/java/org/apache/hadoop/mapreduce/TaskAttemptContext.java
+++ b/src/java/org/apache/hadoop/mapreduce/TaskAttemptContext.java
@@ -44,6 +44,13 @@
    * @return the current status message
    */
   public String getStatus();
+  
+  /**
+   * The current progress of the task attempt.
+   * @return a number between 0.0 and 1.0 (inclusive) indicating the attempt's
+   * progress.
+   */
+  public abstract float getProgress();
 
   /**
    * Get the {@link Counter} for the given <code>counterName</code>.
diff --git a/src/java/org/apache/hadoop/mapreduce/lib/chain/ChainMapContextImpl.java b/src/java/org/apache/hadoop/mapreduce/lib/chain/ChainMapContextImpl.java
index 3efb0e2..598bb93 100644
--- a/src/java/org/apache/hadoop/mapreduce/lib/chain/ChainMapContextImpl.java
+++ b/src/java/org/apache/hadoop/mapreduce/lib/chain/ChainMapContextImpl.java
@@ -316,4 +316,8 @@
     return base.getCredentials();
   }
 
+  @Override
+  public float getProgress() {
+    return base.getProgress();
+  }
 }
diff --git a/src/java/org/apache/hadoop/mapreduce/lib/chain/ChainReduceContextImpl.java b/src/java/org/apache/hadoop/mapreduce/lib/chain/ChainReduceContextImpl.java
index c56096b..8d66484 100644
--- a/src/java/org/apache/hadoop/mapreduce/lib/chain/ChainReduceContextImpl.java
+++ b/src/java/org/apache/hadoop/mapreduce/lib/chain/ChainReduceContextImpl.java
@@ -308,4 +308,9 @@
   public Credentials getCredentials() {
     return base.getCredentials();
   }
+  
+  @Override
+  public float getProgress() {
+    return base.getProgress();
+  }
 }
diff --git a/src/java/org/apache/hadoop/mapreduce/lib/join/Parser.java b/src/java/org/apache/hadoop/mapreduce/lib/join/Parser.java
index 30dca47..275272b 100644
--- a/src/java/org/apache/hadoop/mapreduce/lib/join/Parser.java
+++ b/src/java/org/apache/hadoop/mapreduce/lib/join/Parser.java
@@ -386,6 +386,11 @@
     }
 
     @Override
+    public float getProgress() {
+      return context.getProgress();
+    }
+    
+    @Override
     public void setStatus(String status) {
       context.setStatus(status);
     }
diff --git a/src/java/org/apache/hadoop/mapreduce/lib/map/MultithreadedMapper.java b/src/java/org/apache/hadoop/mapreduce/lib/map/MultithreadedMapper.java
index e4b0fca..814e494 100644
--- a/src/java/org/apache/hadoop/mapreduce/lib/map/MultithreadedMapper.java
+++ b/src/java/org/apache/hadoop/mapreduce/lib/map/MultithreadedMapper.java
@@ -240,6 +240,10 @@
       outer.setStatus(status);
     }
     
+    @Override
+    public float getProgress() {
+      return outer.getProgress();
+    }
   }
 
   private class MapRunner extends Thread {
diff --git a/src/java/org/apache/hadoop/mapreduce/lib/map/WrappedMapper.java b/src/java/org/apache/hadoop/mapreduce/lib/map/WrappedMapper.java
index 5be33be..10761c1 100644
--- a/src/java/org/apache/hadoop/mapreduce/lib/map/WrappedMapper.java
+++ b/src/java/org/apache/hadoop/mapreduce/lib/map/WrappedMapper.java
@@ -317,5 +317,10 @@
     public Credentials getCredentials() {
       return mapContext.getCredentials();
     }
+    
+    @Override
+    public float getProgress() {
+      return mapContext.getProgress();
+    }
   }
 }
diff --git a/src/java/org/apache/hadoop/mapreduce/lib/output/MultipleOutputs.java b/src/java/org/apache/hadoop/mapreduce/lib/output/MultipleOutputs.java
index f0cb9bc..31dd281 100644
--- a/src/java/org/apache/hadoop/mapreduce/lib/output/MultipleOutputs.java
+++ b/src/java/org/apache/hadoop/mapreduce/lib/output/MultipleOutputs.java
@@ -471,6 +471,11 @@
     }
 
     @Override
+    public float getProgress() {
+      return context.getProgress();
+    }
+    
+    @Override
     public void setStatus(String status) {
       context.setStatus(status);
     }
diff --git a/src/java/org/apache/hadoop/mapreduce/lib/reduce/WrappedReducer.java b/src/java/org/apache/hadoop/mapreduce/lib/reduce/WrappedReducer.java
index f8ce5a9..5be02cb 100644
--- a/src/java/org/apache/hadoop/mapreduce/lib/reduce/WrappedReducer.java
+++ b/src/java/org/apache/hadoop/mapreduce/lib/reduce/WrappedReducer.java
@@ -321,5 +321,10 @@
     public Credentials getCredentials() {
       return reduceContext.getCredentials();
     }
+    
+    @Override
+    public float getProgress() {
+      return reduceContext.getProgress();
+    }
   }
 }
diff --git a/src/java/org/apache/hadoop/mapreduce/task/TaskAttemptContextImpl.java b/src/java/org/apache/hadoop/mapreduce/task/TaskAttemptContextImpl.java
index 16746c8..9b039b0 100644
--- a/src/java/org/apache/hadoop/mapreduce/task/TaskAttemptContextImpl.java
+++ b/src/java/org/apache/hadoop/mapreduce/task/TaskAttemptContextImpl.java
@@ -107,5 +107,13 @@
     public Counter getCounter(String group, String name) {
       return new Counters().findCounter(group, name);
     }
+    public float getProgress() {
+      return 0f;
+    }
+  }
+  
+  @Override
+  public float getProgress() {
+    return reporter.getProgress();
   }
 }
\ No newline at end of file
diff --git a/src/test/mapred/org/apache/hadoop/mapred/TestMapProgress.java b/src/test/mapred/org/apache/hadoop/mapred/TestMapProgress.java
index 90d1911..372a5fc 100644
--- a/src/test/mapred/org/apache/hadoop/mapred/TestMapProgress.java
+++ b/src/test/mapred/org/apache/hadoop/mapred/TestMapProgress.java
@@ -94,7 +94,7 @@
       }
       // validate map task progress when the map task is in map phase
       assertTrue("Map progress is not the expected value.",
-                 Math.abs(mapTaskProgress - ((0.667/3)*recordNum)) < 0.001);
+                 Math.abs(mapTaskProgress - ((float)recordNum/3)) < 0.001);
     }
   }
 
diff --git a/src/test/mapred/org/apache/hadoop/mapred/TestReporter.java b/src/test/mapred/org/apache/hadoop/mapred/TestReporter.java
new file mode 100644
index 0000000..43b1a1d
--- /dev/null
+++ b/src/test/mapred/org/apache/hadoop/mapred/TestReporter.java
@@ -0,0 +1,189 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.mapred;
+
+import java.io.IOException;
+import java.util.Iterator;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+/**
+ * Tests the old mapred APIs with {@link Reporter#getProgress()}.
+ */
+public class TestReporter {
+  private static final Path rootTempDir =
+    new Path(System.getProperty("test.build.data", "/tmp"));
+  private static final Path testRootTempDir = 
+    new Path(rootTempDir, "TestReporter");
+  
+  private static FileSystem fs = null;
+
+  @BeforeClass
+  public static void setup() throws Exception {
+    fs = FileSystem.getLocal(new Configuration());
+    fs.delete(testRootTempDir, true);
+    fs.mkdirs(testRootTempDir);
+  }
+
+  @AfterClass
+  public static void cleanup() throws Exception {
+    fs.delete(testRootTempDir, true);
+  }
+  
+  // an input with 4 lines
+  private static final String INPUT = "Hi\nHi\nHi\nHi\n";
+  private static final int INPUT_LINES = INPUT.split("\n").length;
+  
+  @SuppressWarnings("deprecation")
+  static class ProgressTesterMapper extends MapReduceBase 
+  implements Mapper<LongWritable, Text, Text, Text> {
+    private float progressRange = 0;
+    private int numRecords = 0;
+    private Reporter reporter = null;
+    
+    @Override
+    public void configure(JobConf job) {
+      super.configure(job);
+      // set the progress range accordingly
+      if (job.getNumReduceTasks() == 0) {
+        progressRange = 1f;
+      } else {
+        progressRange = 0.667f;
+      }
+    }
+    
+    @Override
+    public void map(LongWritable key, Text value, 
+                    OutputCollector<Text, Text> output, Reporter reporter) 
+    throws IOException {
+      this.reporter = reporter;
+      
+      // calculate the actual map progress
+      float mapProgress = ((float)++numRecords)/INPUT_LINES;
+      // calculate the attempt progress based on the progress range
+      float attemptProgress = progressRange * mapProgress;
+      assertEquals("Invalid progress in map", 
+                   attemptProgress, reporter.getProgress(), 0f);
+      output.collect(new Text(value.toString() + numRecords), value);
+    }
+    
+    @Override
+    public void close() throws IOException {
+      super.close();
+      assertEquals("Invalid progress in map cleanup", 
+                   progressRange, reporter.getProgress(), 0f);
+    }
+  }
+  
+  /**
+   * Test {@link Reporter}'s progress for a map-only job.
+   * This will make sure that only the map phase decides the attempt's progress.
+   */
+  @SuppressWarnings("deprecation")
+  @Test
+  public void testReporterProgressForMapOnlyJob() throws IOException {
+    Path test = new Path(testRootTempDir, "testReporterProgressForMapOnlyJob");
+    
+    JobConf conf = new JobConf();
+    conf.setMapperClass(ProgressTesterMapper.class);
+    conf.setMapOutputKeyClass(Text.class);
+    // fail early
+    conf.setMaxMapAttempts(1);
+    conf.setMaxReduceAttempts(0);
+    
+    RunningJob job = 
+      UtilsForTests.runJob(conf, new Path(test, "in"), new Path(test, "out"), 
+                           1, 0, INPUT);
+    job.waitForCompletion();
+    
+    assertTrue("Job failed", job.isSuccessful());
+  }
+  
+  /**
+   * A {@link Reducer} implementation that checks the progress on every call
+   * to {@link Reducer#reduce(Object, Iterator, OutputCollector, Reporter)}.
+   */
+  @SuppressWarnings("deprecation")
+  static class ProgressTestingReducer extends MapReduceBase 
+  implements Reducer<Text, Text, Text, Text> {
+    private int recordCount = 0;
+    private Reporter reporter = null;
+    // reduce task has a fixed split of progress amongst copy, shuffle and 
+    // reduce phases.
+    private final float REDUCE_PROGRESS_RANGE = 1.0f/3;
+    private final float SHUFFLE_PROGRESS_RANGE = 1 - REDUCE_PROGRESS_RANGE;
+    
+    @Override
+    public void configure(JobConf job) {
+      super.configure(job);
+    }
+    
+    @Override
+    public void reduce(Text key, Iterator<Text> values,
+        OutputCollector<Text, Text> output, Reporter reporter)
+    throws IOException {
+      float reducePhaseProgress = ((float)++recordCount)/INPUT_LINES;
+      float weightedReducePhaseProgress = 
+              reducePhaseProgress * REDUCE_PROGRESS_RANGE;
+      assertEquals("Invalid progress in reduce", 
+                   SHUFFLE_PROGRESS_RANGE + weightedReducePhaseProgress, 
+                   reporter.getProgress(), 0.02f);
+      this.reporter = reporter;
+    }
+    
+    @Override
+    public void close() throws IOException {
+      super.close();
+      assertEquals("Invalid progress in reduce cleanup", 
+                   1.0f, reporter.getProgress(), 0f);
+    }
+  }
+  
+  /**
+   * Test {@link Reporter}'s progress for map-reduce job.
+   */
+  @SuppressWarnings("deprecation")
+  @Test
+  public void testReporterProgressForMRJob() throws IOException {
+    Path test = new Path(testRootTempDir, "testReporterProgressForMRJob");
+    
+    JobConf conf = new JobConf();
+    conf.setMapperClass(ProgressTesterMapper.class);
+    conf.setReducerClass(ProgressTestingReducer.class);
+    conf.setMapOutputKeyClass(Text.class);
+    // fail early
+    conf.setMaxMapAttempts(1);
+    conf.setMaxReduceAttempts(1);
+
+    RunningJob job = 
+      UtilsForTests.runJob(conf, new Path(test, "in"), new Path(test, "out"), 
+                           1, 1, INPUT);
+    job.waitForCompletion();
+    
+    assertTrue("Job failed", job.isSuccessful());
+  }
+}
\ No newline at end of file
diff --git a/src/test/mapred/org/apache/hadoop/mapred/UtilsForTests.java b/src/test/mapred/org/apache/hadoop/mapred/UtilsForTests.java
index 08c81b7..aa5f47e 100644
--- a/src/test/mapred/org/apache/hadoop/mapred/UtilsForTests.java
+++ b/src/test/mapred/org/apache/hadoop/mapred/UtilsForTests.java
@@ -559,6 +559,16 @@
   static RunningJob runJob(JobConf conf, Path inDir, Path outDir, int numMaps, 
                            int numReds) throws IOException {
 
+    String input = "The quick brown fox\n" + "has many silly\n"
+                   + "red fox sox\n";
+    
+    // submit the job and wait for it to complete
+    return runJob(conf, inDir, outDir, numMaps, numReds, input);
+  }
+  
+  // Start a job with the specified input and return its RunningJob object
+  static RunningJob runJob(JobConf conf, Path inDir, Path outDir, int numMaps, 
+                           int numReds, String input) throws IOException {
     FileSystem fs = FileSystem.get(conf);
     if (fs.exists(outDir)) {
       fs.delete(outDir, true);
@@ -566,8 +576,7 @@
     if (!fs.exists(inDir)) {
       fs.mkdirs(inDir);
     }
-    String input = "The quick brown fox\n" + "has many silly\n"
-        + "red fox sox\n";
+    
     for (int i = 0; i < numMaps; ++i) {
       DataOutputStream file = fs.create(new Path(inDir, "part-" + i));
       file.writeBytes(input);
diff --git a/src/test/mapred/org/apache/hadoop/mapreduce/MapReduceTestUtil.java b/src/test/mapred/org/apache/hadoop/mapreduce/MapReduceTestUtil.java
index 65462b9..8351b53 100644
--- a/src/test/mapred/org/apache/hadoop/mapreduce/MapReduceTestUtil.java
+++ b/src/test/mapred/org/apache/hadoop/mapreduce/MapReduceTestUtil.java
@@ -388,6 +388,10 @@
       }
       public void progress() {
       }
+      @Override
+      public float getProgress() {
+        return 0;
+      }
       public Counter getCounter(Enum<?> name) {
         return new Counters().findCounter(name);
       }
diff --git a/src/test/mapred/org/apache/hadoop/mapreduce/TestTaskContext.java b/src/test/mapred/org/apache/hadoop/mapreduce/TestTaskContext.java
index bf78609..6ab42b4 100644
--- a/src/test/mapred/org/apache/hadoop/mapreduce/TestTaskContext.java
+++ b/src/test/mapred/org/apache/hadoop/mapreduce/TestTaskContext.java
@@ -18,16 +18,45 @@
 package org.apache.hadoop.mapreduce;
 
 import java.io.IOException;
+import java.util.Iterator;
 
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.LongWritable;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapred.HadoopTestCase;
+import org.apache.hadoop.mapreduce.MapReduceTestUtil.DataCopyMapper;
+import org.apache.hadoop.mapreduce.MapReduceTestUtil.DataCopyReducer;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
 
 /**
- * Tests context api. 
+ * Tests context api and {@link StatusReporter#getProgress()} via 
+ * {@link TaskAttemptContext#getProgress()} API . 
  */
 public class TestTaskContext extends HadoopTestCase {
+  private static final Path rootTempDir =
+    new Path(System.getProperty("test.build.data", "/tmp"));
+  private static final Path testRootTempDir = 
+    new Path(rootTempDir, "TestTaskContext");
+  
+  private static FileSystem fs = null;
+
+  @BeforeClass
+  public static void setup() throws Exception {
+    fs = FileSystem.getLocal(new Configuration());
+    fs.delete(testRootTempDir, true);
+    fs.mkdirs(testRootTempDir);
+  }
+
+  @AfterClass
+  public static void cleanup() throws Exception {
+    fs.delete(testRootTempDir, true);
+  }
+    
   public TestTaskContext() throws IOException {
     super(HadoopTestCase.CLUSTER_MR , HadoopTestCase.LOCAL_FS, 1, 1);
   }
@@ -48,16 +77,188 @@
    * @throws InterruptedException
    * @throws ClassNotFoundException
    */
+  @Test
   public void testContextStatus()
       throws IOException, InterruptedException, ClassNotFoundException {
+    Path test = new Path(testRootTempDir, "testContextStatus");
+    
+    // test with 1 map and 0 reducers
+    // test with custom task status
     int numMaps = 1;
-    Job job = MapReduceTestUtil.createJob(createJobConf(), new Path("in"),
-        new Path("out"), numMaps, 0);
+    Job job = MapReduceTestUtil.createJob(createJobConf(), 
+                new Path(test, "in"), new Path(test, "out"), numMaps, 0);
     job.setMapperClass(MyMapper.class);
     job.waitForCompletion(true);
     assertTrue("Job failed", job.isSuccessful());
     TaskReport[] reports = job.getTaskReports(TaskType.MAP);
     assertEquals(numMaps, reports.length);
-    assertEquals(myStatus + " > sort", reports[0].getState());
+    assertEquals(myStatus, reports[0].getState());
+    
+    // test with 1 map and 1 reducer
+    // test with default task status
+    int numReduces = 1;
+    job = MapReduceTestUtil.createJob(createJobConf(), 
+            new Path(test, "in"), new Path(test, "out"), numMaps, numReduces);
+    job.setMapperClass(DataCopyMapper.class);
+    job.setReducerClass(DataCopyReducer.class);
+    job.setMapOutputKeyClass(Text.class);
+    job.setMapOutputValueClass(Text.class);
+    job.setOutputKeyClass(Text.class);
+    job.setOutputValueClass(Text.class);
+    
+    // fail early
+    job.setMaxMapAttempts(1);
+    job.setMaxReduceAttempts(0);
+    
+    // run the job and wait for completion
+    job.waitForCompletion(true);
+    assertTrue("Job failed", job.isSuccessful());
+    
+    // check map task reports
+    reports = job.getTaskReports(TaskType.MAP);
+    assertEquals(numMaps, reports.length);
+    assertEquals("map > sort", reports[0].getState());
+    
+    // check reduce task reports
+    reports = job.getTaskReports(TaskType.REDUCE);
+    assertEquals(numReduces, reports.length);
+    assertEquals("reduce > reduce", reports[0].getState());
+  }
+  
+  // an input with 4 lines
+  private static final String INPUT = "Hi\nHi\nHi\nHi\n";
+  private static final int INPUT_LINES = INPUT.split("\n").length;
+  
+  @SuppressWarnings("unchecked")
+  static class ProgressCheckerMapper 
+  extends Mapper<LongWritable, Text, Text, Text> {
+    private int recordCount = 0;
+    private float progressRange = 0;
+    
+    @Override
+    protected void setup(Context context) throws IOException {
+      // check if the map task attempt progress is 0
+      assertEquals("Invalid progress in map setup", 
+                   0.0f, context.getProgress(), 0f);
+      
+      // define the progress boundaries
+      if (context.getNumReduceTasks() == 0) {
+        progressRange = 1f;
+      } else {
+        progressRange = 0.667f;
+      }
+    }
+    
+    @Override
+    protected void map(LongWritable key, Text value, 
+        org.apache.hadoop.mapreduce.Mapper.Context context) 
+    throws IOException ,InterruptedException {
+      // get the map phase progress
+      float mapPhaseProgress = ((float)++recordCount)/INPUT_LINES;
+      // get the weighted map phase progress
+      float weightedMapProgress = progressRange * mapPhaseProgress;
+      // check the map progress
+      assertEquals("Invalid progress in map", 
+                   weightedMapProgress, context.getProgress(), 0f);
+      
+      context.write(new Text(value.toString() + recordCount), value);
+    };
+    
+    protected void cleanup(Mapper.Context context) 
+    throws IOException, InterruptedException {
+      // check if the attempt progress is at the progress boundary 
+      assertEquals("Invalid progress in map cleanup", 
+                   progressRange, context.getProgress(), 0f);
+    };
+  }
+  
+  /**
+   * Tests new MapReduce map task's context.getProgress() method.
+   * 
+   * @throws IOException
+   * @throws InterruptedException
+   * @throws ClassNotFoundException
+   */
+  public void testMapContextProgress()
+      throws IOException, InterruptedException, ClassNotFoundException {
+    int numMaps = 1;
+    
+    Path test = new Path(testRootTempDir, "testMapContextProgress");
+    
+    Job job = MapReduceTestUtil.createJob(createJobConf(), 
+                new Path(test, "in"), new Path(test, "out"), numMaps, 0, INPUT);
+    job.setMapperClass(ProgressCheckerMapper.class);
+    job.setMapOutputKeyClass(Text.class);
+    
+    // fail early
+    job.setMaxMapAttempts(1);
+    
+    job.waitForCompletion(true);
+    assertTrue("Job failed", job.isSuccessful());
+  }
+  
+  @SuppressWarnings("unchecked")
+  static class ProgressCheckerReducer extends Reducer<Text, Text, 
+                                                      Text, Text> {
+    private int recordCount = 0;
+    private final float REDUCE_PROGRESS_RANGE = 1.0f/3;
+    private final float SHUFFLE_PROGRESS_RANGE = 1 - REDUCE_PROGRESS_RANGE;
+    
+    protected void setup(final Reducer.Context context) 
+    throws IOException, InterruptedException {
+      // Note that the reduce will read some segments before calling setup()
+      float reducePhaseProgress =  ((float)++recordCount)/INPUT_LINES;
+      float weightedReducePhaseProgress = 
+        REDUCE_PROGRESS_RANGE * reducePhaseProgress;
+      // check that the shuffle phase progress is accounted for
+      assertEquals("Invalid progress in reduce setup",
+                   SHUFFLE_PROGRESS_RANGE + weightedReducePhaseProgress, 
+                   context.getProgress(), 0.01f);
+    };
+    
+    public void reduce(Text key, Iterator<Text> values, Context context)
+    throws IOException, InterruptedException {
+      float reducePhaseProgress =  ((float)++recordCount)/INPUT_LINES;
+      float weightedReducePhaseProgress = 
+        REDUCE_PROGRESS_RANGE * reducePhaseProgress;
+      assertEquals("Invalid progress in reduce", 
+                   SHUFFLE_PROGRESS_RANGE + weightedReducePhaseProgress, 
+                   context.getProgress(), 0.01f);
+    }
+    
+    protected void cleanup(Reducer.Context context) 
+    throws IOException, InterruptedException {
+      // check if the reduce task has progress of 1 in the end
+      assertEquals("Invalid progress in reduce cleanup", 
+                   1.0f, context.getProgress(), 0f);
+    };
+  }
+  
+  /**
+   * Tests new MapReduce reduce task's context.getProgress() method.
+   * 
+   * @throws IOException
+   * @throws InterruptedException
+   * @throws ClassNotFoundException
+   */
+  @Test
+  public void testReduceContextProgress()
+      throws IOException, InterruptedException, ClassNotFoundException {
+    int numTasks = 1;
+    Path test = new Path(testRootTempDir, "testReduceContextProgress");
+    
+    Job job = MapReduceTestUtil.createJob(createJobConf(), 
+                new Path(test, "in"), new Path(test, "out"), numTasks, numTasks,
+                INPUT);
+    job.setMapperClass(ProgressCheckerMapper.class);
+    job.setReducerClass(ProgressCheckerReducer.class);
+    job.setMapOutputKeyClass(Text.class);
+    
+    // fail early
+    job.setMaxMapAttempts(1);
+    job.setMaxReduceAttempts(1);
+    
+    job.waitForCompletion(true);
+    assertTrue("Job failed", job.isSuccessful());
   }
 }