MRUNIT-205 - Add support for MultipleInputs (Jason E Tedor via Brock)
diff --git a/src/main/java/org/apache/hadoop/mrunit/MapDriverBase.java b/src/main/java/org/apache/hadoop/mrunit/MapDriverBase.java
index 2691648..73c1293 100644
--- a/src/main/java/org/apache/hadoop/mrunit/MapDriverBase.java
+++ b/src/main/java/org/apache/hadoop/mrunit/MapDriverBase.java
@@ -29,14 +29,14 @@
 import org.apache.hadoop.mrunit.types.Pair;
 
 /**
- * Harness that allows you to test a Mapper instance. You provide the input
- * (k, v)* pairs that should be sent to the Mapper, and outputs you expect to be
+ * Harness that allows you to test a Mapper instance. You provide the input (k,
+ * v)* pairs that should be sent to the Mapper, and outputs you expect to be
  * sent by the Mapper to the collector for those inputs. By calling runTest(),
  * the harness will deliver the input to the Mapper and will check its outputs
  * against the expected results.
  */
 public abstract class MapDriverBase<K1, V1, K2, V2, T extends MapDriverBase<K1, V1, K2, V2, T>>
-    extends TestDriver<K1, V1, K2, V2, T> {
+    extends TestDriver<K2, V2, T> {
 
   public static final Log LOG = LogFactory.getLog(MapDriverBase.class);
 
@@ -53,16 +53,17 @@
 
   /**
    * Sets the input key to send to the mapper
-   *
+   * 
    * @param key
    * @deprecated MRUNIT-64. Moved to list implementation to support multiple
-   *             input (k, v)*. Replaced by {@link #setInput},
-   *             {@link #addInput}, and {@link #addAll}
+   *             input (k, v)*. Replaced by {@link #setInput}, {@link #addInput}
+   *             , and {@link #addAll}
    */
   @Deprecated
   public void setInputKey(final K1 key) {
     inputKey = copy(key);
   }
+
   @Deprecated
   public K1 getInputKey() {
     return inputKey;
@@ -70,16 +71,17 @@
 
   /**
    * Sets the input value to send to the mapper
-   *
+   * 
    * @param val
    * @deprecated MRUNIT-64. Moved to list implementation to support multiple
-   *             input (k, v)*. Replaced by {@link #setInput},
-   *             {@link #addInput}, and {@link #addAll}
+   *             input (k, v)*. Replaced by {@link #setInput}, {@link #addInput}
+   *             , and {@link #addAll}
    */
   @Deprecated
   public void setInputValue(final V1 val) {
     inputVal = copy(val);
   }
+
   @Deprecated
   public V1 getInputValue() {
     return inputVal;
@@ -87,15 +89,15 @@
 
   /**
    * Sets the input to send to the mapper
-   *
+   * 
    */
   public void setInput(final K1 key, final V1 val) {
-  	setInput(new Pair<K1, V1>(key, val));
+    setInput(new Pair<K1, V1>(key, val));
   }
 
   /**
    * Sets the input to send to the mapper
-   *
+   * 
    * @param inputRecord
    *          a (key, val) pair
    */
@@ -110,7 +112,7 @@
 
   /**
    * Adds an input to send to the mapper
-   *
+   * 
    * @param key
    * @param val
    */
@@ -120,7 +122,7 @@
 
   /**
    * Adds an input to send to the mapper
-   *
+   * 
    * @param input
    *          a (K, V) pair
    */
@@ -130,7 +132,7 @@
 
   /**
    * Adds list of inputs to send to the mapper
-   *
+   * 
    * @param inputs
    *          list of (K, V) pairs
    */
@@ -150,7 +152,7 @@
   /**
    * Expects an input of the form "key \t val" Forces the Mapper input types to
    * Text.
-   *
+   * 
    * @param input
    *          A string of the form "key \t val".
    * @deprecated No replacement due to lack of type safety and incompatibility
@@ -171,7 +173,7 @@
 
   /**
    * Identical to setInputKey() but with fluent programming style
-   *
+   * 
    * @return this
    * @deprecated MRUNIT-64. Moved to list implementation to support multiple
    *             input (k, v)*. Replaced by {@link #withInput} and
@@ -185,7 +187,7 @@
 
   /**
    * Identical to setInputValue() but with fluent programming style
-   *
+   * 
    * @param val
    * @return this
    * @deprecated MRUNIT-64. Moved to list implementation to support multiple
@@ -199,9 +201,9 @@
   }
 
   /**
-   * Similar to setInput() but uses addInput() instead so accumulates values, and returns
-   * the class instance, conforming to the fluent programming style
-   *
+   * Similar to setInput() but uses addInput() instead so accumulates values,
+   * and returns the class instance, conforming to the fluent programming style
+   * 
    * @return this
    */
   public T withInput(final K1 key, final V1 val) {
@@ -211,7 +213,7 @@
 
   /**
    * Identical to setInput() but returns self for fluent programming style
-   *
+   * 
    * @param inputRecord
    * @return this
    */
@@ -222,7 +224,7 @@
 
   /**
    * Identical to setInputFromString, but with a fluent programming style
-   *
+   * 
    * @param input
    *          A string of the form "key \t val". Trims any whitespace.
    * @return this
@@ -237,7 +239,7 @@
 
   /**
    * Identical to addAll() but returns self for fluent programming style
-   *
+   * 
    * @param inputRecords
    * @return this
    */
@@ -254,7 +256,8 @@
   }
 
   /**
-   * @param mapInputPath Path which is to be passed to the mappers InputSplit
+   * @param mapInputPath
+   *          Path which is to be passed to the mappers InputSplit
    */
   public void setMapInputPath(Path mapInputPath) {
     this.mapInputPath = mapInputPath;
@@ -262,7 +265,7 @@
 
   /**
    * @param mapInputPath
-   *       The Path object which will be given to the mapper
+   *          The Path object which will be given to the mapper
    * @return this
    */
   public final T withMapInputPath(Path mapInputPath) {
@@ -289,9 +292,8 @@
 
     if (driverReused()) {
       throw new IllegalStateException("Driver reuse not allowed");
-    }
-    else {
-     setUsedOnceStatus();
+    } else {
+      setUsedOnceStatus();
     }
   }
 
@@ -301,7 +303,8 @@
   @Override
   protected void printPreTestDebugLog() {
     for (Pair<K1, V1> input : inputs) {
-      LOG.debug("Mapping input (" + input.getFirst() + ", " + input.getSecond() + ")");
+      LOG.debug("Mapping input (" + input.getFirst() + ", " + input.getSecond()
+          + ")");
     }
   }
 
diff --git a/src/main/java/org/apache/hadoop/mrunit/MapOutputShuffler.java b/src/main/java/org/apache/hadoop/mrunit/MapOutputShuffler.java
new file mode 100644
index 0000000..fec593f
--- /dev/null
+++ b/src/main/java/org/apache/hadoop/mrunit/MapOutputShuffler.java
@@ -0,0 +1,113 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mrunit;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mrunit.types.Pair;
+
+import java.util.*;
+
+public class MapOutputShuffler<K, V> {
+  private final Configuration configuration;
+  private final Comparator<K> outputKeyComparator;
+  private final Comparator<K> outputValueGroupingComparator;
+
+  public MapOutputShuffler(final Configuration configuration,
+      final Comparator<K> outputKeyComparator,
+      final Comparator<K> outputValueGroupingComparator) {
+    this.configuration = configuration;
+    this.outputKeyComparator = outputKeyComparator;
+    this.outputValueGroupingComparator = outputValueGroupingComparator;
+  }
+
+  public List<Pair<K, List<V>>> shuffle(final List<Pair<K, V>> mapOutputs) {
+
+    final Comparator<K> keyOrderComparator;
+    final Comparator<K> keyGroupComparator;
+
+    if (mapOutputs.isEmpty()) {
+      return Collections.emptyList();
+    }
+
+    // JobConf needs the map output key class to work out the
+    // comparator to use
+
+    JobConf conf = new JobConf(configuration != null ? configuration
+        : new Configuration());
+    K firstKey = mapOutputs.get(0).getFirst();
+    conf.setMapOutputKeyClass(firstKey.getClass());
+
+    // get the ordering comparator or work out from conf
+    if (outputKeyComparator == null) {
+      keyOrderComparator = conf.getOutputKeyComparator();
+    } else {
+      keyOrderComparator = outputKeyComparator;
+    }
+
+    // get the grouping comparator or work out from conf
+    if (outputValueGroupingComparator == null) {
+      keyGroupComparator = conf.getOutputValueGroupingComparator();
+    } else {
+      keyGroupComparator = outputValueGroupingComparator;
+    }
+
+    // sort the map outputs according to their keys
+    Collections.sort(mapOutputs, new Comparator<Pair<K, V>>() {
+      public int compare(final Pair<K, V> o1, final Pair<K, V> o2) {
+        return keyOrderComparator.compare(o1.getFirst(), o2.getFirst());
+      }
+    });
+
+    // apply grouping comparator to create groups
+    final Map<K, List<Pair<K, V>>> groupedByKey = new LinkedHashMap<K, List<Pair<K, V>>>();
+
+    List<Pair<K, V>> groupedKeyList = null;
+    Pair<K, V> previous = null;
+
+    for (final Pair<K, V> mapOutput : mapOutputs) {
+      if (previous == null
+          || keyGroupComparator.compare(previous.getFirst(),
+              mapOutput.getFirst()) != 0) {
+        groupedKeyList = new ArrayList<Pair<K, V>>();
+        groupedByKey.put(mapOutput.getFirst(), groupedKeyList);
+      }
+      groupedKeyList.add(mapOutput);
+      previous = mapOutput;
+    }
+
+    // populate output list
+    final List<Pair<K, List<V>>> outputKeyValuesList = new ArrayList<Pair<K, List<V>>>();
+    for (final Map.Entry<K, List<Pair<K, V>>> groupedByKeyEntry : groupedByKey
+        .entrySet()) {
+
+      // create list to hold values for the grouped key
+      final List<V> valuesList = new ArrayList<V>();
+      for (final Pair<K, V> pair : groupedByKeyEntry.getValue()) {
+        valuesList.add(pair.getSecond());
+      }
+
+      // add key and values to output list
+      outputKeyValuesList.add(new Pair<K, List<V>>(groupedByKeyEntry.getKey(),
+          valuesList));
+    }
+
+    return outputKeyValuesList;
+  }
+}
diff --git a/src/main/java/org/apache/hadoop/mrunit/MapReduceDriver.java b/src/main/java/org/apache/hadoop/mrunit/MapReduceDriver.java
index e9f76c5..518584c 100644
--- a/src/main/java/org/apache/hadoop/mrunit/MapReduceDriver.java
+++ b/src/main/java/org/apache/hadoop/mrunit/MapReduceDriver.java
@@ -17,21 +17,17 @@
  */
 package org.apache.hadoop.mrunit;
 
-import static org.apache.hadoop.mrunit.internal.util.ArgumentChecker.returnNonNull;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.mapred.*;
+import org.apache.hadoop.mrunit.internal.counters.CounterWrapper;
+import org.apache.hadoop.mrunit.types.Pair;
 
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.mapred.Counters;
-import org.apache.hadoop.mapred.InputFormat;
-import org.apache.hadoop.mapred.Mapper;
-import org.apache.hadoop.mapred.OutputFormat;
-import org.apache.hadoop.mapred.Reducer;
-import org.apache.hadoop.mrunit.internal.counters.CounterWrapper;
-import org.apache.hadoop.mrunit.types.Pair;
+import static org.apache.hadoop.mrunit.internal.util.ArgumentChecker.returnNonNull;
 
 /**
  * Harness that allows you to test a Mapper and a Reducer instance together
@@ -42,7 +38,7 @@
  * Reducer (without checking them), and will check the Reducer's outputs against
  * the expected results. This is designed to handle the (k, v)* -> (k, v)* case
  * from the Mapper/Reducer pair, representing a single unit test.
- *
+ * 
  * If a combiner is specified, then it will be run exactly once after the Mapper
  * and before the Reducer.
  */
@@ -86,7 +82,7 @@
 
   /**
    * Sets the counters object to use for this test.
-   *
+   * 
    * @param ctrs
    *          The counters object to use.
    */
@@ -104,7 +100,7 @@
 
   /**
    * Set the Mapper instance to use with this test driver
-   *
+   * 
    * @param m
    *          the Mapper instance to use
    */
@@ -128,7 +124,7 @@
 
   /**
    * Sets the reducer object to use for this test
-   *
+   * 
    * @param r
    *          The reducer object to use
    */
@@ -138,7 +134,7 @@
 
   /**
    * Identical to setReducer(), but with fluent programming style
-   *
+   * 
    * @param r
    *          The Reducer to use
    * @return this
@@ -158,7 +154,7 @@
 
   /**
    * Sets the reducer object to use as a combiner for this test
-   *
+   * 
    * @param c
    *          The combiner object to use
    */
@@ -168,7 +164,7 @@
 
   /**
    * Identical to setCombiner(), but with fluent programming style
-   *
+   * 
    * @param c
    *          The Combiner to use
    * @return this
@@ -189,7 +185,7 @@
   /**
    * Configure {@link Reducer} to output with a real {@link OutputFormat}. Set
    * {@link InputFormat} to read output back in for use with run* methods
-   *
+   * 
    * @param outputFormatClass
    * @param inputFormatClass
    * @return this for fluent style
@@ -203,47 +199,6 @@
     return this;
   }
 
-  /**
-   * The private class to manage starting the reduce phase is used for type
-   * genericity reasons. This class is used in the run() method.
-   */
-  private class ReducePhaseRunner<OUTKEY, OUTVAL> {
-    private List<Pair<OUTKEY, OUTVAL>> runReduce(
-        final List<Pair<K2, List<V2>>> inputs,
-        final Reducer<K2, V2, OUTKEY, OUTVAL> reducer) throws IOException {
-
-      final List<Pair<OUTKEY, OUTVAL>> reduceOutputs = new ArrayList<Pair<OUTKEY, OUTVAL>>();
-
-      if (!inputs.isEmpty()) {
-        if (LOG.isDebugEnabled()) {
-          final StringBuilder sb = new StringBuilder();
-          for (Pair<K2, List<V2>> input : inputs) {
-            formatValueList(input.getSecond(), sb);
-            LOG.debug("Reducing input (" + input.getFirst() + ", " + sb + ")");
-            sb.delete(0, sb.length());
-          }
-        }
-
-        final ReduceDriver<K2, V2, OUTKEY, OUTVAL> reduceDriver = ReduceDriver
-            .newReduceDriver(reducer).withCounters(getCounters())
-            .withConfiguration(getConfiguration()).withAll(inputs);
-
-        if (getOutputSerializationConfiguration() != null) {
-          reduceDriver
-              .withOutputSerializationConfiguration(getOutputSerializationConfiguration());
-        }
-
-        if (outputFormatClass != null) {
-          reduceDriver.withOutputFormat(outputFormatClass, inputFormatClass);
-        }
-
-        reduceOutputs.addAll(reduceDriver.run());
-      }
-
-      return reduceOutputs;
-    }
-  }
-
   @Override
   public List<Pair<K3, V3>> run() throws IOException {
     try {
@@ -251,6 +206,9 @@
       initDistributedCache();
       List<Pair<K2, V2>> mapOutputs = new ArrayList<Pair<K2, V2>>();
 
+      MapOutputShuffler<K2, V2> shuffler = new MapOutputShuffler<K2, V2>(
+          getConfiguration(), keyValueOrderComparator, keyGroupComparator);
+
       // run map component
       LOG.debug("Starting map phase with mapper: " + myMapper);
       mapOutputs.addAll(MapDriver.newMapDriver(myMapper)
@@ -258,18 +216,22 @@
           .withAll(inputList).withMapInputPath(getMapInputPath()).run());
 
       if (myCombiner != null) {
-        // User has specified a combiner. Run this and replace the mapper outputs
+        // User has specified a combiner. Run this and replace the mapper
+        // outputs
         // with the result of the combiner.
         LOG.debug("Starting combine phase with combiner: " + myCombiner);
-        mapOutputs = new ReducePhaseRunner<K2, V2>().runReduce(
-            shuffle(mapOutputs), myCombiner);
+        mapOutputs = new ReducePhaseRunner<K2, V2, K2, V2>(inputFormatClass,
+            getConfiguration(), counters,
+            getOutputSerializationConfiguration(), outputFormatClass)
+            .runReduce(shuffler.shuffle(mapOutputs), myCombiner);
       }
 
       // Run the reduce phase.
       LOG.debug("Starting reduce phase with reducer: " + myReducer);
 
-      return new ReducePhaseRunner<K3, V3>()
-          .runReduce(shuffle(mapOutputs),myReducer);
+      return new ReducePhaseRunner<K2, V2, K3, V3>(inputFormatClass,
+          getConfiguration(), counters, getOutputSerializationConfiguration(),
+          outputFormatClass).runReduce(shuffler.shuffle(mapOutputs), myReducer);
     } finally {
       cleanupDistributedCache();
     }
@@ -283,7 +245,7 @@
   /**
    * Returns a new MapReduceDriver without having to specify the generic types
    * on the right hand side of the object create statement.
-   *
+   * 
    * @return new MapReduceDriver
    */
   public static <K1, V1, K2, V2, K3, V3> MapReduceDriver<K1, V1, K2, V2, K3, V3> newMapReduceDriver() {
@@ -293,7 +255,7 @@
   /**
    * Returns a new MapReduceDriver without having to specify the generic types
    * on the right hand side of the object create statement.
-   *
+   * 
    * @param mapper
    *          passed to MapReduceDriver constructor
    * @param reducer
@@ -308,7 +270,7 @@
   /**
    * Returns a new MapReduceDriver without having to specify the generic types
    * on the right hand side of the object create statement.
-   *
+   * 
    * @param mapper
    *          passed to MapReduceDriver constructor
    * @param reducer
diff --git a/src/main/java/org/apache/hadoop/mrunit/MapReduceDriverBase.java b/src/main/java/org/apache/hadoop/mrunit/MapReduceDriverBase.java
index 25cc274..66545cc 100644
--- a/src/main/java/org/apache/hadoop/mrunit/MapReduceDriverBase.java
+++ b/src/main/java/org/apache/hadoop/mrunit/MapReduceDriverBase.java
@@ -17,15 +17,6 @@
  */
 package org.apache.hadoop.mrunit;
 
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.Comparator;
-import java.util.LinkedHashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Map.Entry;
-
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.fs.Path;
@@ -35,6 +26,11 @@
 import org.apache.hadoop.mrunit.types.Pair;
 import org.apache.hadoop.util.ReflectionUtils;
 
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+
 /**
  * Harness that allows you to test a Mapper and a Reducer instance together You
  * provide the input key and value that should be sent to the Mapper, and
@@ -46,7 +42,7 @@
  * pair, representing a single unit test.
  */
 public abstract class MapReduceDriverBase<K1, V1, K2, V2, K3, V3, T extends MapReduceDriverBase<K1, V1, K2, V2, K3, V3, T>>
-    extends TestDriver<K1, V1, K3, V3, T> {
+    extends TestDriver<K3, V3, T> {
 
   public static final Log LOG = LogFactory.getLog(MapReduceDriverBase.class);
 
@@ -62,7 +58,7 @@
 
   /**
    * Adds an input to send to the mapper
-   *
+   * 
    * @param key
    * @param val
    */
@@ -72,7 +68,7 @@
 
   /**
    * Adds an input to send to the Mapper
-   *
+   * 
    * @param input
    *          The (k, v) pair to add to the input list.
    */
@@ -82,7 +78,7 @@
 
   /**
    * Adds input to send to the mapper
-   *
+   * 
    * @param inputs
    *          List of (k, v) pairs to add to the input list
    */
@@ -95,7 +91,7 @@
   /**
    * Expects an input of the form "key \t val" Forces the Mapper input types to
    * Text.
-   *
+   * 
    * @param input
    *          A string of the form "key \t val". Trims any whitespace.
    * @deprecated No replacement due to lack of type safety and incompatibility
@@ -114,33 +110,31 @@
 
   /**
    * Identical to addInput() but returns self for fluent programming style
-   *
+   * 
    * @param key
    * @param val
    * @return this
    */
-  public T withInput(final K1 key,
-      final V1 val) {
+  public T withInput(final K1 key, final V1 val) {
     addInput(key, val);
     return thisAsMapReduceDriver();
   }
 
   /**
    * Identical to addInput() but returns self for fluent programming style
-   *
+   * 
    * @param input
    *          The (k, v) pair to add
    * @return this
    */
-  public T withInput(
-      final Pair<K1, V1> input) {
+  public T withInput(final Pair<K1, V1> input) {
     addInput(input);
     return thisAsMapReduceDriver();
   }
 
   /**
    * Identical to addInputFromString, but with a fluent programming style
-   *
+   * 
    * @param input
    *          A string of the form "key \t val". Trims any whitespace.
    * @return this
@@ -148,21 +142,19 @@
    *             with non Text Writables
    */
   @Deprecated
-  public T withInputFromString(
-      final String input) {
+  public T withInputFromString(final String input) {
     addInputFromString(input);
     return thisAsMapReduceDriver();
   }
 
   /**
    * Identical to addAll() but returns self for fluent programming style
-   *
+   * 
    * @param inputs
    *          List of (k, v) pairs to add
    * @return this
    */
-  public T withAll(
-      final List<Pair<K1, V1>> inputs) {
+  public T withAll(final List<Pair<K1, V1>> inputs) {
     addAll(inputs);
     return thisAsMapReduceDriver();
   }
@@ -175,7 +167,8 @@
   }
 
   /**
-   * @param mapInputPath Path which is to be passed to the mappers InputSplit
+   * @param mapInputPath
+   *          Path which is to be passed to the mappers InputSplit
    */
   public void setMapInputPath(Path mapInputPath) {
     this.mapInputPath = mapInputPath;
@@ -183,7 +176,7 @@
 
   /**
    * @param mapInputPath
-   *       The Path object which will be given to the mapper
+   *          The Path object which will be given to the mapper
    * @return this
    */
   public final T withMapInputPath(Path mapInputPath) {
@@ -203,8 +196,7 @@
     }
     if (driverReused()) {
       throw new IllegalStateException("Driver reuse not allowed");
-    }
-    else {
+    } else {
       setUsedOnceStatus();
     }
   }
@@ -213,93 +205,13 @@
   public abstract List<Pair<K3, V3>> run() throws IOException;
 
   /**
-   * Take the outputs from the Mapper, combine all values for the same key, and
-   * sort them by key.
-   *
-   * @param mapOutputs
-   *          An unordered list of (key, val) pairs from the mapper
-   * @return the sorted list of (key, list(val))'s to present to the reducer
-   */
-  public List<Pair<K2, List<V2>>> shuffle(final List<Pair<K2, V2>> mapOutputs) {
-
-    final Comparator<K2> keyOrderComparator;
-    final Comparator<K2> keyGroupComparator;
-
-    if (mapOutputs.isEmpty()) {
-      return Collections.emptyList();
-    }
-
-    // JobConf needs the map output key class to work out the
-    // comparator to use
-    JobConf conf = new JobConf(getConfiguration());
-    K2 firstKey = mapOutputs.get(0).getFirst();
-    conf.setMapOutputKeyClass(firstKey.getClass());
-
-    // get the ordering comparator or work out from conf
-    if (keyValueOrderComparator == null) {
-      keyOrderComparator = conf.getOutputKeyComparator();
-    } else {
-      keyOrderComparator = this.keyValueOrderComparator;
-    }
-
-    // get the grouping comparator or work out from conf
-    if (this.keyGroupComparator == null) {
-      keyGroupComparator = conf.getOutputValueGroupingComparator();
-    } else {
-      keyGroupComparator = this.keyGroupComparator;
-    }
-
-    // sort the map outputs according to their keys
-    Collections.sort(mapOutputs, new Comparator<Pair<K2, V2>>() {
-      public int compare(final Pair<K2, V2> o1, final Pair<K2, V2> o2) {
-        return keyOrderComparator.compare(o1.getFirst(), o2.getFirst());
-      }
-    });
-
-    // apply grouping comparator to create groups
-    final Map<K2, List<Pair<K2, V2>>> groupedByKey =
-        new LinkedHashMap<K2, List<Pair<K2, V2>>>();
-
-    List<Pair<K2, V2>> groupedKeyList = null;
-    Pair<K2,V2> previous = null;
-
-    for (final Pair<K2, V2> mapOutput : mapOutputs) {
-      if (previous == null || keyGroupComparator
-          .compare(previous.getFirst(), mapOutput.getFirst()) != 0) {
-        groupedKeyList = new ArrayList<Pair<K2, V2>>();
-        groupedByKey.put(mapOutput.getFirst(), groupedKeyList);
-      }
-      groupedKeyList.add(mapOutput);
-      previous = mapOutput;
-    }
-
-    // populate output list
-    final List<Pair<K2, List<V2>>> outputKeyValuesList = new ArrayList<Pair<K2, List<V2>>>();
-    for (final Entry<K2, List<Pair<K2, V2>>> groupedByKeyEntry :
-            groupedByKey.entrySet()) {
-
-      // create list to hold values for the grouped key
-      final List<V2> valuesList = new ArrayList<V2>();
-      for (final Pair<K2, V2> pair : groupedByKeyEntry.getValue()) {
-        valuesList.add(pair.getSecond());
-      }
-
-      // add key and values to output list
-      outputKeyValuesList.add(new Pair<K2, List<V2>>(
-          groupedByKeyEntry.getKey(), valuesList));
-    }
-
-    return outputKeyValuesList;
-  }
-
-  /**
    * Set the key grouping comparator, similar to calling the following API calls
    * but passing a real instance rather than just the class:
    * <UL>
    * <LI>pre 0.20.1 API: {@link JobConf#setOutputValueGroupingComparator(Class)}
    * <LI>0.20.1+ API: {@link Job#setGroupingComparatorClass(Class)}
    * </UL>
-   *
+   * 
    * @param groupingComparator
    */
   public void setKeyGroupingComparator(
@@ -315,7 +227,7 @@
    * <LI>pre 0.20.1 API: {@link JobConf#setOutputKeyComparatorClass(Class)}
    * <LI>0.20.1+ API: {@link Job#setSortComparatorClass(Class)}
    * </UL>
-   *
+   * 
    * @param orderComparator
    */
   public void setKeyOrderComparator(final RawComparator<K2> orderComparator) {
@@ -326,7 +238,7 @@
   /**
    * Identical to {@link #setKeyGroupingComparator(RawComparator)}, but with a
    * fluent programming style
-   *
+   * 
    * @param groupingComparator
    *          Comparator to use in the shuffle stage for key grouping
    * @return this
@@ -339,7 +251,7 @@
   /**
    * Identical to {@link #setKeyOrderComparator(RawComparator)}, but with a
    * fluent programming style
-   *
+   * 
    * @param orderComparator
    *          Comparator to use in the shuffle stage for key value ordering
    * @return this
diff --git a/src/main/java/org/apache/hadoop/mrunit/MultipleInputsMapReduceDriver.java b/src/main/java/org/apache/hadoop/mrunit/MultipleInputsMapReduceDriver.java
new file mode 100644
index 0000000..e155e08
--- /dev/null
+++ b/src/main/java/org/apache/hadoop/mrunit/MultipleInputsMapReduceDriver.java
@@ -0,0 +1,482 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mrunit;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.mapred.*;
+import org.apache.hadoop.mrunit.internal.counters.CounterWrapper;
+import org.apache.hadoop.mrunit.internal.driver.MultipleInputsMapReduceDriverBase;
+import org.apache.hadoop.mrunit.types.Pair;
+
+import java.io.IOException;
+import java.util.*;
+
+import static org.apache.hadoop.mrunit.internal.util.ArgumentChecker.returnNonNull;
+
+/**
+ * Harness that allows you to test multiple Mappers and a Reducer instance
+ * together (along with an optional combiner). You provide the input keys and
+ * values that should be sent to each Mapper, and outputs you expect to be sent
+ * by the Reducer to the collector for those inputs. By calling runTest(), the
+ * harness will deliver the inputs to the respective Mappers, feed the
+ * intermediate results to the Reducer (without checking them), and will check
+ * the Reducer's outputs against the expected results.
+ * 
+ * If a combiner is specified, it will run exactly once after all the Mappers
+ * and before the Reducer
+ * 
+ * @param <K1>
+ *          The common map output key type
+ * @param <V1>
+ *          The common map output value type
+ * @param <K2>
+ *          The reduce output key type
+ * @param <V2>
+ *          The reduce output value type
+ */
+public class MultipleInputsMapReduceDriver<K1, V1, K2, V2>
+    extends
+        MultipleInputsMapReduceDriverBase<Mapper, K1, V1, K2, V2, MultipleInputsMapReduceDriver<K1, V1, K2, V2>> {
+  public static final Log LOG = LogFactory
+      .getLog(MultipleInputsMapReduceDriver.class);
+
+  private Set<Mapper> mappers = new HashSet<Mapper>();
+
+  /**
+   * Add a mapper to use with this test driver
+   * 
+   * @param mapper
+   *          The mapper instance to add
+   * @param <K>
+   *          The input key type to the mapper
+   * @param <V>
+   *          The input value type to the mapper
+   */
+  public <K, V> void addMapper(final Mapper<K, V, K1, V1> mapper) {
+    this.mappers.add(returnNonNull(mapper));
+  }
+
+  /**
+   * Identical to addMapper but supports a fluent programming style
+   * 
+   * @param mapper
+   *          The mapper instance to add
+   * @param <K>
+   *          The input key type to the mapper
+   * @param <V>
+   *          The input value type to the mapper
+   * @return this
+   */
+  public <K, V> MultipleInputsMapReduceDriver<K1, V1, K2, V2> withMapper(
+      final Mapper<K, V, K1, V1> mapper) {
+    addMapper(mapper);
+    return this;
+  }
+
+  /**
+   * @return The Mapper instances being used by this test
+   */
+  public Collection<Mapper> getMappers() {
+    return mappers;
+  }
+
+  private Reducer<K1, V1, K1, V1> combiner;
+
+  /**
+   * Set the combiner to use with this test driver
+   * 
+   * @param combiner
+   *          The combiner instance to use
+   */
+  public void setCombiner(final Reducer<K1, V1, K1, V1> combiner) {
+    this.combiner = returnNonNull(combiner);
+  }
+
+  /**
+   * Identical to setCombiner but supports a fluent programming style
+   * 
+   * @param combiner
+   *          The combiner instance to use
+   * @return this
+   */
+  public MultipleInputsMapReduceDriver<K1, V1, K2, V2> withCombiner(
+      final Reducer<K1, V1, K1, V1> combiner) {
+    setCombiner(combiner);
+    return this;
+  }
+
+  /**
+   * @return The combiner instance being used by this test
+   */
+  public Reducer<K1, V1, K1, V1> getCombiner() {
+    return combiner;
+  }
+
+  private Reducer<K1, V1, K2, V2> reducer;
+
+  /**
+   * Set the reducer to use with this test driver
+   * 
+   * @param reducer
+   *          The reducer instance to use
+   */
+  public void setReducer(final Reducer<K1, V1, K2, V2> reducer) {
+    this.reducer = returnNonNull(reducer);
+  }
+
+  /**
+   * Identical to setReducer but supports a fluent programming style
+   * 
+   * @param reducer
+   *          The reducer instance to use
+   * @return this
+   */
+  public MultipleInputsMapReduceDriver<K1, V1, K2, V2> withReducer(
+      final Reducer<K1, V1, K2, V2> reducer) {
+    setReducer(reducer);
+    return this;
+  }
+
+  /**
+   * @return Get the reducer instance being used by this test
+   */
+  public Reducer<K1, V1, K2, V2> getReducer() {
+    return reducer;
+  }
+
+  private Counters counters;
+
+  /**
+   * @return The counters used in this test
+   */
+  public Counters getCounters() {
+    return counters;
+  }
+
+  /**
+   * Sets the counters object to use for this test
+   * 
+   * @param counters
+   *          The counters object to use
+   */
+  public void setCounters(Counters counters) {
+    this.counters = counters;
+    counterWrapper = new CounterWrapper(counters);
+  }
+
+  /**
+   * Identical to setCounters but supports a fluent programming style
+   * 
+   * @param counters
+   *          The counters object to use
+   * @return this
+   */
+  public MultipleInputsMapReduceDriver<K1, V1, K2, V2> withCounter(
+      Counters counters) {
+    setCounters(counters);
+    return this;
+  }
+
+  private Class<? extends OutputFormat> outputFormatClass;
+
+  /**
+   * Configure {@link Reducer} to output with a real {@link OutputFormat}.
+   * 
+   * @param outputFormatClass
+   *          The OutputFormat class
+   * @return this
+   */
+  public MultipleInputsMapReduceDriver<K1, V1, K2, V2> withOutputFormat(
+      final Class<? extends OutputFormat> outputFormatClass) {
+    this.outputFormatClass = returnNonNull(outputFormatClass);
+    return this;
+  }
+
+  private Class<? extends InputFormat> inputFormatClass;
+
+  /**
+   * Set the InputFormat
+   * 
+   * @param inputFormatClass
+   *          The InputFormat class
+   * @return this
+   */
+  public MultipleInputsMapReduceDriver<K1, V1, K2, V2> withInputFormat(
+      final Class<? extends InputFormat> inputFormatClass) {
+    this.inputFormatClass = returnNonNull(inputFormatClass);
+    return this;
+  }
+
+  /**
+   * Construct a driver with the specified Reducer. Note that a Combiner can be
+   * set separately.
+   * 
+   * @param reducer
+   *          The reducer to use
+   */
+  public MultipleInputsMapReduceDriver(Reducer<K1, V1, K2, V2> reducer) {
+    this();
+    this.reducer = reducer;
+  }
+
+  /**
+   * Construct a driver with the specified Combiner and Reducers
+   * 
+   * @param combiner
+   *          The combiner to use
+   * @param reducer
+   *          The reducer to use
+   */
+  public MultipleInputsMapReduceDriver(Reducer<K1, V1, K1, V1> combiner,
+                                       Reducer<K1, V1, K2, V2> reducer) {
+    this(reducer);
+    this.combiner = combiner;
+  }
+
+  /**
+   * Construct a driver without specifying a Combiner nor a Reducer. Note that
+   * these can be set with the appropriate set methods and that at least the
+   * Reducer must be set.
+   */
+  public MultipleInputsMapReduceDriver() {
+    setCounters(new Counters());
+  }
+
+  /**
+   * Static factory-style method to construct a driver instance with the
+   * specified Combiner and Reducer
+   * 
+   * @param combiner
+   *          The combiner to use
+   * @param reducer
+   *          The reducer to use
+   * @param <K1>
+   *          The common output key type of the mappers
+   * @param <V1>
+   *          The common output value type of the mappers
+   * @param <K2>
+   *          The output key type of the reducer
+   * @param <V2>
+   *          The output value type of the reducer
+   * @return this to support fluent programming style
+   */
+  public static <K1, V1, K2, V2> MultipleInputsMapReduceDriver<K1, V1, K2, V2> newMultipleInputMapReduceDriver(
+      final Reducer<K1, V1, K1, V1> combiner,
+      final Reducer<K1, V1, K2, V2> reducer) {
+    return new MultipleInputsMapReduceDriver<K1, V1, K2, V2>(combiner, reducer);
+  }
+
+  /**
+   * Static factory-style method to construct a driver instance with the
+   * specified Reducer
+   * 
+   * @param reducer
+   *          The reducer to use
+   * @param <K1>
+   *          The common output key type of the mappers
+   * @param <V1>
+   *          The common output value type of the mappers
+   * @param <K2>
+   *          The output key type of the reducer
+   * @param <V2>
+   *          The output value type of the reducer
+   * @return this to support fluent programming style
+   */
+  public static <K1, V1, K2, V2> MultipleInputsMapReduceDriver<K1, V1, K2, V2> newMultipleInputMapReduceDriver(
+      final Reducer<K1, V1, K2, V2> reducer) {
+    return new MultipleInputsMapReduceDriver<K1, V1, K2, V2>(reducer);
+  }
+
+  /**
+   * Static factory-style method to construct a driver instance without
+   * specifying a Combiner nor a Reducer. Note that these can be set separately
+   * by using the appropriate set (or with) methods and that at least a Reducer
+   * must be set
+   * 
+   * @param <K1>
+   *          The common output key type of the mappers
+   * @param <V1>
+   *          The common output value type of the mappers
+   * @param <K2>
+   *          The output key type of the reducer
+   * @param <V2>
+   *          The output value type of the reducer
+   * @return this to support fluent programming style
+   */
+  public static <K1, V1, K2, V2> MultipleInputsMapReduceDriver<K1, V1, K2, V2> newMultipleInputMapReduceDriver() {
+    return new MultipleInputsMapReduceDriver<K1, V1, K2, V2>();
+  }
+
+  /**
+   * Add the specified (key, val) pair to the specified mapper
+   * 
+   * @param mapper
+   *          The mapper to add the input pair to
+   * @param key
+   *          The key
+   * @param val
+   *          The value
+   * @param <K>
+   *          The type of the key
+   * @param <V>
+   *          The type of the value
+   */
+  public <K, V> void addInput(final Mapper<K, V, K1, V1> mapper, final K key,
+      final V val) {
+    super.addInput(mapper, key, val);
+  }
+
+  /**
+   * Add the specified input pair to the specified mapper
+   * 
+   * @param mapper
+   *          The mapper to add the input pair to
+   * @param input
+   *          The (k,v) pair to add
+   * @param <K>
+   *          The type of the key
+   * @param <V>
+   *          The type of the value
+   */
+  public <K, V> void addInput(final Mapper<K, V, K1, V1> mapper,
+      final Pair<K, V> input) {
+    super.addInput(mapper, input);
+  }
+
+  /**
+   * Add the specified input pairs to the specified mapper
+   * 
+   * @param mapper
+   *          The mapper to add the input pairs to
+   * @param inputs
+   *          The (k, v) pairs to add
+   * @param <K>
+   *          The type of the key
+   * @param <V>
+   *          The type of the value
+   */
+  public <K, V> void addAll(final Mapper<K, V, K1, V1> mapper,
+      final List<Pair<K, V>> inputs) {
+    super.addAll(mapper, inputs);
+  }
+
+  /**
+   * Identical to addInput but supports fluent programming style
+   * 
+   * @param mapper
+   *          The mapper to add the input pair to
+   * @param key
+   *          The key
+   * @param val
+   *          The value
+   * @param <K>
+   *          The type of the key
+   * @param <V>
+   *          The type of the value
+   * @return this
+   */
+  public <K, V> MultipleInputsMapReduceDriver<K1, V1, K2, V2> withInput(
+      final Mapper<K, V, K1, V1> mapper, final K key, final V val) {
+    return super.withInput(mapper, key, val);
+  }
+
+  /**
+   * Identical to addInput but supports fluent programming style
+   * 
+   * @param mapper
+   *          The mapper to add the input pairs to
+   * @param inputs
+   *          The (k, v) pairs to add
+   * @param <K>
+   *          The type of the key
+   * @param <V>
+   *          The type of the value
+   * @return this
+   */
+  public <K, V> MultipleInputsMapReduceDriver<K1, V1, K2, V2> withInput(
+      final Mapper<K, V, K1, V1> mapper, final Pair<K, V> input) {
+    return super.withInput(mapper, input);
+  }
+
+  /**
+   * Identical to addInput but supports fluent programming style
+   * 
+   * @param mapper
+   *          The mapper to add the input pairs to
+   * @param inputs
+   *          The (k, v) pairs to add
+   * @param <K>
+   *          The type of the key
+   * @param <V>
+   *          The type of the value
+   * @return this
+   */
+  public <K, V> MultipleInputsMapReduceDriver<K1, V1, K2, V2> withAll(
+      final Mapper<K, V, K1, V1> mapper, final List<Pair<K, V>> inputs) {
+    return super.withAll(mapper, inputs);
+  }
+
+  @Override
+  protected void preRunChecks(Set<Mapper> mappers, Object reducer) {
+    if (mappers.isEmpty()) {
+      throw new IllegalStateException("No mappers were provided");
+    }
+    super.preRunChecks(mappers, reducer);
+  }
+
+  @SuppressWarnings("unchecked")
+  @Override
+  public List<Pair<K2, V2>> run() throws IOException {
+    try {
+      preRunChecks(mappers, reducer);
+      initDistributedCache();
+
+      List<Pair<K1, V1>> outputs = new ArrayList<Pair<K1, V1>>();
+
+      MapOutputShuffler<K1, V1> shuffler = new MapOutputShuffler<K1, V1>(
+          getConfiguration(), keyValueOrderComparator, keyGroupComparator);
+
+      for (Mapper mapper : mappers) {
+        MapDriver mapDriver = MapDriver.newMapDriver(mapper);
+        mapDriver.setCounters(counters);
+        mapDriver.setConfiguration(getConfiguration());
+        mapDriver.addAll(inputs.get(mapper));
+        mapDriver.withMapInputPath(getMapInputPath(mapper));
+        outputs.addAll(mapDriver.run());
+      }
+
+      if (combiner != null) {
+        LOG.debug("Starting combine phase with combiner: " + combiner);
+        outputs = new ReducePhaseRunner<K1, V1, K1, V1>(inputFormatClass,
+            getConfiguration(), counters,
+            getOutputSerializationConfiguration(), outputFormatClass)
+            .runReduce(shuffler.shuffle(outputs), combiner);
+      }
+
+      LOG.debug("Starting reduce phase with reducer: " + reducer);
+
+      return new ReducePhaseRunner<K1, V1, K2, V2>(inputFormatClass,
+          getConfiguration(), counters, getOutputSerializationConfiguration(),
+          outputFormatClass).runReduce(shuffler.shuffle(outputs), reducer);
+    } finally {
+      cleanupDistributedCache();
+    }
+  }
+}
diff --git a/src/main/java/org/apache/hadoop/mrunit/PipelineMapReduceDriver.java b/src/main/java/org/apache/hadoop/mrunit/PipelineMapReduceDriver.java
index 9cf4e42..3c9d297 100644
--- a/src/main/java/org/apache/hadoop/mrunit/PipelineMapReduceDriver.java
+++ b/src/main/java/org/apache/hadoop/mrunit/PipelineMapReduceDriver.java
@@ -38,23 +38,23 @@
  * workflow, as well as a set of (key, value) pairs to pass in to the first
  * Mapper. You can also specify the outputs you expect to be sent to the final
  * Reducer in the pipeline.
- *
+ * 
  * By calling runTest(), the harness will deliver the input to the first Mapper,
  * feed the intermediate results to the first Reducer (without checking them),
  * and proceed to forward this data along to subsequent Mapper/Reducer jobs in
  * the pipeline until the final Reducer. The last Reducer's outputs are checked
  * against the expected results.
- *
+ * 
  * This is designed for slightly more complicated integration tests than the
  * MapReduceDriver, which is for smaller unit tests.
- *
+ * 
  * (K1, V1) in the type signature refer to the types associated with the inputs
  * to the first Mapper. (K2, V2) refer to the types associated with the final
  * Reducer's output. No intermediate types are specified.
  */
 @SuppressWarnings("rawtypes")
 public class PipelineMapReduceDriver<K1, V1, K2, V2> extends
-    TestDriver<K1, V1, K2, V2, PipelineMapReduceDriver<K1, V1, K2, V2>> {
+    TestDriver<K2, V2, PipelineMapReduceDriver<K1, V1, K2, V2>> {
 
   public static final Log LOG = LogFactory
       .getLog(PipelineMapReduceDriver.class);
@@ -83,7 +83,7 @@
 
   /**
    * Sets the counters object to use for this test.
-   *
+   * 
    * @param ctrs
    *          The counters object to use.
    */
@@ -102,7 +102,7 @@
   /**
    * Add a Mapper and Reducer instance to the pipeline to use with this test
    * driver
-   *
+   * 
    * @param m
    *          The Mapper instance to add to the pipeline
    * @param r
@@ -115,7 +115,7 @@
   /**
    * Add a Mapper and Reducer instance to the pipeline to use with this test
    * driver
-   *
+   * 
    * @param p
    *          The Mapper and Reducer instances to add to the pipeline
    */
@@ -126,7 +126,7 @@
   /**
    * Add a Mapper and Reducer instance to the pipeline to use with this test
    * driver using fluent style
-   *
+   * 
    * @param m
    *          The Mapper instance to use
    * @param r
@@ -141,7 +141,7 @@
   /**
    * Add a Mapper and Reducer instance to the pipeline to use with this test
    * driver using fluent style
-   *
+   * 
    * @param p
    *          The Mapper and Reducer instances to add to the pipeline
    */
@@ -160,7 +160,7 @@
 
   /**
    * Adds an input to send to the mapper
-   *
+   * 
    * @param key
    * @param val
    */
@@ -170,7 +170,7 @@
 
   /**
    * Adds list of inputs to send to the mapper
-   *
+   * 
    * @param inputs
    *          list of (K, V) pairs
    */
@@ -182,7 +182,7 @@
 
   /**
    * Identical to addInput() but returns self for fluent programming style
-   *
+   * 
    * @param key
    * @param val
    * @return this
@@ -195,7 +195,7 @@
 
   /**
    * Adds an input to send to the Mapper
-   *
+   * 
    * @param input
    *          The (k, v) pair to add to the input list.
    */
@@ -205,7 +205,7 @@
 
   /**
    * Identical to addInput() but returns self for fluent programming style
-   *
+   * 
    * @param input
    *          The (k, v) pair to add
    * @return this
@@ -219,7 +219,7 @@
   /**
    * Expects an input of the form "key \t val" Forces the Mapper input types to
    * Text.
-   *
+   * 
    * @param input
    *          A string of the form "key \t val". Trims any whitespace.
    * @deprecated No replacement due to lack of type safety and incompatibility
@@ -233,7 +233,7 @@
 
   /**
    * Identical to addInputFromString, but with a fluent programming style
-   *
+   * 
    * @param input
    *          A string of the form "key \t val". Trims any whitespace.
    * @return this
@@ -249,8 +249,9 @@
 
   /**
    * Identical to addAll() but returns self for fluent programming style
-   *
-   * @param inputRecords input key/value pairs
+   * 
+   * @param inputRecords
+   *          input key/value pairs
    * @return this
    */
   public PipelineMapReduceDriver<K1, V1, K2, V2> withAll(
@@ -267,7 +268,8 @@
   }
 
   /**
-   * @param mapInputPath Path which is to be passed to the mappers InputSplit
+   * @param mapInputPath
+   *          Path which is to be passed to the mappers InputSplit
    */
   public void setMapInputPath(Path mapInputPath) {
     this.mapInputPath = mapInputPath;
@@ -275,10 +277,11 @@
 
   /**
    * @param mapInputPath
-   *       The Path object which will be given to the mapper
+   *          The Path object which will be given to the mapper
    * @return this
    */
-  public final PipelineMapReduceDriver<K1, V1, K2, V2> withMapInputPath(Path mapInputPath) {
+  public final PipelineMapReduceDriver<K1, V1, K2, V2> withMapInputPath(
+      Path mapInputPath) {
     setMapInputPath(mapInputPath);
     return this;
   }
@@ -297,8 +300,7 @@
     }
     if (driverReused()) {
       throw new IllegalStateException("Driver reuse not allowed");
-    }
-    else {
+    } else {
       setUsedOnceStatus();
     }
 
@@ -331,7 +333,7 @@
   /**
    * Returns a new PipelineMapReduceDriver without having to specify the generic
    * types on the right hand side of the object create statement.
-   *
+   * 
    * @return new PipelineMapReduceDriver
    */
   public static <K1, V1, K2, V2> PipelineMapReduceDriver<K1, V1, K2, V2> newPipelineMapReduceDriver() {
@@ -341,7 +343,7 @@
   /**
    * Returns a new PipelineMapReduceDriver without having to specify the generic
    * types on the right hand side of the object create statement.
-   *
+   * 
    * @param pipeline
    *          passed to PipelineMapReduceDriver constructor
    * @return new PipelineMapReduceDriver
diff --git a/src/main/java/org/apache/hadoop/mrunit/ReduceDriverBase.java b/src/main/java/org/apache/hadoop/mrunit/ReduceDriverBase.java
index c66087f..e24b6dd 100644
--- a/src/main/java/org/apache/hadoop/mrunit/ReduceDriverBase.java
+++ b/src/main/java/org/apache/hadoop/mrunit/ReduceDriverBase.java
@@ -17,17 +17,17 @@
  */
 package org.apache.hadoop.mrunit;
 
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
-
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mrunit.internal.io.Serialization;
 import org.apache.hadoop.mrunit.internal.output.MockOutputCreator;
 import org.apache.hadoop.mrunit.types.Pair;
 
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
 /**
  * Harness that allows you to test a Reducer instance. You provide a key and a
  * set of intermediate values for that key that represent inputs that should be
@@ -37,7 +37,7 @@
  * expected results.
  */
 public abstract class ReduceDriverBase<K1, V1, K2, V2, T extends ReduceDriverBase<K1, V1, K2, V2, T>>
-    extends TestDriver<K1, V1, K2, V2, T> {
+    extends TestDriver<K2, V2, T> {
 
   protected List<Pair<K1, List<V1>>> inputs = new ArrayList<Pair<K1, List<V1>>>();
   @Deprecated
@@ -53,7 +53,7 @@
 
   /**
    * Returns a list of values.
-   *
+   * 
    * @return List of values
    * @deprecated MRUNIT-64. Moved to list implementation to support multiple
    *             input (k, v*)*. Replaced by {@link #getInputValues(Object)}
@@ -65,7 +65,7 @@
 
   /**
    * Returns a list of values for the given key
-   *
+   * 
    * @param key
    * @return List for the given key, or null if key does not exist
    */
@@ -80,7 +80,7 @@
 
   /**
    * Sets the input key to send to the Reducer
-   *
+   * 
    * @deprecated MRUNIT-64. Moved to list implementation to support multiple
    *             input (k, v*)*. Replaced by {@link #setInput},
    *             {@link #addInput}, and {@link #addAll}
@@ -92,7 +92,7 @@
 
   /**
    * adds an input value to send to the reducer
-   *
+   * 
    * @param val
    * @deprecated MRUNIT-64. Moved to list implementation to support multiple
    *             input (k, v*)*. Replaced by {@link #setInput},
@@ -105,7 +105,7 @@
 
   /**
    * Sets the input values to send to the reducer; overwrites existing ones
-   *
+   * 
    * @param values
    * @deprecated MRUNIT-64. Moved to list implementation to support multiple
    *             input (k, v*)*. Replaced by {@link #setInput},
@@ -119,7 +119,7 @@
 
   /**
    * Adds a set of input values to send to the reducer
-   *
+   * 
    * @param values
    * @deprecated MRUNIT-64. Moved to list implementation to support multiple
    *             input (k, v*)*. Replaced by {@link #setInput},
@@ -134,7 +134,7 @@
 
   /**
    * Sets the input to send to the reducer
-   *
+   * 
    * @param key
    * @param values
    */
@@ -155,7 +155,7 @@
 
   /**
    * Add input (K, V*) to send to the Reducer
-   *
+   * 
    * @param key
    *          The key too add
    * @param values
@@ -173,7 +173,7 @@
 
   /**
    * Add input (K, V*) to send to the Reducer
-   *
+   * 
    * @param input
    *          input pair
    */
@@ -183,7 +183,7 @@
 
   /**
    * Adds input to send to the Reducer
-   *
+   * 
    * @param inputs
    *          list of (K, V*) pairs
    */
@@ -196,7 +196,7 @@
   /**
    * Expects an input of the form "key \t val, val, val..." Forces the Reducer
    * input types to Text.
-   *
+   * 
    * @param input
    *          A string of the form "key \t val,val,val". Trims any whitespace.
    * @deprecated No replacement due to lack of type safety and incompatibility
@@ -218,7 +218,7 @@
 
   /**
    * Identical to setInputKey() but with fluent programming style
-   *
+   * 
    * @return this
    * @deprecated MRUNIT-64. Moved to list implementation to support multiple
    *             input (k, v*)*. Replaced by {@link #withInput(Object, List)},
@@ -232,7 +232,7 @@
 
   /**
    * Identical to addInputValue() but with fluent programming style
-   *
+   * 
    * @param val
    * @return this
    * @deprecated MRUNIT-64. Moved to list implementation to support multiple
@@ -247,7 +247,7 @@
 
   /**
    * Identical to addInputValues() but with fluent programming style
-   *
+   * 
    * @param values
    * @return this
    * @deprecated MRUNIT-64. Moved to list implementation to support multiple
@@ -262,18 +262,17 @@
 
   /**
    * Identical to setInput() but returns self for fluent programming style
-   *
+   * 
    * @return this
    */
-  public T withInput(final K1 key,
-      final List<V1> values) {
+  public T withInput(final K1 key, final List<V1> values) {
     addInput(key, values);
     return thisAsReduceDriver();
   }
 
   /**
    * Identical to setInput, but with a fluent programming style
-   *
+   * 
    * @param input
    *          A string of the form "key \t val". Trims any whitespace.
    * @return this
@@ -288,7 +287,7 @@
 
   /**
    * Identical to addInput() but returns self for fluent programming style
-   *
+   * 
    * @param input
    * @return this
    */
@@ -299,12 +298,11 @@
 
   /**
    * Identical to addAll() but returns self for fluent programming style
-   *
+   * 
    * @param inputs
    * @return this
    */
-  public T withAll(
-      final List<Pair<K1, List<V1>>> inputs) {
+  public T withAll(final List<Pair<K1, List<V1>>> inputs) {
     addAll(inputs);
     return thisAsReduceDriver();
   }
@@ -327,8 +325,7 @@
     }
     if (driverReused()) {
       throw new IllegalStateException("Driver reuse not allowed");
-    }
-    else {
+    } else {
       setUsedOnceStatus();
     }
   }
diff --git a/src/main/java/org/apache/hadoop/mrunit/ReducePhaseRunner.java b/src/main/java/org/apache/hadoop/mrunit/ReducePhaseRunner.java
new file mode 100644
index 0000000..76da743
--- /dev/null
+++ b/src/main/java/org/apache/hadoop/mrunit/ReducePhaseRunner.java
@@ -0,0 +1,92 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mrunit;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.mapred.Counters;
+import org.apache.hadoop.mapred.InputFormat;
+import org.apache.hadoop.mapred.OutputFormat;
+import org.apache.hadoop.mapred.Reducer;
+import org.apache.hadoop.mrunit.types.Pair;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * The class to manage starting the reduce phase is used for type genericity
+ * reasons. This class is used in the run() method.
+ */
+class ReducePhaseRunner<INKEY, INVAL, OUTKEY, OUTVAL> {
+  public static final Log LOG = LogFactory.getLog(ReducePhaseRunner.class);
+
+  private final Configuration configuration;
+  private final Counters counters;
+  private Configuration outputSerializationConfiguration;
+  private Class<? extends OutputFormat> outputFormatClass;
+  private Class<? extends InputFormat> inputFormatClass;
+
+  ReducePhaseRunner(Class<? extends InputFormat> inputFormatClass,
+      Configuration configuration, Counters counters,
+      Configuration outputSerializationConfiguration,
+      Class<? extends OutputFormat> outputFormatClass) {
+    this.inputFormatClass = inputFormatClass;
+    this.configuration = configuration;
+    this.counters = counters;
+    this.outputSerializationConfiguration = outputSerializationConfiguration;
+    this.outputFormatClass = outputFormatClass;
+  }
+
+  public List<Pair<OUTKEY, OUTVAL>> runReduce(
+      final List<Pair<INKEY, List<INVAL>>> inputs,
+      final Reducer<INKEY, INVAL, OUTKEY, OUTVAL> reducer) throws IOException {
+
+    final List<Pair<OUTKEY, OUTVAL>> reduceOutputs = new ArrayList<Pair<OUTKEY, OUTVAL>>();
+
+    if (!inputs.isEmpty()) {
+      if (LOG.isDebugEnabled()) {
+        final StringBuilder sb = new StringBuilder();
+        for (Pair<INKEY, List<INVAL>> input : inputs) {
+          TestDriver.formatValueList(input.getSecond(), sb);
+          LOG.debug("Reducing input (" + input.getFirst() + ", " + sb + ")");
+          sb.delete(0, sb.length());
+        }
+      }
+
+      final ReduceDriver<INKEY, INVAL, OUTKEY, OUTVAL> reduceDriver = ReduceDriver
+          .newReduceDriver(reducer).withCounters(counters)
+          .withConfiguration(configuration).withAll(inputs);
+
+      if (outputSerializationConfiguration != null) {
+        reduceDriver
+            .withOutputSerializationConfiguration(outputSerializationConfiguration);
+      }
+
+      if (outputFormatClass != null) {
+        reduceDriver.withOutputFormat(outputFormatClass, inputFormatClass);
+      }
+
+      reduceOutputs.addAll(reduceDriver.run());
+    }
+
+    return reduceOutputs;
+  }
+}
\ No newline at end of file
diff --git a/src/main/java/org/apache/hadoop/mrunit/TestDriver.java b/src/main/java/org/apache/hadoop/mrunit/TestDriver.java
index be4d67e..a2d9536 100644
--- a/src/main/java/org/apache/hadoop/mrunit/TestDriver.java
+++ b/src/main/java/org/apache/hadoop/mrunit/TestDriver.java
@@ -17,14 +17,6 @@
  */
 package org.apache.hadoop.mrunit;
 
-import static org.apache.hadoop.mrunit.internal.util.ArgumentChecker.returnNonNull;
-
-import java.io.File;
-import java.io.IOException;
-import java.net.URI;
-import java.util.*;
-
-import com.google.common.collect.Lists;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.conf.Configuration;
@@ -41,31 +33,38 @@
 import org.apache.hadoop.mrunit.internal.util.StringUtils;
 import org.apache.hadoop.mrunit.types.Pair;
 
-public abstract class TestDriver<K1, V1, K2, V2, T extends TestDriver<K1, V1, K2, V2, T>> {
+import java.io.File;
+import java.io.IOException;
+import java.net.URI;
+import java.util.*;
+
+import static org.apache.hadoop.mrunit.internal.util.ArgumentChecker.returnNonNull;
+
+public abstract class TestDriver<K, V, T extends TestDriver<K, V, T>> {
 
   public static final Log LOG = LogFactory.getLog(TestDriver.class);
 
-  protected List<Pair<K2, V2>> expectedOutputs;
+  protected List<Pair<K, V>> expectedOutputs;
 
   private boolean strictCountersChecking = false;
   protected List<Pair<Enum<?>, Long>> expectedEnumCounters;
   protected List<Pair<Pair<String, String>, Long>> expectedStringCounters;
   /**
-   * Configuration object, do not use directly, always use the
-   * the getter as it lazily creates the object in the case
-   * the setConfiguration() method will be used by the user.
+   * Configuration object, do not use directly, always use the the getter as it
+   * lazily creates the object in the case the setConfiguration() method will be
+   * used by the user.
    */
   private Configuration configuration;
   /**
-   * Serialization object, do not use directly, always use the
-   * the getter as it lazily creates the object in the case
-   * the setConfiguration() method will be used by the user.
+   * Serialization object, do not use directly, always use the the getter as it
+   * lazily creates the object in the case the setConfiguration() method will be
+   * used by the user.
    */
   private Serialization serialization;
 
   private Configuration outputSerializationConfiguration;
-  private Comparator<K2> keyComparator;
-  private Comparator<V2> valueComparator;
+  private Comparator<K> keyComparator;
+  private Comparator<V> valueComparator;
   private File tmpDistCacheDir;
   protected CounterWrapper counterWrapper;
   protected MockMultipleOutputs mos;
@@ -73,18 +72,18 @@
   protected Map<String, List<Pair<?, ?>>> expectedPathOutputs;
   private boolean hasRun = false;
 
-
   public TestDriver() {
-    expectedOutputs = new ArrayList<Pair<K2, V2>>();
+    expectedOutputs = new ArrayList<Pair<K, V>>();
     expectedEnumCounters = new ArrayList<Pair<Enum<?>, Long>>();
     expectedStringCounters = new ArrayList<Pair<Pair<String, String>, Long>>();
-    expectedMultipleOutputs = new HashMap<String, List<Pair<?, ? >>>();
-	expectedPathOutputs = new HashMap<String, List<Pair<?, ?>>>();
+    expectedMultipleOutputs = new HashMap<String, List<Pair<?, ?>>>();
+    expectedPathOutputs = new HashMap<String, List<Pair<?, ?>>>();
   }
 
   /**
    * Check to see if this driver is being reused
-   * @return  boolean - true if run() has been called more than once
+   * 
+   * @return boolean - true if run() has been called more than once
    */
   protected boolean driverReused() {
     return this.hasRun;
@@ -96,74 +95,76 @@
   protected void setUsedOnceStatus() {
     this.hasRun = true;
   }
+
   /**
    * Adds output (k, v)* pairs we expect
-   *
+   * 
    * @param outputRecords
    *          The (k, v)* pairs to add
    */
-  public void addAllOutput(final List<Pair<K2, V2>> outputRecords) {
-    for (Pair<K2, V2> output : outputRecords) {
+  public void addAllOutput(final List<Pair<K, V>> outputRecords) {
+    for (Pair<K, V> output : outputRecords) {
       addOutput(output);
     }
   }
 
   /**
    * Functions like addAllOutput() but returns self for fluent programming style
-   *
+   * 
    * @param outputRecords
    * @return this
    */
-  public T withAllOutput(
-      final List<Pair<K2, V2>> outputRecords) {
+  public T withAllOutput(final List<Pair<K, V>> outputRecords) {
     addAllOutput(outputRecords);
     return thisAsTestDriver();
   }
 
   /**
    * Adds an output (k, v) pair we expect
-   *
+   * 
    * @param outputRecord
    *          The (k, v) pair to add
    */
-  public void addOutput(final Pair<K2, V2> outputRecord) {
+  public void addOutput(final Pair<K, V> outputRecord) {
     addOutput(outputRecord.getFirst(), outputRecord.getSecond());
   }
 
   /**
    * Adds a (k, v) pair we expect as output
-   * @param key the key
-   * @param val the value
+   * 
+   * @param key
+   *          the key
+   * @param val
+   *          the value
    */
-  public void addOutput(final K2 key, final V2 val) {
+  public void addOutput(final K key, final V val) {
     expectedOutputs.add(copyPair(key, val));
   }
 
   /**
    * Works like addOutput(), but returns self for fluent style
-   *
+   * 
    * @param outputRecord
    * @return this
    */
-  public T withOutput(final Pair<K2, V2> outputRecord) {
+  public T withOutput(final Pair<K, V> outputRecord) {
     addOutput(outputRecord);
     return thisAsTestDriver();
   }
 
   /**
    * Works like addOutput() but returns self for fluent programming style
-   *
+   * 
    * @return this
    */
-  public T withOutput(final K2 key, final V2 val) {
+  public T withOutput(final K key, final V val) {
     addOutput(key, val);
     return thisAsTestDriver();
   }
 
   /**
-   * Expects an input of the form "key \t val" Forces the output types to
-   * Text.
-   *
+   * Expects an input of the form "key \t val" Forces the output types to Text.
+   * 
    * @param output
    *          A string of the form "key \t val". Trims any whitespace.
    * @deprecated No replacement due to lack of type safety and incompatibility
@@ -172,12 +173,12 @@
   @Deprecated
   @SuppressWarnings("unchecked")
   public void addOutputFromString(final String output) {
-    addOutput((Pair<K2, V2>) parseTabbedPair(output));
+    addOutput((Pair<K, V>) parseTabbedPair(output));
   }
 
   /**
    * Identical to addOutputFromString, but with a fluent programming style
-   *
+   * 
    * @param output
    *          A string of the form "key \t val". Trims any whitespace.
    * @return this
@@ -193,7 +194,7 @@
   /**
    * @return the list of (k, v) pairs expected as output from this driver
    */
-  public List<Pair<K2, V2>> getExpectedOutputs() {
+  public List<Pair<K, V>> getExpectedOutputs() {
     return expectedOutputs;
   }
 
@@ -233,22 +234,21 @@
 
   /**
    * Register expected enumeration based counter value
-   *
+   * 
    * @param e
    *          Enumeration based counter
    * @param expectedValue
    *          Expected value
    * @return this
    */
-  public T withCounter(final Enum<?> e,
-      final long expectedValue) {
+  public T withCounter(final Enum<?> e, final long expectedValue) {
     expectedEnumCounters.add(new Pair<Enum<?>, Long>(e, expectedValue));
     return thisAsTestDriver();
   }
 
   /**
    * Register expected name based counter value
-   *
+   * 
    * @param group
    *          Counter group
    * @param name
@@ -257,8 +257,8 @@
    *          Expected value
    * @return this
    */
-  public T withCounter(final String group,
-      final String name, final long expectedValue) {
+  public T withCounter(final String group, final String name,
+      final long expectedValue) {
     expectedStringCounters.add(new Pair<Pair<String, String>, Long>(
         new Pair<String, String>(group, name), expectedValue));
     return thisAsTestDriver();
@@ -268,7 +268,7 @@
    * Change counter checking. After this method is called, the test will fail if
    * an actual counter is not matched by an expected counter. By default, the
    * test only check that every expected counter is there.
-   *
+   * 
    * This mode allows you to ensure that no unexpected counters has been
    * declared.
    */
@@ -282,7 +282,7 @@
    *         reducer associated with the driver
    */
   public Configuration getConfiguration() {
-    if(configuration == null) {
+    if (configuration == null) {
       configuration = new Configuration();
     }
     return configuration;
@@ -291,14 +291,14 @@
   /**
    * @return The comparator for output keys or null of none has been set
    */
-  public Comparator<K2> getKeyComparator() {
+  public Comparator<K> getKeyComparator() {
     return this.keyComparator;
   }
 
   /**
    * @return The comparator for output values or null of none has been set
    */
-  public Comparator<V2> getValueComparator() {
+  public Comparator<V> getValueComparator() {
     return this.valueComparator;
   }
 
@@ -306,11 +306,11 @@
    * @param configuration
    *          The configuration object that will given to the mapper and/or
    *          reducer associated with the driver. This method should only be
-   *          called directly after the constructor as the internal state
-   *          of the driver depends on the configuration object
-   * @deprecated
-   *          Use getConfiguration() to set configuration items as opposed to
-   *          overriding the entire configuration object as it's used internally.
+   *          called directly after the constructor as the internal state of the
+   *          driver depends on the configuration object
+   * @deprecated Use getConfiguration() to set configuration items as opposed to
+   *             overriding the entire configuration object as it's used
+   *             internally.
    */
   @Deprecated
   public void setConfiguration(final Configuration configuration) {
@@ -323,14 +323,13 @@
    *          with the driver. This method should only be called directly after
    *          the constructor as the internal state of the driver depends on the
    *          configuration object
-   * @deprecated
-   *          Use getConfiguration() to set configuration items as opposed to
-   *          overriding the entire configuration object as it's used internally.
+   * @deprecated Use getConfiguration() to set configuration items as opposed to
+   *             overriding the entire configuration object as it's used
+   *             internally.
    * @return this object for fluent coding
    */
   @Deprecated
-  public T withConfiguration(
-      final Configuration configuration) {
+  public T withConfiguration(final Configuration configuration) {
     setConfiguration(configuration);
     return thisAsTestDriver();
   }
@@ -339,7 +338,7 @@
    * Get the {@link Configuration} to use when copying output for use with run*
    * methods or for the InputFormat when reading output back in when setting a
    * real OutputFormat.
-   *
+   * 
    * @return outputSerializationConfiguration, null when no
    *         outputSerializationConfiguration is set
    */
@@ -353,7 +352,7 @@
    * real OutputFormat. When this configuration is not set, MRUnit will use the
    * configuration set with {@link #withConfiguration(Configuration)} or
    * {@link #setConfiguration(Configuration)}
-   *
+   * 
    * @param configuration
    */
   public void setOutputSerializationConfiguration(
@@ -367,22 +366,21 @@
    * real OutputFormat. When this configuration is not set, MRUnit will use the
    * configuration set with {@link #withConfiguration(Configuration)} or
    * {@link #setConfiguration(Configuration)}
-   *
+   * 
    * @param configuration
    * @return this for fluent style
    */
-  public T withOutputSerializationConfiguration(
-      Configuration configuration) {
+  public T withOutputSerializationConfiguration(Configuration configuration) {
     setOutputSerializationConfiguration(configuration);
     return thisAsTestDriver();
   }
 
   /**
-   * Adds a file to be put on the distributed cache.
-   * The path may be relative and will try to be resolved from
-   * the classpath of the test.
-   *
-   * @param path path to the file
+   * Adds a file to be put on the distributed cache. The path may be relative
+   * and will try to be resolved from the classpath of the test.
+   * 
+   * @param path
+   *          path to the file
    */
   public void addCacheFile(String path) {
     addCacheFile(DistCacheUtils.findResource(path));
@@ -390,7 +388,9 @@
 
   /**
    * Adds a file to be put on the distributed cache.
-   * @param uri uri of the file
+   * 
+   * @param uri
+   *          uri of the file
    */
   public void addCacheFile(URI uri) {
     DistributedCache.addCacheFile(uri, getConfiguration());
@@ -398,7 +398,9 @@
 
   /**
    * Set the list of files to put on the distributed cache
-   * @param files list of URIs
+   * 
+   * @param files
+   *          list of URIs
    */
   public void setCacheFiles(URI[] files) {
     DistributedCache.setCacheFiles(files, getConfiguration());
@@ -406,26 +408,30 @@
 
   /**
    * Set the output key comparator
-   * @param keyComparator the key comparator
+   * 
+   * @param keyComparator
+   *          the key comparator
    */
-  public void setKeyComparator(Comparator<K2> keyComparator) {
+  public void setKeyComparator(Comparator<K> keyComparator) {
     this.keyComparator = keyComparator;
   }
 
   /**
    * Set the output value comparator
-   * @param valueComparator the value comparator
+   * 
+   * @param valueComparator
+   *          the value comparator
    */
-  public void setValueComparator(Comparator<V2> valueComparator) {
+  public void setValueComparator(Comparator<V> valueComparator) {
     this.valueComparator = valueComparator;
   }
 
   /**
-   * Adds an archive to be put on the distributed cache.
-   * The path may be relative and will try to be resolved from
-   * the classpath of the test.
-   *
-   * @param path path to the archive
+   * Adds an archive to be put on the distributed cache. The path may be
+   * relative and will try to be resolved from the classpath of the test.
+   * 
+   * @param path
+   *          path to the archive
    */
   public void addCacheArchive(String path) {
     addCacheArchive(DistCacheUtils.findResource(path));
@@ -433,7 +439,9 @@
 
   /**
    * Adds an archive to be put on the distributed cache.
-   * @param uri uri of the archive
+   * 
+   * @param uri
+   *          uri of the archive
    */
   public void addCacheArchive(URI uri) {
     DistributedCache.addCacheArchive(uri, getConfiguration());
@@ -441,18 +449,20 @@
 
   /**
    * Set the list of archives to put on the distributed cache
-   * @param archives list of URIs
+   * 
+   * @param archives
+   *          list of URIs
    */
   public void setCacheArchives(URI[] archives) {
     DistributedCache.setCacheArchives(archives, getConfiguration());
   }
 
   /**
-   * Adds a file to be put on the distributed cache.
-   * The path may be relative and will try to be resolved from
-   * the classpath of the test.
-   *
-   * @param file path to the file
+   * Adds a file to be put on the distributed cache. The path may be relative
+   * and will try to be resolved from the classpath of the test.
+   * 
+   * @param file
+   *          path to the file
    * @return the driver
    */
   public T withCacheFile(String file) {
@@ -462,7 +472,9 @@
 
   /**
    * Adds a file to be put on the distributed cache.
-   * @param file uri of the file
+   * 
+   * @param file
+   *          uri of the file
    * @return the driver
    */
   public T withCacheFile(URI file) {
@@ -471,11 +483,11 @@
   }
 
   /**
-   * Adds an archive to be put on the distributed cache.
-   * The path may be relative and will try to be resolved from
-   * the classpath of the test.
-   *
-   * @param archive path to the archive
+   * Adds an archive to be put on the distributed cache. The path may be
+   * relative and will try to be resolved from the classpath of the test.
+   * 
+   * @param archive
+   *          path to the archive
    * @return the driver
    */
   public T withCacheArchive(String archive) {
@@ -485,7 +497,9 @@
 
   /**
    * Adds an archive to be put on the distributed cache.
-   * @param archive uri of the archive
+   * 
+   * @param archive
+   *          uri of the archive
    * @return the driver
    */
   public T withCacheArchive(URI archive) {
@@ -496,14 +510,15 @@
   /**
    * Runs the test but returns the result set instead of validating it (ignores
    * any addOutput(), etc calls made before this).
-   *
+   * 
    * Also optionally performs counter validation.
-   *
-   * @param validateCounters whether to run automatic counter validation
+   * 
+   * @param validateCounters
+   *          whether to run automatic counter validation
    * @return the list of (k, v) pairs returned as output from the test
    */
-  public List<Pair<K2, V2>> run(boolean validateCounters) throws IOException {
-    final List<Pair<K2, V2>> outputs = run();
+  public List<Pair<K, V>> run(boolean validateCounters) throws IOException {
+    final List<Pair<K, V>> outputs = run();
     if (validateCounters) {
       validate(counterWrapper);
     }
@@ -511,20 +526,20 @@
   }
 
   private Serialization getSerialization() {
-    if(serialization == null) {
+    if (serialization == null) {
       serialization = new Serialization(getConfiguration());
     }
     return serialization;
   }
 
   /**
-   * Initialises the test distributed cache if required. This
-   * process is referred to as "localizing" by Hadoop, but since
-   * this is a unit test all files/archives are already local.
-   *
-   * Cached files are not moved but cached archives are extracted
-   * into a temporary directory.
-   *
+   * Initialises the test distributed cache if required. This process is
+   * referred to as "localizing" by Hadoop, but since this is a unit test all
+   * files/archives are already local.
+   * 
+   * Cached files are not moved but cached archives are extracted into a
+   * temporary directory.
+   * 
    * @throws IOException
    */
   protected void initDistributedCache() throws IOException {
@@ -539,7 +554,7 @@
     List<Path> localFiles = new ArrayList<Path>();
 
     if (DistributedCache.getCacheFiles(conf) != null) {
-      for (URI uri: DistributedCache.getCacheFiles(conf)) {
+      for (URI uri : DistributedCache.getCacheFiles(conf)) {
         Path filePath = new Path(uri.getPath());
         localFiles.add(filePath);
       }
@@ -549,13 +564,13 @@
       }
     }
     if (DistributedCache.getCacheArchives(conf) != null) {
-      for (URI uri: DistributedCache.getCacheArchives(conf)) {
+      for (URI uri : DistributedCache.getCacheArchives(conf)) {
         Path archivePath = new Path(uri.getPath());
         if (tmpDistCacheDir == null) {
           tmpDistCacheDir = DistCacheUtils.createTempDirectory();
         }
-        localArchives.add(DistCacheUtils.extractArchiveToTemp(
-            archivePath, tmpDistCacheDir));
+        localArchives.add(DistCacheUtils.extractArchiveToTemp(archivePath,
+            tmpDistCacheDir));
       }
       if (!localArchives.isEmpty()) {
         DistCacheUtils.addLocalArchives(conf,
@@ -565,28 +580,28 @@
   }
 
   /**
-   * Checks whether the distributed cache has been "localized", i.e.
-   * archives extracted and paths moved so that they can be accessed
-   * through {@link DistributedCache#getLocalCacheArchives()} and
+   * Checks whether the distributed cache has been "localized", i.e. archives
+   * extracted and paths moved so that they can be accessed through
+   * {@link DistributedCache#getLocalCacheArchives()} and
    * {@link DistributedCache#getLocalCacheFiles()}
-   *
-   * @param conf the configuration
+   * 
+   * @param conf
+   *          the configuration
    * @return true if the cache is initialised
    * @throws IOException
    */
   private boolean isDistributedCacheInitialised(Configuration conf)
       throws IOException {
-    return DistributedCache.getLocalCacheArchives(conf) != null ||
-        DistributedCache.getLocalCacheFiles(conf) != null;
+    return DistributedCache.getLocalCacheArchives(conf) != null
+        || DistributedCache.getLocalCacheFiles(conf) != null;
   }
 
   /**
-   * Cleans up the distributed cache test by deleting the
-   * temporary directory and any extracted cache archives
-   * contained within
-   *
+   * Cleans up the distributed cache test by deleting the temporary directory
+   * and any extracted cache archives contained within
+   * 
    * @throws IOException
-   *  if the local fs handle cannot be retrieved
+   *           if the local fs handle cannot be retrieved
    */
   protected void cleanupDistributedCache() throws IOException {
     if (tmpDistCacheDir != null) {
@@ -600,10 +615,10 @@
   /**
    * Runs the test but returns the result set instead of validating it (ignores
    * any addOutput(), etc calls made before this)
-   *
+   * 
    * @return the list of (k, v) pairs returned as output from the test
    */
-  public abstract List<Pair<K2, V2>> run() throws IOException;
+  public abstract List<Pair<K, V>> run() throws IOException;
 
   /**
    * Runs the test and validates the results
@@ -614,7 +629,7 @@
 
   /**
    * Runs the test and validates the results
-   *
+   * 
    * @param orderMatters
    *          Whether or not output ordering is important
    */
@@ -622,7 +637,7 @@
     if (LOG.isDebugEnabled()) {
       printPreTestDebugLog();
     }
-    final List<Pair<K2, V2>> outputs = run();
+    final List<Pair<K, V>> outputs = run();
     validate(outputs, orderMatters);
     validate(counterWrapper);
     validate(mos);
@@ -637,7 +652,7 @@
 
   /**
    * Split "key \t val" into Pair(Text(key), Text(val))
-   *
+   * 
    * @param tabSeparatedPair
    * @return (k,v)
    */
@@ -647,7 +662,7 @@
 
   /**
    * Split "val,val,val,val..." into a List of Text(val) objects.
-   *
+   * 
    * @param commaDelimList
    *          A list of values separated by commas
    */
@@ -666,52 +681,55 @@
 
   /**
    * check the outputs against the expected inputs in record
-   *
+   * 
    * @param outputs
    *          The actual output (k, v) pairs
    * @param orderMatters
    *          Whether or not output ordering is important when validating test
    *          result
    */
-  protected void validate(final List<Pair<K2, V2>> outputs,
+  protected void validate(final List<Pair<K, V>> outputs,
       final boolean orderMatters) {
     // expected nothing and got nothing, everything is fine
     if (outputs.isEmpty() && expectedOutputs.isEmpty()) {
-        return;
+      return;
     }
 
     final Errors errors = new Errors(LOG);
     // expected nothing but got something
     if (!outputs.isEmpty() && expectedOutputs.isEmpty()) {
-        errors.record("Expected no output; got %d output(s).", outputs.size());
-        errors.assertNone();
+      errors.record("Expected no output; got %d output(s).", outputs.size());
+      errors.assertNone();
     }
     // expected something but got nothing
     if (outputs.isEmpty() && !expectedOutputs.isEmpty()) {
-        errors.record("Expected %d output(s); got no output.", expectedOutputs.size());
-        errors.assertNone();
+      errors.record("Expected %d output(s); got no output.",
+          expectedOutputs.size());
+      errors.assertNone();
     }
 
     // now, the smart test needs to be done
-    // check that user's key and value writables implement equals, hashCode, toString
+    // check that user's key and value writables implement equals, hashCode,
+    // toString
     checkOverrides(outputs, expectedOutputs);
 
-    final PairEquality<K2, V2> equality = new PairEquality<K2, V2>(
-            keyComparator, valueComparator);
+    final PairEquality<K, V> equality = new PairEquality<K, V>(keyComparator,
+        valueComparator);
     if (orderMatters) {
-        validateWithOrder(outputs, errors, equality);
+      validateWithOrder(outputs, errors, equality);
     } else {
-        validateWithoutOrder(outputs, errors, equality);
+      validateWithoutOrder(outputs, errors, equality);
     }
 
-    // if there are errors, it might be due to types and not clear from the message
-    if(!errors.isEmpty()) {
+    // if there are errors, it might be due to types and not clear from the
+    // message
+    if (!errors.isEmpty()) {
       Class<?> outputKeyClass = null;
       Class<?> outputValueClass = null;
       Class<?> expectedKeyClass = null;
       Class<?> expectedValueClass = null;
 
-      for (Pair<K2, V2> output : outputs) {
+      for (Pair<K, V> output : outputs) {
         if (output.getFirst() != null) {
           outputKeyClass = output.getFirst().getClass();
         }
@@ -723,7 +741,7 @@
         }
       }
 
-      for (Pair<K2, V2> expected : expectedOutputs) {
+      for (Pair<K, V> expected : expectedOutputs) {
         if (expected.getFirst() != null) {
           expectedKeyClass = expected.getFirst().getClass();
         }
@@ -735,13 +753,13 @@
         }
       }
 
-      if (outputKeyClass != null && expectedKeyClass !=null
+      if (outputKeyClass != null && expectedKeyClass != null
           && !outputKeyClass.equals(expectedKeyClass)) {
         errors.record("Mismatch in key class: expected: %s actual: %s",
             expectedKeyClass, outputKeyClass);
       }
 
-      if (outputValueClass != null && expectedValueClass !=null
+      if (outputValueClass != null && expectedValueClass != null
           && !outputValueClass.equals(expectedValueClass)) {
         errors.record("Mismatch in value class: expected: %s actual: %s",
             expectedValueClass, outputValueClass);
@@ -750,85 +768,86 @@
     errors.assertNone();
   }
 
-  private void validateWithoutOrder(final List<Pair<K2, V2>> outputs,
-      final Errors errors, final PairEquality<K2, V2> equality) {
+  private void validateWithoutOrder(final List<Pair<K, V>> outputs,
+      final Errors errors, final PairEquality<K, V> equality) {
     Set<Integer> verifiedExpecteds = new HashSet<Integer>();
     Set<Integer> unverifiedOutputs = new HashSet<Integer>();
     for (int i = 0; i < outputs.size(); i++) {
-        Pair<K2, V2> output = outputs.get(i);
-        boolean found = false;
-        for (int j = 0; j < expectedOutputs.size(); j++) {
-            if (verifiedExpecteds.contains(j)) {
-                continue;
-            }
-            Pair<K2, V2> expected = expectedOutputs.get(j);
-            if (equality.isTrueFor(output, expected)) {
-                found = true;
-                verifiedExpecteds.add(j);
-                LOG.debug(String.format("Matched expected output %s no %d at "
-                        + "position %d", output, j, i));
-                break;
-            }
+      Pair<K, V> output = outputs.get(i);
+      boolean found = false;
+      for (int j = 0; j < expectedOutputs.size(); j++) {
+        if (verifiedExpecteds.contains(j)) {
+          continue;
         }
-        if (!found) {
-            unverifiedOutputs.add(i);
+        Pair<K, V> expected = expectedOutputs.get(j);
+        if (equality.isTrueFor(output, expected)) {
+          found = true;
+          verifiedExpecteds.add(j);
+          LOG.debug(String.format("Matched expected output %s no %d at "
+              + "position %d", output, j, i));
+          break;
         }
+      }
+      if (!found) {
+        unverifiedOutputs.add(i);
+      }
     }
     for (int j = 0; j < expectedOutputs.size(); j++) {
-        if (!verifiedExpecteds.contains(j)) {
-            errors.record("Missing expected output %s", expectedOutputs.get(j));
-        }
+      if (!verifiedExpecteds.contains(j)) {
+        errors.record("Missing expected output %s", expectedOutputs.get(j));
+      }
     }
     for (int i = 0; i < outputs.size(); i++) {
-        if (unverifiedOutputs.contains(i)) {
-            errors.record("Received unexpected output %s", outputs.get(i));
-        }
+      if (unverifiedOutputs.contains(i)) {
+        errors.record("Received unexpected output %s", outputs.get(i));
+      }
     }
   }
 
-  private void validateWithOrder(final List<Pair<K2, V2>> outputs,
-      final Errors errors, final PairEquality<K2, V2> equality) {
+  private void validateWithOrder(final List<Pair<K, V>> outputs,
+      final Errors errors, final PairEquality<K, V> equality) {
     int i = 0;
-    for (i = 0; i < Math.min(outputs.size(),expectedOutputs.size()); i++) {
-        Pair<K2, V2> output = outputs.get(i);
-        Pair<K2, V2> expected = expectedOutputs.get(i);
-        if (equality.isTrueFor(output, expected)) {
-            LOG.debug(String.format("Matched expected output %s at "
-                    + "position %d", expected, i));
-        } else {
-            errors.record("Missing expected output %s at position %d, got %s.",
-                    expected, i, output);
-        }
+    for (i = 0; i < Math.min(outputs.size(), expectedOutputs.size()); i++) {
+      Pair<K, V> output = outputs.get(i);
+      Pair<K, V> expected = expectedOutputs.get(i);
+      if (equality.isTrueFor(output, expected)) {
+        LOG.debug(String.format("Matched expected output %s at "
+            + "position %d", expected, i));
+      } else {
+        errors.record("Missing expected output %s at position %d, got %s.",
+            expected, i, output);
+      }
     }
-    for(int j=i; j < outputs.size(); j++) {
-        errors.record("Received unexpected output %s at position %d.",
-                outputs.get(j), j);
+    for (int j = i; j < outputs.size(); j++) {
+      errors.record("Received unexpected output %s at position %d.",
+          outputs.get(j), j);
     }
-    for(int j=i; j < expectedOutputs.size(); j++) {
-        errors.record("Missing expected output %s at position %d.",
-                expectedOutputs.get(j), j);
+    for (int j = i; j < expectedOutputs.size(); j++) {
+      errors.record("Missing expected output %s at position %d.",
+          expectedOutputs.get(j), j);
     }
   }
 
-  private void checkOverrides(final List<Pair<K2,V2>> outputPairs, final List<Pair<K2,V2>> expectedOutputPairs) {
+  private void checkOverrides(final List<Pair<K, V>> outputPairs,
+      final List<Pair<K, V>> expectedOutputPairs) {
     Class<?> keyClass = null;
     Class<?> valueClass = null;
     // key or value could be null, try to find a class
-    for (Pair<K2,V2> pair : outputPairs) {
-        if (keyClass == null && pair.getFirst() != null) {
-            keyClass = pair.getFirst().getClass();
-        }
-        if (valueClass == null && pair.getSecond() != null) {
-        	valueClass = pair.getSecond().getClass();
-        }
+    for (Pair<K, V> pair : outputPairs) {
+      if (keyClass == null && pair.getFirst() != null) {
+        keyClass = pair.getFirst().getClass();
+      }
+      if (valueClass == null && pair.getSecond() != null) {
+        valueClass = pair.getSecond().getClass();
+      }
     }
-    for (Pair<K2,V2> pair : expectedOutputPairs) {
-        if (keyClass == null && pair.getFirst() != null) {
-            keyClass = pair.getFirst().getClass();
-        }
-        if (valueClass == null && pair.getSecond() != null) {
-        	valueClass = pair.getSecond().getClass();
-        }
+    for (Pair<K, V> pair : expectedOutputPairs) {
+      if (keyClass == null && pair.getFirst() != null) {
+        keyClass = pair.getFirst().getClass();
+      }
+      if (valueClass == null && pair.getSecond() != null) {
+        valueClass = pair.getSecond().getClass();
+      }
     }
     checkOverride(keyClass);
     checkOverride(valueClass);
@@ -836,20 +855,21 @@
 
   private void checkOverride(final Class<?> clazz) {
     if (clazz == null) {
-        return;
+      return;
     }
     try {
       if (clazz.getMethod("equals", Object.class).getDeclaringClass() != clazz) {
-        LOG.warn(clazz.getCanonicalName() + ".equals(Object) " +
-            "is not being overridden - tests may fail!");
+        LOG.warn(clazz.getCanonicalName() + ".equals(Object) "
+            + "is not being overridden - tests may fail!");
       }
       if (clazz.getMethod("hashCode").getDeclaringClass() != clazz) {
-        LOG.warn(clazz.getCanonicalName() + ".hashCode() " +
-            "is not being overridden - tests may fail!");
+        LOG.warn(clazz.getCanonicalName() + ".hashCode() "
+            + "is not being overridden - tests may fail!");
       }
       if (clazz.getMethod("toString").getDeclaringClass() != clazz) {
-        LOG.warn(clazz.getCanonicalName() + ".toString() " +
-            "is not being overridden - test failures may be difficult to diagnose.");
+        LOG.warn(clazz.getCanonicalName()
+            + ".toString() "
+            + "is not being overridden - test failures may be difficult to diagnose.");
         LOG.warn("Consider executing test using run() to access outputs");
       }
     } catch (SecurityException e) {
@@ -859,12 +879,12 @@
     }
   }
 
-  private Map<Pair<K2, V2>, List<Integer>> buildPositionMap(
-      final List<Pair<K2, V2>> values, Comparator<Pair<K2, V2>> comparator) {
-    final Map<Pair<K2, V2>, List<Integer>> valuePositions =
-        new TreeMap<Pair<K2, V2>, List<Integer>>(comparator);
+  private Map<Pair<K, V>, List<Integer>> buildPositionMap(
+      final List<Pair<K, V>> values, Comparator<Pair<K, V>> comparator) {
+    final Map<Pair<K, V>, List<Integer>> valuePositions = new TreeMap<Pair<K, V>, List<Integer>>(
+        comparator);
     for (int i = 0; i < values.size(); i++) {
-      final Pair<K2, V2> output = values.get(i);
+      final Pair<K, V> output = values.get(i);
       List<Integer> positions;
       if (valuePositions.containsKey(output)) {
         positions = valuePositions.get(output);
@@ -877,7 +897,6 @@
     return valuePositions;
   }
 
-
   /**
    * Check counters.
    */
@@ -885,6 +904,7 @@
     validateExpectedAgainstActual(counterWrapper);
     validateActualAgainstExpected(counterWrapper);
   }
+
   /**
    * Check Multiple Outputs.
    */
@@ -1025,10 +1045,10 @@
    */
   private Collection<Pair<String, String>> findExpectedCounterValues() {
     Collection<Pair<String, String>> results = new ArrayList<Pair<String, String>>();
-    for (Pair<Pair<String, String>,Long> counterAndCount : expectedStringCounters) {
+    for (Pair<Pair<String, String>, Long> counterAndCount : expectedStringCounters) {
       results.add(counterAndCount.getFirst());
     }
-    for (Pair<Enum<?>,Long> counterAndCount : expectedEnumCounters) {
+    for (Pair<Enum<?>, Long> counterAndCount : expectedEnumCounters) {
       Enum<?> first = counterAndCount.getFirst();
       String groupName = first.getDeclaringClass().getName();
       String counterName = first.name();
@@ -1038,13 +1058,12 @@
   }
 
   /**
-   * Check that provided actual counters contain all expected counters with proper
-   * values.
-   *
+   * Check that provided actual counters contain all expected counters with
+   * proper values.
+   * 
    * @param counterWrapper
    */
-  private void validateExpectedAgainstActual(
-      final CounterWrapper counterWrapper) {
+  private void validateExpectedAgainstActual(final CounterWrapper counterWrapper) {
     final Errors errors = new Errors(LOG);
 
     // Firstly check enumeration based counters
@@ -1080,16 +1099,17 @@
 
   /**
    * Check that provided actual counters are all expected.
-   *
+   * 
    * @param counterWrapper
    */
   private void validateActualAgainstExpected(final CounterWrapper counterWrapper) {
     if (strictCountersChecking) {
       final Errors errors = new Errors(LOG);
-      Collection<Pair<String, String>> unmatchedCounters = counterWrapper.findCounterValues();
+      Collection<Pair<String, String>> unmatchedCounters = counterWrapper
+          .findCounterValues();
       Collection<Pair<String, String>> findExpectedCounterValues = findExpectedCounterValues();
       unmatchedCounters.removeAll(findExpectedCounterValues);
-      if(!unmatchedCounters.isEmpty()) {
+      if (!unmatchedCounters.isEmpty()) {
         for (Pair<String, String> unmatcherCounter : unmatchedCounters) {
           errors
               .record(
@@ -1101,34 +1121,37 @@
     }
   }
 
-  protected static void formatValueList(final List<?> values,
+  public static void formatValueList(final List<?> values,
       final StringBuilder sb) {
     StringUtils.formatValueList(values, sb);
   }
 
-  protected static <KEYIN, VALUEIN> void formatPairList(final List<Pair<KEYIN,VALUEIN>> pairs,
-      final StringBuilder sb) {
+  protected static <KEYIN, VALUEIN> void formatPairList(
+      final List<Pair<KEYIN, VALUEIN>> pairs, final StringBuilder sb) {
     StringUtils.formatPairList(pairs, sb);
   }
 
   /**
    * Adds an output (k, v) pair we expect as Multiple output
-   *
+   * 
    * @param namedOutput
    * @param outputRecord
    */
-  public <K, V> void addMultiOutput(String namedOutput, final Pair<K, V> outputRecord) {
-    addMultiOutput(namedOutput, outputRecord.getFirst(), outputRecord.getSecond());
+  public <K, V> void addMultiOutput(String namedOutput,
+      final Pair<K, V> outputRecord) {
+    addMultiOutput(namedOutput, outputRecord.getFirst(),
+        outputRecord.getSecond());
   }
 
   /**
    * add a (k, v) pair we expect as Multiple output
-   *
+   * 
    * @param namedOutput
    * @param key
    * @param val
    */
-  public <K, V> void addMultiOutput(final String namedOutput, final K key, final V val) {
+  public <K, V> void addMultiOutput(final String namedOutput, final K key,
+      final V val) {
     List<Pair<?, ?>> outputs = expectedMultipleOutputs.get(namedOutput);
     if (outputs == null) {
       outputs = new ArrayList<Pair<?, ?>>();
@@ -1139,20 +1162,21 @@
 
   /**
    * works like addMultiOutput() but returns self for fluent programming style
-   *
+   * 
    * @param namedOutput
    * @param key
    * @param value
    * @return this
    */
-  public <K extends Comparable, V extends Comparable> T withMultiOutput(final String namedOutput, final K key, final V value) {
+  public <K extends Comparable, V extends Comparable> T withMultiOutput(
+      final String namedOutput, final K key, final V value) {
     addMultiOutput(namedOutput, key, value);
     return thisAsTestDriver();
   }
 
   /**
    * Works like addMultiOutput(), but returns self for fluent programming style
-   *
+   * 
    * @param namedOutput
    * @param outputRecord
    * @return this
diff --git a/src/main/java/org/apache/hadoop/mrunit/internal/driver/MultipleInputsMapReduceDriverBase.java b/src/main/java/org/apache/hadoop/mrunit/internal/driver/MultipleInputsMapReduceDriverBase.java
new file mode 100644
index 0000000..726aaeb
--- /dev/null
+++ b/src/main/java/org/apache/hadoop/mrunit/internal/driver/MultipleInputsMapReduceDriverBase.java
@@ -0,0 +1,307 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mrunit.internal.driver;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.RawComparator;
+import org.apache.hadoop.mrunit.TestDriver;
+import org.apache.hadoop.mrunit.types.Pair;
+import org.apache.hadoop.util.ReflectionUtils;
+
+import java.util.*;
+
+/**
+ * Harness that allows you to test multiple Mappers and a Reducer instance
+ * together You provide the input keys and values that should be sent to each
+ * Mapper, and outputs you expect to be sent by the Reducer to the collector for
+ * those inputs. By calling runTest(), the harness will deliver the inputs to
+ * the respective Mappers, feed the intermediate results to the Reducer (without
+ * checking them), and will check the Reducer's outputs against the expected
+ * results.
+ * 
+ * @param <M>
+ *          The type of the Mapper (to support mapred and mapreduce API)
+ * @param <K1>
+ *          The common map output key type
+ * @param <V1>
+ *          The common map output value type
+ * @param <K2>
+ *          The reduce output key type
+ * @param <V2>
+ *          The reduce output value type
+ * @param <T>
+ *          The type of the MultipleInputMapReduceDriver implementation
+ */
+public abstract class MultipleInputsMapReduceDriverBase<M, K1, V1, K2, V2, T extends MultipleInputsMapReduceDriverBase<M, K1, V1, K2, V2, T>>
+    extends TestDriver<K2, V2, T> {
+  public static final Log LOG = LogFactory
+      .getLog(MultipleInputsMapReduceDriverBase.class);
+
+  protected Map<M, Path> mapInputPaths = new HashMap<M, Path>();
+
+  /**
+   * The path passed to the specifed mapper InputSplit
+   * 
+   * @param mapper
+   *          The mapper to get the input path for
+   * @return The path
+   */
+  public Path getMapInputPath(final M mapper) {
+    return mapInputPaths.get(mapper);
+  }
+
+  /**
+   * Path that is passed to the mapper InputSplit
+   * 
+   * @param mapper
+   *          The mapper to set the input path for
+   * @param mapInputPath
+   *          The path
+   */
+  public void setMapInputPath(final M mapper, Path mapInputPath) {
+    mapInputPaths.put(mapper, mapInputPath);
+  }
+
+  /**
+   * Identical to setMapInputPath but supports a fluent programming style
+   * 
+   * @param mapper
+   *          The mapper to set the input path for
+   * @param mapInputPath
+   *          The path
+   * @return this
+   */
+  public final T withMapInputPath(final M mapper, Path mapInputPath) {
+    this.setMapInputPath(mapper, mapInputPath);
+    return thisAsMapReduceDriver();
+  }
+
+  /**
+   * Key group comparator
+   */
+  protected Comparator<K1> keyGroupComparator;
+
+  /**
+   * Set the key grouping comparator, similar to calling the following API calls
+   * but passing a real instance rather than just the class:
+   * <UL>
+   * <LI>pre 0.20.1 API:
+   * {@link org.apache.hadoop.mapred.JobConf#setOutputValueGroupingComparator(Class)}
+   * <LI>0.20.1+ API:
+   * {@link org.apache.hadoop.mapreduce.Job#setGroupingComparatorClass(Class)}
+   * </UL>
+   * 
+   * @param groupingComparator
+   */
+  @SuppressWarnings("unchecked")
+  public void setKeyGroupingComparator(
+      final RawComparator<K2> groupingComparator) {
+    keyGroupComparator = ReflectionUtils.newInstance(
+        groupingComparator.getClass(), getConfiguration());
+  }
+
+  /**
+   * Identical to {@link #setKeyGroupingComparator(RawComparator)}, but with a
+   * fluent programming style
+   * 
+   * @param groupingComparator
+   *          Comparator to use in the shuffle stage for key grouping
+   * @return this
+   */
+  public T withKeyGroupingComparator(final RawComparator<K2> groupingComparator) {
+    setKeyGroupingComparator(groupingComparator);
+    return thisAsMapReduceDriver();
+  }
+
+  /**
+   * Key value order comparator
+   */
+  protected Comparator<K1> keyValueOrderComparator;
+
+  /**
+   * Set the key value order comparator, similar to calling the following API
+   * calls but passing a real instance rather than just the class:
+   * <UL>
+   * <LI>pre 0.20.1 API:
+   * {@link org.apache.hadoop.mapred.JobConf#setOutputKeyComparatorClass(Class)}
+   * <LI>0.20.1+ API:
+   * {@link org.apache.hadoop.mapreduce.Job#setSortComparatorClass(Class)}
+   * </UL>
+   * 
+   * @param orderComparator
+   */
+  @SuppressWarnings("unchecked")
+  public void setKeyOrderComparator(final RawComparator<K2> orderComparator) {
+    keyValueOrderComparator = ReflectionUtils.newInstance(
+        orderComparator.getClass(), getConfiguration());
+  }
+
+  /**
+   * Identical to {@link #setKeyOrderComparator(RawComparator)}, but with a
+   * fluent programming style
+   * 
+   * @param orderComparator
+   *          Comparator to use in the shuffle stage for key value ordering
+   * @return this
+   */
+  public T withKeyOrderComparator(final RawComparator<K2> orderComparator) {
+    setKeyOrderComparator(orderComparator);
+    return thisAsMapReduceDriver();
+  }
+
+  @SuppressWarnings("rawtypes")
+  protected Map<M, List<Pair>> inputs = new HashMap<M, List<Pair>>();
+
+  /**
+   * Add an input to send to the specified mapper
+   * 
+   * @param mapper
+   *          The mapper to add the input to
+   * @param key
+   *          The key to add
+   * @param val
+   *          The value to add
+   * @param <K>
+   *          The key type
+   * @param <V>
+   *          The value type
+   */
+  protected <K, V> void addInput(final M mapper, final K key, final V val) {
+    if (!inputs.containsKey(mapper)) {
+      inputs.put(mapper, new ArrayList<Pair>());
+    }
+    inputs.get(mapper).add(copyPair(key, val));
+  }
+
+  /**
+   * Add an input to the specified mappper
+   * 
+   * @param mapper
+   *          The mapper to add the input to
+   * @param input
+   *          The (k, v) pair
+   * @param <K>
+   *          The key type
+   * @param <V>
+   *          The value type
+   */
+  protected <K, V> void addInput(final M mapper, final Pair<K, V> input) {
+    addInput(mapper, input.getFirst(), input.getSecond());
+  }
+
+  /**
+   * Add inputs to the specified mapper
+   * 
+   * @param mapper
+   *          The mapper to add the input to
+   * @param inputs
+   *          The (k, v) pairs
+   * @param <K>
+   *          The key type
+   * @param <V>
+   *          The value type
+   */
+  protected <K, V> void addAll(final M mapper, final List<Pair<K, V>> inputs) {
+    for (Pair<K, V> input : inputs) {
+      addInput(mapper, input);
+    }
+  }
+
+  /**
+   * Identical to addInput but supports a fluent programming style
+   * 
+   * @param mapper
+   *          The mapper to add the input to
+   * @param key
+   *          The key to add
+   * @param val
+   *          The value to add
+   * @param <K>
+   *          The key type
+   * @param <V>
+   *          The value type
+   * @return this
+   */
+  protected <K, V> T withInput(final M mapper, final K key, final V val) {
+    addInput(mapper, key, val);
+    return thisAsMapReduceDriver();
+  }
+
+  /**
+   * Identical to addInput but supports a fluent programming style
+   * 
+   * @param mapper
+   *          The mapper to add the input to
+   * @param input
+   *          The (k, v) pair to add
+   * @param <K>
+   *          The key type
+   * @param <V>
+   *          The value type
+   * @return this
+   */
+  protected <K, V> T withInput(final M mapper, final Pair<K, V> input) {
+    addInput(mapper, input);
+    return thisAsMapReduceDriver();
+  }
+
+  /**
+   * Identical to addAll but supports a fluent programming style
+   * 
+   * @param mapper
+   *          The mapper to add the input to
+   * @param inputs
+   *          The (k, v) pairs to add
+   * @param <K>
+   *          The key type
+   * @param <V>
+   *          The value type
+   * @return this
+   */
+  protected <K, V> T withAll(final M mapper, final List<Pair<K, V>> inputs) {
+    addAll(mapper, inputs);
+    return thisAsMapReduceDriver();
+  }
+
+  protected void preRunChecks(Set<M> mappers, Object reducer) {
+    for (M mapper : mappers) {
+      if (inputs.get(mapper) == null || inputs.get(mapper).isEmpty()) {
+        throw new IllegalStateException(String.format(
+            "No input was provided for mapper %s", mapper));
+      }
+    }
+
+    if (reducer == null) {
+      throw new IllegalStateException("No reducer class was provided");
+    }
+    if (driverReused()) {
+      throw new IllegalStateException("Driver reuse not allowed");
+    } else {
+      setUsedOnceStatus();
+    }
+  }
+
+  @SuppressWarnings("unchecked")
+  private T thisAsMapReduceDriver() {
+    return (T) this;
+  }
+
+}
diff --git a/src/main/java/org/apache/hadoop/mrunit/mapreduce/MapReduceDriver.java b/src/main/java/org/apache/hadoop/mrunit/mapreduce/MapReduceDriver.java
index 4bf05fd..6aa8b5a 100644
--- a/src/main/java/org/apache/hadoop/mrunit/mapreduce/MapReduceDriver.java
+++ b/src/main/java/org/apache/hadoop/mrunit/mapreduce/MapReduceDriver.java
@@ -17,25 +17,21 @@
  */
 package org.apache.hadoop.mrunit.mapreduce;
 
-import static org.apache.hadoop.mrunit.internal.util.ArgumentChecker.returnNonNull;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapreduce.*;
+import org.apache.hadoop.mrunit.MapReduceDriverBase;
+import org.apache.hadoop.mrunit.internal.counters.CounterWrapper;
+import org.apache.hadoop.mrunit.types.KeyValueReuseList;
+import org.apache.hadoop.mrunit.types.Pair;
 
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.mapred.JobConf;
-import org.apache.hadoop.mapreduce.Counters;
-import org.apache.hadoop.mapreduce.InputFormat;
-import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.hadoop.mapreduce.OutputFormat;
-import org.apache.hadoop.mapreduce.Reducer;
-import org.apache.hadoop.mrunit.MapReduceDriverBase;
-import org.apache.hadoop.mrunit.internal.counters.CounterWrapper;
-import org.apache.hadoop.mrunit.types.KeyValueReuseList;
-import org.apache.hadoop.mrunit.types.Pair;
+import static org.apache.hadoop.mrunit.internal.util.ArgumentChecker.returnNonNull;
 
 /**
  * Harness that allows you to test a Mapper and a Reducer instance together You
@@ -48,7 +44,8 @@
  * pair, representing a single unit test.
  */
 
-public class MapReduceDriver<K1, V1, K2, V2, K3, V3> extends
+public class MapReduceDriver<K1, V1, K2, V2, K3, V3>
+    extends
     MapReduceDriverBase<K1, V1, K2, V2, K3, V3, MapReduceDriver<K1, V1, K2, V2, K3, V3>> {
 
   public static final Log LOG = LogFactory.getLog(MapReduceDriver.class);
@@ -82,7 +79,7 @@
 
   /**
    * Set the Mapper instance to use with this test driver
-   *
+   * 
    * @param m
    *          the Mapper instance to use
    */
@@ -106,7 +103,7 @@
 
   /**
    * Sets the reducer object to use for this test
-   *
+   * 
    * @param r
    *          The reducer object to use
    */
@@ -116,7 +113,7 @@
 
   /**
    * Identical to setReducer(), but with fluent programming style
-   *
+   * 
    * @param r
    *          The Reducer to use
    * @return this
@@ -136,7 +133,7 @@
 
   /**
    * Sets the reducer object to use as a combiner for this test
-   *
+   * 
    * @param c
    *          The combiner object to use
    */
@@ -146,7 +143,7 @@
 
   /**
    * Identical to setCombiner(), but with fluent programming style
-   *
+   * 
    * @param c
    *          The Combiner to use
    * @return this
@@ -171,7 +168,7 @@
 
   /**
    * Sets the counters object to use for this test.
-   *
+   * 
    * @param ctrs
    *          The counters object to use.
    */
@@ -190,7 +187,7 @@
   /**
    * Configure {@link Reducer} to output with a real {@link OutputFormat}. Set
    * {@link InputFormat} to read output back in for use with run* methods
-   *
+   * 
    * @param outputFormatClass
    * @param inputFormatClass
    * @return this for fluent style
@@ -204,64 +201,26 @@
     return this;
   }
 
-  /**
-   * The private class to manage starting the reduce phase is used for type
-   * genericity reasons. This class is used in the run() method.
-   */
-  private class ReducePhaseRunner<OUTKEY, OUTVAL> {
-    private List<Pair<OUTKEY, OUTVAL>> runReduce(
-        final List<KeyValueReuseList<K2, V2>> inputs,
-        final Reducer<K2, V2, OUTKEY, OUTVAL> reducer) throws IOException {
-
-      final List<Pair<OUTKEY, OUTVAL>> reduceOutputs = new ArrayList<Pair<OUTKEY, OUTVAL>>();
-
-      if (!inputs.isEmpty()) {
-        if (LOG.isDebugEnabled()) {
-          final StringBuilder sb = new StringBuilder();
-          for (List<Pair<K2, V2>> input : inputs) {
-            formatPairList(input, sb);
-            LOG.debug("Reducing input " + sb);
-            sb.delete(0, sb.length());
-          }
-        }
-
-        final ReduceDriver<K2, V2, OUTKEY, OUTVAL> reduceDriver = ReduceDriver
-            .newReduceDriver(reducer).withCounters(getCounters())
-            .withConfiguration(getConfiguration()).withAllElements(inputs);
-
-        if (getOutputSerializationConfiguration() != null) {
-          reduceDriver
-              .withOutputSerializationConfiguration(getOutputSerializationConfiguration());
-        }
-
-        if (outputFormatClass != null) {
-          reduceDriver.withOutputFormat(outputFormatClass, inputFormatClass);
-        }
-
-        reduceOutputs.addAll(reduceDriver.run());
-      }
-
-      return reduceOutputs;
-    }
-  }
-
-  protected List<KeyValueReuseList<K2,V2>> sortAndGroup(final List<Pair<K2, V2>> mapOutputs){
-    if(mapOutputs.isEmpty()) {
+  protected List<KeyValueReuseList<K2, V2>> sortAndGroup(
+      final List<Pair<K2, V2>> mapOutputs) {
+    if (mapOutputs.isEmpty()) {
       return Collections.emptyList();
     }
 
-    if (keyValueOrderComparator == null || keyGroupComparator == null){
+    if (keyValueOrderComparator == null || keyGroupComparator == null) {
       JobConf conf = new JobConf(getConfiguration());
       conf.setMapOutputKeyClass(mapOutputs.get(0).getFirst().getClass());
-      if (keyGroupComparator == null){
+      if (keyGroupComparator == null) {
         keyGroupComparator = conf.getOutputValueGroupingComparator();
       }
       if (keyValueOrderComparator == null) {
         keyValueOrderComparator = conf.getOutputKeyComparator();
       }
     }
-    ReduceFeeder<K2,V2> reduceFeeder = new ReduceFeeder<K2,V2>(getConfiguration());
-    return reduceFeeder.sortAndGroup(mapOutputs, keyValueOrderComparator, keyGroupComparator);
+    ReduceFeeder<K2, V2> reduceFeeder = new ReduceFeeder<K2, V2>(
+        getConfiguration());
+    return reduceFeeder.sortAndGroup(mapOutputs, keyValueOrderComparator,
+        keyGroupComparator);
   }
 
   @Override
@@ -276,16 +235,20 @@
           .withCounters(getCounters()).withConfiguration(getConfiguration())
           .withAll(inputList).withMapInputPath(getMapInputPath()).run());
       if (myCombiner != null) {
-        // User has specified a combiner. Run this and replace the mapper outputs
+        // User has specified a combiner. Run this and replace the mapper
+        // outputs
         // with the result of the combiner.
         LOG.debug("Starting combine phase with combiner: " + myCombiner);
-        mapOutputs = new ReducePhaseRunner<K2, V2>().runReduce(
-            sortAndGroup(mapOutputs), myCombiner);
+        mapOutputs = new ReducePhaseRunner<K2, V2, K2, V2>(inputFormatClass,
+            getConfiguration(), counters,
+            getOutputSerializationConfiguration(), outputFormatClass)
+            .runReduce(sortAndGroup(mapOutputs), myCombiner);
       }
       // Run the reduce phase.
       LOG.debug("Starting reduce phase with reducer: " + myReducer);
-      return new ReducePhaseRunner<K3, V3>().runReduce(sortAndGroup(mapOutputs),
-          myReducer);
+      return new ReducePhaseRunner<K2, V2, K3, V3>(inputFormatClass,
+          getConfiguration(), counters, getOutputSerializationConfiguration(),
+          outputFormatClass).runReduce(sortAndGroup(mapOutputs), myReducer);
     } finally {
       cleanupDistributedCache();
     }
@@ -299,7 +262,7 @@
   /**
    * Returns a new MapReduceDriver without having to specify the generic types
    * on the right hand side of the object create statement.
-   *
+   * 
    * @return new MapReduceDriver
    */
   public static <K1, V1, K2, V2, K3, V3> MapReduceDriver<K1, V1, K2, V2, K3, V3> newMapReduceDriver() {
@@ -309,7 +272,7 @@
   /**
    * Returns a new MapReduceDriver without having to specify the generic types
    * on the right hand side of the object create statement.
-   *
+   * 
    * @param mapper
    *          passed to MapReduceDriver constructor
    * @param reducer
@@ -324,7 +287,7 @@
   /**
    * Returns a new MapReduceDriver without having to specify the generic types
    * on the right hand side of the object create statement.
-   *
+   * 
    * @param mapper
    *          passed to MapReduceDriver constructor
    * @param reducer
diff --git a/src/main/java/org/apache/hadoop/mrunit/mapreduce/MultipleInputsMapReduceDriver.java b/src/main/java/org/apache/hadoop/mrunit/mapreduce/MultipleInputsMapReduceDriver.java
new file mode 100644
index 0000000..a77ad08
--- /dev/null
+++ b/src/main/java/org/apache/hadoop/mrunit/mapreduce/MultipleInputsMapReduceDriver.java
@@ -0,0 +1,503 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mrunit.mapreduce;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapreduce.*;
+import org.apache.hadoop.mrunit.internal.counters.CounterWrapper;
+import org.apache.hadoop.mrunit.internal.driver.MultipleInputsMapReduceDriverBase;
+import org.apache.hadoop.mrunit.types.KeyValueReuseList;
+import org.apache.hadoop.mrunit.types.Pair;
+
+import java.io.IOException;
+import java.util.*;
+
+import static org.apache.hadoop.mrunit.internal.util.ArgumentChecker.returnNonNull;
+
+/**
+ * Harness that allows you to test multiple Mappers and a Reducer instance
+ * together (along with an optional combiner). You provide the input keys and
+ * values that should be sent to each Mapper, and outputs you expect to be sent
+ * by the Reducer to the collector for those inputs. By calling runTest(), the
+ * harness will deliver the inputs to the respective Mappers, feed the
+ * intermediate results to the Reducer (without checking them), and will check
+ * the Reducer's outputs against the expected results.
+ * 
+ * If a combiner is specified, it will run exactly once after all the Mappers
+ * and before the Reducer
+ * 
+ * @param <K1>
+ *          The common map output key type
+ * @param <V1>
+ *          The common map output value type
+ * @param <K2>
+ *          The reduce output key type
+ * @param <V2>
+ *          The reduce output value type
+ */
+public class MultipleInputsMapReduceDriver<K1, V1, K2, V2>
+    extends
+        MultipleInputsMapReduceDriverBase<Mapper, K1, V1, K2, V2, MultipleInputsMapReduceDriver<K1, V1, K2, V2>> {
+  public static final Log LOG = LogFactory
+      .getLog(MultipleInputsMapReduceDriver.class);
+
+  private Set<Mapper> mappers = new HashSet<Mapper>();
+
+  /**
+   * Add a mapper to use with this test driver
+   * 
+   * @param mapper
+   *          The mapper instance to add
+   * @param <K>
+   *          The input key type to the mapper
+   * @param <V>
+   *          The input value type to the mapper
+   */
+  public <K, V> void addMapper(final Mapper<K, V, K1, V1> mapper) {
+    this.mappers.add(returnNonNull(mapper));
+  }
+
+  /**
+   * Identical to addMapper but supports a fluent programming style
+   * 
+   * @param mapper
+   *          The mapper instance to add
+   * @param <K>
+   *          The input key type to the mapper
+   * @param <V>
+   *          The input value type to the mapper
+   * @return this
+   */
+  public <K, V> MultipleInputsMapReduceDriver<K1, V1, K2, V2> withMapper(
+      final Mapper<K, V, K1, V1> mapper) {
+    addMapper(mapper);
+    return this;
+  }
+
+  /**
+   * @return The Mapper instances being used by this test
+   */
+  public Collection<Mapper> getMappers() {
+    return mappers;
+  }
+
+  private Reducer<K1, V1, K1, V1> combiner;
+
+  /**
+   * Set the combiner to use with this test driver
+   * 
+   * @param combiner
+   *          The combiner instance to use
+   */
+  public void setCombiner(final Reducer<K1, V1, K1, V1> combiner) {
+    this.combiner = returnNonNull(combiner);
+  }
+
+  /**
+   * Identical to setCombiner but supports a fluent programming style
+   * 
+   * @param combiner
+   *          The combiner instance to use
+   * @return this
+   */
+  public MultipleInputsMapReduceDriver<K1, V1, K2, V2> withCombiner(
+      final Reducer<K1, V1, K1, V1> combiner) {
+    setCombiner(combiner);
+    return this;
+  }
+
+  /**
+   * @return The combiner instance being used by this test
+   */
+  public Reducer<K1, V1, K1, V1> getCombiner() {
+    return combiner;
+  }
+
+  private Reducer<K1, V1, K2, V2> reducer;
+
+  /**
+   * Set the reducer to use with this test driver
+   * 
+   * @param reducer
+   *          The reducer instance to use
+   */
+  public void setReducer(final Reducer<K1, V1, K2, V2> reducer) {
+    this.reducer = returnNonNull(reducer);
+  }
+
+  /**
+   * Identical to setReducer but supports a fluent programming style
+   * 
+   * @param reducer
+   *          The reducer instance to use
+   * @return this
+   */
+  public MultipleInputsMapReduceDriver<K1, V1, K2, V2> withReducer(
+      final Reducer<K1, V1, K2, V2> reducer) {
+    setReducer(reducer);
+    return this;
+  }
+
+  /**
+   * @return Get the reducer instance being used by this test
+   */
+  public Reducer<K1, V1, K2, V2> getReducer() {
+    return reducer;
+  }
+
+  private Counters counters;
+
+  /**
+   * @return The counters used in this test
+   */
+  public Counters getCounters() {
+    return counters;
+  }
+
+  /**
+   * Sets the counters object to use for this test
+   * 
+   * @param counters
+   *          The counters object to use
+   */
+  public void setCounters(Counters counters) {
+    this.counters = counters;
+    counterWrapper = new CounterWrapper(counters);
+  }
+
+  /**
+   * Identical to setCounters but supports a fluent programming style
+   * 
+   * @param counters
+   *          The counters object to use
+   * @return this
+   */
+  public MultipleInputsMapReduceDriver<K1, V1, K2, V2> withCounter(
+      Counters counters) {
+    setCounters(counters);
+    return this;
+  }
+
+  private Class<? extends OutputFormat> outputFormatClass;
+
+  /**
+   * Configure {@link Reducer} to output with a real {@link OutputFormat}.
+   * 
+   * @param outputFormatClass
+   *          The OutputFormat class
+   * @return this
+   */
+  public MultipleInputsMapReduceDriver<K1, V1, K2, V2> withOutputFormat(
+      final Class<? extends OutputFormat> outputFormatClass) {
+    this.outputFormatClass = returnNonNull(outputFormatClass);
+    return this;
+  }
+
+  private Class<? extends InputFormat> inputFormatClass;
+
+  /**
+   * Set the InputFormat
+   * 
+   * @param inputFormatClass
+   *          The InputFormat class
+   * @return this
+   */
+  public MultipleInputsMapReduceDriver<K1, V1, K2, V2> withInputFormat(
+      final Class<? extends InputFormat> inputFormatClass) {
+    this.inputFormatClass = returnNonNull(inputFormatClass);
+    return this;
+  }
+
+  /**
+   * Construct a driver with the specified Reducer. Note that a Combiner can be
+   * set separately.
+   * 
+   * @param reducer
+   *          The reducer to use
+   */
+  public MultipleInputsMapReduceDriver(Reducer<K1, V1, K2, V2> reducer) {
+    this();
+    this.reducer = reducer;
+  }
+
+  /**
+   * Construct a driver with the specified Combiner and Reducers
+   * 
+   * @param combiner
+   *          The combiner to use
+   * @param reducer
+   *          The reducer to use
+   */
+  public MultipleInputsMapReduceDriver(Reducer<K1, V1, K1, V1> combiner,
+                                       Reducer<K1, V1, K2, V2> reducer) {
+    this(reducer);
+    this.combiner = combiner;
+  }
+
+  /**
+   * Construct a driver without specifying a Combiner nor a Reducer. Note that
+   * these can be set with the appropriate set methods and that at least the
+   * Reducer must be set.
+   */
+  public MultipleInputsMapReduceDriver() {
+    setCounters(new Counters());
+  }
+
+  /**
+   * Static factory-style method to construct a driver instance with the
+   * specified Combiner and Reducer
+   * 
+   * @param combiner
+   *          The combiner to use
+   * @param reducer
+   *          The reducer to use
+   * @param <K1>
+   *          The common output key type of the mappers
+   * @param <V1>
+   *          The common output value type of the mappers
+   * @param <K2>
+   *          The output key type of the reducer
+   * @param <V2>
+   *          The output value type of the reducer
+   * @return this to support fluent programming style
+   */
+  public static <K1, V1, K2, V2> MultipleInputsMapReduceDriver<K1, V1, K2, V2> newMultipleInputMapReduceDriver(
+      final Reducer<K1, V1, K1, V1> combiner,
+      final Reducer<K1, V1, K2, V2> reducer) {
+    return new MultipleInputsMapReduceDriver<K1, V1, K2, V2>(combiner, reducer);
+  }
+
+  /**
+   * Static factory-style method to construct a driver instance with the
+   * specified Reducer
+   * 
+   * @param reducer
+   *          The reducer to use
+   * @param <K1>
+   *          The common output key type of the mappers
+   * @param <V1>
+   *          The common output value type of the mappers
+   * @param <K2>
+   *          The output key type of the reducer
+   * @param <V2>
+   *          The output value type of the reducer
+   * @return this to support fluent programming style
+   */
+  public static <K1, V1, K2, V2> MultipleInputsMapReduceDriver<K1, V1, K2, V2> newMultipleInputMapReduceDriver(
+      final Reducer<K1, V1, K2, V2> reducer) {
+    return new MultipleInputsMapReduceDriver<K1, V1, K2, V2>(reducer);
+  }
+
+  /**
+   * Static factory-style method to construct a driver instance without
+   * specifying a Combiner nor a Reducer. Note that these can be set separately
+   * by using the appropriate set (or with) methods and that at least a Reducer
+   * must be set
+   * 
+   * @param <K1>
+   *          The common output key type of the mappers
+   * @param <V1>
+   *          The common output value type of the mappers
+   * @param <K2>
+   *          The output key type of the reducer
+   * @param <V2>
+   *          The output value type of the reducer
+   * @return this to support fluent programming style
+   */
+  public static <K1, V1, K2, V2> MultipleInputsMapReduceDriver<K1, V1, K2, V2> newMultipleInputMapReduceDriver() {
+    return new MultipleInputsMapReduceDriver<K1, V1, K2, V2>();
+  }
+
+  /**
+   * Add the specified (key, val) pair to the specified mapper
+   * 
+   * @param mapper
+   *          The mapper to add the input pair to
+   * @param key
+   *          The key
+   * @param val
+   *          The value
+   * @param <K>
+   *          The type of the key
+   * @param <V>
+   *          The type of the value
+   */
+  public <K, V> void addInput(final Mapper<K, V, K1, V1> mapper, final K key,
+      final V val) {
+    super.addInput(mapper, key, val);
+  }
+
+  /**
+   * Add the specified input pair to the specified mapper
+   * 
+   * @param mapper
+   *          The mapper to add the input pair to
+   * @param input
+   *          The (k,v) pair to add
+   * @param <K>
+   *          The type of the key
+   * @param <V>
+   *          The type of the value
+   */
+  public <K, V> void addInput(final Mapper<K, V, K1, V1> mapper,
+      final Pair<K, V> input) {
+    super.addInput(mapper, input);
+  }
+
+  /**
+   * Add the specified input pairs to the specified mapper
+   * 
+   * @param mapper
+   *          The mapper to add the input pairs to
+   * @param inputs
+   *          The (k, v) pairs to add
+   * @param <K>
+   *          The type of the key
+   * @param <V>
+   *          The type of the value
+   */
+  public <K, V> void addAll(final Mapper<K, V, K1, V1> mapper,
+      final List<Pair<K, V>> inputs) {
+    super.addAll(mapper, inputs);
+  }
+
+  /**
+   * Identical to addInput but supports fluent programming style
+   * 
+   * @param mapper
+   *          The mapper to add the input pair to
+   * @param key
+   *          The key
+   * @param val
+   *          The value
+   * @param <K>
+   *          The type of the key
+   * @param <V>
+   *          The type of the value
+   * @return this
+   */
+  public <K, V> MultipleInputsMapReduceDriver<K1, V1, K2, V2> withInput(
+      final Mapper<K, V, K1, V1> mapper, final K key, final V val) {
+    return super.withInput(mapper, key, val);
+  }
+
+  /**
+   * Identical to addInput but supports fluent programming style
+   * 
+   * @param mapper
+   *          The mapper to add the input pairs to
+   * @param inputs
+   *          The (k, v) pairs to add
+   * @param <K>
+   *          The type of the key
+   * @param <V>
+   *          The type of the value
+   * @return this
+   */
+  public <K, V> MultipleInputsMapReduceDriver<K1, V1, K2, V2> withInput(
+      final Mapper<K, V, K1, V1> mapper, final Pair<K, V> input) {
+    return super.withInput(mapper, input);
+  }
+
+  /**
+   * Identical to addInput but supports fluent programming style
+   * 
+   * @param mapper
+   *          The mapper to add the input pairs to
+   * @param inputs
+   *          The (k, v) pairs to add
+   * @param <K>
+   *          The type of the key
+   * @param <V>
+   *          The type of the value
+   * @return this
+   */
+  public <K, V> MultipleInputsMapReduceDriver<K1, V1, K2, V2> withAll(
+      final Mapper<K, V, K1, V1> mapper, final List<Pair<K, V>> inputs) {
+    return super.withAll(mapper, inputs);
+  }
+
+  @Override
+  protected void preRunChecks(Set<Mapper> mappers, Object reducer) {
+    if (mappers.isEmpty()) {
+      throw new IllegalStateException("No mappers were provided");
+    }
+    super.preRunChecks(mappers, reducer);
+  }
+
+  protected List<KeyValueReuseList<K1, V1>> sortAndGroup(
+      final List<Pair<K1, V1>> mapOutputs) {
+    if (mapOutputs.isEmpty()) {
+      return Collections.emptyList();
+    }
+
+    if (keyValueOrderComparator == null || keyGroupComparator == null) {
+      JobConf conf = new JobConf(getConfiguration());
+      conf.setMapOutputKeyClass(mapOutputs.get(0).getFirst().getClass());
+      if (keyGroupComparator == null) {
+        keyGroupComparator = conf.getOutputValueGroupingComparator();
+      }
+      if (keyValueOrderComparator == null) {
+        keyValueOrderComparator = conf.getOutputKeyComparator();
+      }
+    }
+    ReduceFeeder<K1, V1> reduceFeeder = new ReduceFeeder<K1, V1>(
+        getConfiguration());
+    return reduceFeeder.sortAndGroup(mapOutputs, keyValueOrderComparator,
+        keyGroupComparator);
+  }
+
+  @SuppressWarnings("unchecked")
+  @Override
+  public List<Pair<K2, V2>> run() throws IOException {
+    try {
+      preRunChecks(mappers, reducer);
+      initDistributedCache();
+
+      List<Pair<K1, V1>> outputs = new ArrayList<Pair<K1, V1>>();
+
+      for (Mapper mapper : mappers) {
+        MapDriver mapDriver = MapDriver.newMapDriver(mapper);
+        mapDriver.setCounters(counters);
+        mapDriver.setConfiguration(getConfiguration());
+        mapDriver.addAll(inputs.get(mapper));
+        mapDriver.withMapInputPath(getMapInputPath(mapper));
+        outputs.addAll(mapDriver.run());
+      }
+
+      if (combiner != null) {
+        LOG.debug("Starting combine phase with combiner: " + combiner);
+        outputs = new ReducePhaseRunner<K1, V1, K1, V1>(inputFormatClass,
+            getConfiguration(), counters,
+            getOutputSerializationConfiguration(), outputFormatClass)
+            .runReduce(sortAndGroup(outputs), combiner);
+      }
+
+      LOG.debug("Starting reduce phase with reducer: " + reducer);
+
+      return new ReducePhaseRunner<K1, V1, K2, V2>(inputFormatClass,
+          getConfiguration(), counters, getOutputSerializationConfiguration(),
+          outputFormatClass).runReduce(sortAndGroup(outputs), reducer);
+    } finally {
+      cleanupDistributedCache();
+    }
+  }
+}
diff --git a/src/main/java/org/apache/hadoop/mrunit/mapreduce/ReducePhaseRunner.java b/src/main/java/org/apache/hadoop/mrunit/mapreduce/ReducePhaseRunner.java
new file mode 100644
index 0000000..73e7152
--- /dev/null
+++ b/src/main/java/org/apache/hadoop/mrunit/mapreduce/ReducePhaseRunner.java
@@ -0,0 +1,90 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mrunit.mapreduce;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.mapreduce.Counters;
+import org.apache.hadoop.mapreduce.InputFormat;
+import org.apache.hadoop.mapreduce.OutputFormat;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mrunit.TestDriver;
+import org.apache.hadoop.mrunit.types.KeyValueReuseList;
+import org.apache.hadoop.mrunit.types.Pair;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+class ReducePhaseRunner<INKEY, INVAL, OUTKEY, OUTVAL> {
+  public static final Log LOG = LogFactory.getLog(ReducePhaseRunner.class);
+
+  private final Configuration configuration;
+  private final Counters counters;
+  private Configuration outputSerializationConfiguration;
+  private Class<? extends OutputFormat> outputFormatClass;
+  private Class<? extends InputFormat> inputFormatClass;
+
+  ReducePhaseRunner(Class<? extends InputFormat> inputFormatClass,
+      Configuration configuration, Counters counters,
+      Configuration outputSerializationConfiguration,
+      Class<? extends OutputFormat> outputFormatClass) {
+    this.inputFormatClass = inputFormatClass;
+    this.configuration = configuration;
+    this.counters = counters;
+    this.outputSerializationConfiguration = outputSerializationConfiguration;
+    this.outputFormatClass = outputFormatClass;
+  }
+
+  public List<Pair<OUTKEY, OUTVAL>> runReduce(
+      final List<KeyValueReuseList<INKEY, INVAL>> inputs,
+      final Reducer<INKEY, INVAL, OUTKEY, OUTVAL> reducer) throws IOException {
+
+    final List<Pair<OUTKEY, OUTVAL>> reduceOutputs = new ArrayList<Pair<OUTKEY, OUTVAL>>();
+
+    if (!inputs.isEmpty()) {
+      if (LOG.isDebugEnabled()) {
+        final StringBuilder sb = new StringBuilder();
+        for (List<Pair<INKEY, INVAL>> input : inputs) {
+          TestDriver.formatValueList(input, sb);
+          LOG.debug("Reducing input " + sb);
+          sb.delete(0, sb.length());
+        }
+      }
+
+      final ReduceDriver<INKEY, INVAL, OUTKEY, OUTVAL> reduceDriver = ReduceDriver
+          .newReduceDriver(reducer).withCounters(counters)
+          .withConfiguration(configuration).withAllElements(inputs);
+
+      if (outputSerializationConfiguration != null) {
+        reduceDriver
+            .withOutputSerializationConfiguration(outputSerializationConfiguration);
+      }
+
+      if (outputFormatClass != null) {
+        reduceDriver.withOutputFormat(outputFormatClass, inputFormatClass);
+      }
+
+      reduceOutputs.addAll(reduceDriver.run());
+    }
+
+    return reduceOutputs;
+  }
+}
\ No newline at end of file
diff --git a/src/test/java/org/apache/hadoop/mrunit/TestMapOutputShuffler.java b/src/test/java/org/apache/hadoop/mrunit/TestMapOutputShuffler.java
new file mode 100644
index 0000000..1157c74
--- /dev/null
+++ b/src/test/java/org/apache/hadoop/mrunit/TestMapOutputShuffler.java
@@ -0,0 +1,130 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mrunit;
+
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mrunit.types.Pair;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.apache.hadoop.mrunit.ExtendedAssert.assertListEquals;
+import static org.junit.Assert.assertEquals;
+
+public class TestMapOutputShuffler {
+  private MapOutputShuffler<Text, Text> shuffler;
+
+  @Before
+  public void setUp() {
+    shuffler = new MapOutputShuffler<Text, Text>(null, null, null);
+  }
+
+  @Test
+  public void testEmptyShuffle() {
+    final List<Pair<Text, Text>> inputs = new ArrayList<Pair<Text, Text>>();
+    final List<Pair<Text, List<Text>>> outputs = shuffler.shuffle(inputs);
+    assertEquals(0, outputs.size());
+  }
+
+  // just shuffle a single (k, v) pair
+  @Test
+  public void testSingleShuffle() {
+    final List<Pair<Text, Text>> inputs = new ArrayList<Pair<Text, Text>>();
+    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("b")));
+
+    final List<Pair<Text, List<Text>>> outputs = shuffler.shuffle(inputs);
+
+    final List<Pair<Text, List<Text>>> expected = new ArrayList<Pair<Text, List<Text>>>();
+    final List<Text> sublist = new ArrayList<Text>();
+    sublist.add(new Text("b"));
+    expected.add(new Pair<Text, List<Text>>(new Text("a"), sublist));
+
+    assertListEquals(expected, outputs);
+  }
+
+  // shuffle multiple values from the same key.
+  @Test
+  public void testShuffleOneKey() {
+    final List<Pair<Text, Text>> inputs = new ArrayList<Pair<Text, Text>>();
+    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("b")));
+    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("c")));
+
+    final List<Pair<Text, List<Text>>> outputs = shuffler.shuffle(inputs);
+
+    final List<Pair<Text, List<Text>>> expected = new ArrayList<Pair<Text, List<Text>>>();
+    final List<Text> sublist = new ArrayList<Text>();
+    sublist.add(new Text("b"));
+    sublist.add(new Text("c"));
+    expected.add(new Pair<Text, List<Text>>(new Text("a"), sublist));
+
+    assertListEquals(expected, outputs);
+  }
+
+  // shuffle multiple keys
+  @Test
+  public void testMultiShuffle1() {
+    final List<Pair<Text, Text>> inputs = new ArrayList<Pair<Text, Text>>();
+    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("x")));
+    inputs.add(new Pair<Text, Text>(new Text("b"), new Text("z")));
+    inputs.add(new Pair<Text, Text>(new Text("b"), new Text("w")));
+    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("y")));
+
+    final List<Pair<Text, List<Text>>> outputs = shuffler.shuffle(inputs);
+
+    final List<Pair<Text, List<Text>>> expected = new ArrayList<Pair<Text, List<Text>>>();
+    final List<Text> sublist1 = new ArrayList<Text>();
+    sublist1.add(new Text("x"));
+    sublist1.add(new Text("y"));
+    expected.add(new Pair<Text, List<Text>>(new Text("a"), sublist1));
+
+    final List<Text> sublist2 = new ArrayList<Text>();
+    sublist2.add(new Text("z"));
+    sublist2.add(new Text("w"));
+    expected.add(new Pair<Text, List<Text>>(new Text("b"), sublist2));
+
+    assertListEquals(expected, outputs);
+  }
+
+  // shuffle multiple keys that are out-of-order to start.
+  @Test
+  public void testMultiShuffle2() {
+    final List<Pair<Text, Text>> inputs = new ArrayList<Pair<Text, Text>>();
+    inputs.add(new Pair<Text, Text>(new Text("b"), new Text("z")));
+    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("x")));
+    inputs.add(new Pair<Text, Text>(new Text("b"), new Text("w")));
+    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("y")));
+
+    final List<Pair<Text, List<Text>>> outputs = shuffler.shuffle(inputs);
+
+    final List<Pair<Text, List<Text>>> expected = new ArrayList<Pair<Text, List<Text>>>();
+    final List<Text> sublist1 = new ArrayList<Text>();
+    sublist1.add(new Text("x"));
+    sublist1.add(new Text("y"));
+    expected.add(new Pair<Text, List<Text>>(new Text("a"), sublist1));
+
+    final List<Text> sublist2 = new ArrayList<Text>();
+    sublist2.add(new Text("z"));
+    sublist2.add(new Text("w"));
+    expected.add(new Pair<Text, List<Text>>(new Text("b"), sublist2));
+
+    assertListEquals(expected, outputs);
+  }
+}
diff --git a/src/test/java/org/apache/hadoop/mrunit/TestMapReduceDriver.java b/src/test/java/org/apache/hadoop/mrunit/TestMapReduceDriver.java
index 50da602..3a60f73 100644
--- a/src/test/java/org/apache/hadoop/mrunit/TestMapReduceDriver.java
+++ b/src/test/java/org/apache/hadoop/mrunit/TestMapReduceDriver.java
@@ -17,14 +17,6 @@
  */
 package org.apache.hadoop.mrunit;
 
-import static org.apache.hadoop.mrunit.ExtendedAssert.*;
-import static org.junit.Assert.*;
-
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
-
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.IntWritable;
@@ -32,17 +24,7 @@
 import org.apache.hadoop.io.RawComparator;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.io.serializer.JavaSerializationComparator;
-import org.apache.hadoop.mapred.FileSplit;
-import org.apache.hadoop.mapred.JobConf;
-import org.apache.hadoop.mapred.MapReduceBase;
-import org.apache.hadoop.mapred.Mapper;
-import org.apache.hadoop.mapred.OutputCollector;
-import org.apache.hadoop.mapred.Reducer;
-import org.apache.hadoop.mapred.Reporter;
-import org.apache.hadoop.mapred.SequenceFileInputFormat;
-import org.apache.hadoop.mapred.SequenceFileOutputFormat;
-import org.apache.hadoop.mapred.TextInputFormat;
-import org.apache.hadoop.mapred.TextOutputFormat;
+import org.apache.hadoop.mapred.*;
 import org.apache.hadoop.mapred.lib.IdentityMapper;
 import org.apache.hadoop.mapred.lib.IdentityReducer;
 import org.apache.hadoop.mapred.lib.LongSumReducer;
@@ -53,7 +35,14 @@
 import org.junit.Rule;
 import org.junit.Test;
 
-import com.google.common.collect.Lists;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+import static org.apache.hadoop.mrunit.ExtendedAssert.assertListEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
 
 public class TestMapReduceDriver {
 
@@ -100,10 +89,10 @@
   public void testUncomparable() throws IOException {
     Text k = new Text("test");
     Object v = new UncomparableWritable(2);
-    MapReduceDriver.newMapReduceDriver(
-        new IdentityMapper<Text, Object>(),
-        new IdentityReducer<Text, Object>())
-        .withInput(k, v).withOutput(k, v).runTest();
+    MapReduceDriver
+        .newMapReduceDriver(new IdentityMapper<Text, Object>(),
+            new IdentityReducer<Text, Object>()).withInput(k, v)
+        .withOutput(k, v).runTest();
   }
 
   @Test
@@ -127,8 +116,10 @@
   @Test
   public void testTestRun3() throws IOException {
     thrown.expectAssertionErrorMessage("2 Error(s)");
-    thrown.expectAssertionErrorMessage("Missing expected output (foo, 52) at position 0, got (bar, 12).");
-    thrown.expectAssertionErrorMessage("Missing expected output (bar, 12) at position 1, got (foo, 52).");
+    thrown
+        .expectAssertionErrorMessage("Missing expected output (foo, 52) at position 0, got (bar, 12).");
+    thrown
+        .expectAssertionErrorMessage("Missing expected output (bar, 12) at position 1, got (foo, 52).");
     driver.withInput(new Text("foo"), new LongWritable(FOO_IN_A))
         .withInput(new Text("bar"), new LongWritable(BAR_IN))
         .withInput(new Text("foo"), new LongWritable(FOO_IN_B))
@@ -139,13 +130,18 @@
   @Test
   public void testAddAll() throws IOException {
     final List<Pair<Text, LongWritable>> inputs = new ArrayList<Pair<Text, LongWritable>>();
-    inputs.add(new Pair<Text, LongWritable>(new Text("foo"), new LongWritable(FOO_IN_A)));
-    inputs.add(new Pair<Text, LongWritable>(new Text("foo"), new LongWritable(FOO_IN_B)));
-    inputs.add(new Pair<Text, LongWritable>(new Text("bar"), new LongWritable(BAR_IN)));
+    inputs.add(new Pair<Text, LongWritable>(new Text("foo"), new LongWritable(
+        FOO_IN_A)));
+    inputs.add(new Pair<Text, LongWritable>(new Text("foo"), new LongWritable(
+        FOO_IN_B)));
+    inputs.add(new Pair<Text, LongWritable>(new Text("bar"), new LongWritable(
+        BAR_IN)));
 
     final List<Pair<Text, LongWritable>> outputs = new ArrayList<Pair<Text, LongWritable>>();
-    outputs.add(new Pair<Text, LongWritable>(new Text("bar"), new LongWritable(BAR_IN)));
-    outputs.add(new Pair<Text, LongWritable>(new Text("foo"), new LongWritable(FOO_OUT)));
+    outputs.add(new Pair<Text, LongWritable>(new Text("bar"), new LongWritable(
+        BAR_IN)));
+    outputs.add(new Pair<Text, LongWritable>(new Text("foo"), new LongWritable(
+        FOO_OUT)));
 
     driver.withAll(inputs).withAllOutput(outputs).runTest();
   }
@@ -177,97 +173,6 @@
     driver.runTest();
   }
 
-  @Test
-  public void testEmptyShuffle() {
-    final List<Pair<Text, Text>> inputs = new ArrayList<Pair<Text, Text>>();
-    final List<Pair<Text, List<Text>>> outputs = driver2.shuffle(inputs);
-    assertEquals(0, outputs.size());
-  }
-
-  // just shuffle a single (k, v) pair
-  @Test
-  public void testSingleShuffle() {
-    final List<Pair<Text, Text>> inputs = new ArrayList<Pair<Text, Text>>();
-    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("b")));
-
-    final List<Pair<Text, List<Text>>> outputs = driver2.shuffle(inputs);
-
-    final List<Pair<Text, List<Text>>> expected = new ArrayList<Pair<Text, List<Text>>>();
-    final List<Text> sublist = new ArrayList<Text>();
-    sublist.add(new Text("b"));
-    expected.add(new Pair<Text, List<Text>>(new Text("a"), sublist));
-
-    assertListEquals(expected, outputs);
-  }
-
-  // shuffle multiple values from the same key.
-  @Test
-  public void testShuffleOneKey() {
-    final List<Pair<Text, Text>> inputs = new ArrayList<Pair<Text, Text>>();
-    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("b")));
-    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("c")));
-
-    final List<Pair<Text, List<Text>>> outputs = driver2.shuffle(inputs);
-
-    final List<Pair<Text, List<Text>>> expected = new ArrayList<Pair<Text, List<Text>>>();
-    final List<Text> sublist = new ArrayList<Text>();
-    sublist.add(new Text("b"));
-    sublist.add(new Text("c"));
-    expected.add(new Pair<Text, List<Text>>(new Text("a"), sublist));
-
-    assertListEquals(expected, outputs);
-  }
-
-  // shuffle multiple keys
-  @Test
-  public void testMultiShuffle1() {
-    final List<Pair<Text, Text>> inputs = new ArrayList<Pair<Text, Text>>();
-    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("x")));
-    inputs.add(new Pair<Text, Text>(new Text("b"), new Text("z")));
-    inputs.add(new Pair<Text, Text>(new Text("b"), new Text("w")));
-    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("y")));
-
-    final List<Pair<Text, List<Text>>> outputs = driver2.shuffle(inputs);
-
-    final List<Pair<Text, List<Text>>> expected = new ArrayList<Pair<Text, List<Text>>>();
-    final List<Text> sublist1 = new ArrayList<Text>();
-    sublist1.add(new Text("x"));
-    sublist1.add(new Text("y"));
-    expected.add(new Pair<Text, List<Text>>(new Text("a"), sublist1));
-
-    final List<Text> sublist2 = new ArrayList<Text>();
-    sublist2.add(new Text("z"));
-    sublist2.add(new Text("w"));
-    expected.add(new Pair<Text, List<Text>>(new Text("b"), sublist2));
-
-    assertListEquals(expected, outputs);
-  }
-
-  // shuffle multiple keys that are out-of-order to start.
-  @Test
-  public void testMultiShuffle2() {
-    final List<Pair<Text, Text>> inputs = new ArrayList<Pair<Text, Text>>();
-    inputs.add(new Pair<Text, Text>(new Text("b"), new Text("z")));
-    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("x")));
-    inputs.add(new Pair<Text, Text>(new Text("b"), new Text("w")));
-    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("y")));
-
-    final List<Pair<Text, List<Text>>> outputs = driver2.shuffle(inputs);
-
-    final List<Pair<Text, List<Text>>> expected = new ArrayList<Pair<Text, List<Text>>>();
-    final List<Text> sublist1 = new ArrayList<Text>();
-    sublist1.add(new Text("x"));
-    sublist1.add(new Text("y"));
-    expected.add(new Pair<Text, List<Text>>(new Text("a"), sublist1));
-
-    final List<Text> sublist2 = new ArrayList<Text>();
-    sublist2.add(new Text("z"));
-    sublist2.add(new Text("w"));
-    expected.add(new Pair<Text, List<Text>>(new Text("b"), sublist2));
-
-    assertListEquals(expected, outputs);
-  }
-
   // Test "combining" with an IdentityReducer. Result should be the same.
   @Test
   public void testIdentityCombiner() throws IOException {
@@ -306,12 +211,13 @@
   @Test
   public void testRepeatRun() throws IOException {
     driver.withCombiner(new IdentityReducer<Text, LongWritable>())
-            .withInput(new Text("foo"), new LongWritable(FOO_IN_A))
-            .withInput(new Text("foo"), new LongWritable(FOO_IN_B))
-            .withInput(new Text("bar"), new LongWritable(BAR_IN))
-            .withOutput(new Text("bar"), new LongWritable(BAR_IN))
-            .withOutput(new Text("foo"), new LongWritable(FOO_OUT)).runTest();
-    thrown.expectMessage(IllegalStateException.class, "Driver reuse not allowed");
+        .withInput(new Text("foo"), new LongWritable(FOO_IN_A))
+        .withInput(new Text("foo"), new LongWritable(FOO_IN_B))
+        .withInput(new Text("bar"), new LongWritable(BAR_IN))
+        .withOutput(new Text("bar"), new LongWritable(BAR_IN))
+        .withOutput(new Text("foo"), new LongWritable(FOO_OUT)).runTest();
+    thrown.expectMessage(IllegalStateException.class,
+        "Driver reuse not allowed");
     driver.runTest();
   }
 
@@ -592,8 +498,7 @@
 
   @Test
   public void testMapInputFile() throws IOException {
-    InputPathStoringMapper<LongWritable,LongWritable> mapper =
-        new InputPathStoringMapper<LongWritable,LongWritable>();
+    InputPathStoringMapper<LongWritable, LongWritable> mapper = new InputPathStoringMapper<LongWritable, LongWritable>();
     Path mapInputPath = new Path("myfile");
     driver = MapReduceDriver.newMapReduceDriver(mapper, reducer);
     driver.setMapInputPath(mapInputPath);
@@ -606,16 +511,15 @@
 
   @Test
   public void testGroupingComparatorBehaviour1() throws IOException {
-    driver.withInput(new Text("A1"),new LongWritable(1L))
-      .withInput(new Text("A2"),new LongWritable(1L))
-      .withInput(new Text("B1"),new LongWritable(1L))
-      .withInput(new Text("B2"),new LongWritable(1L))
-      .withInput(new Text("C1"),new LongWritable(1L))
-      .withOutput(new Text("A1"),new LongWritable(2L))
-      .withOutput(new Text("B1"),new LongWritable(2L))
-      .withOutput(new Text("C1"),new LongWritable(1L))
-      .withKeyGroupingComparator(new FirstCharComparator())
-      .runTest(false);
+    driver.withInput(new Text("A1"), new LongWritable(1L))
+        .withInput(new Text("A2"), new LongWritable(1L))
+        .withInput(new Text("B1"), new LongWritable(1L))
+        .withInput(new Text("B2"), new LongWritable(1L))
+        .withInput(new Text("C1"), new LongWritable(1L))
+        .withOutput(new Text("A1"), new LongWritable(2L))
+        .withOutput(new Text("B1"), new LongWritable(2L))
+        .withOutput(new Text("C1"), new LongWritable(1L))
+        .withKeyGroupingComparator(new FirstCharComparator()).runTest(false);
   }
 
   @Test
@@ -624,34 +528,32 @@
     // grouping of reduce keys in "shuffle".
     // MapReduce doesn't group keys which aren't in a contiguous
     // range when sorted by their sorting comparator.
-    driver.withInput(new Text("1A"),new LongWritable(1L))
-      .withInput(new Text("2A"),new LongWritable(1L))
-      .withInput(new Text("1B"),new LongWritable(1L))
-      .withInput(new Text("2B"),new LongWritable(1L))
-      .withInput(new Text("1C"),new LongWritable(1L))
-      .withOutput(new Text("1A"),new LongWritable(1L))
-      .withOutput(new Text("2A"),new LongWritable(1L))
-      .withOutput(new Text("1B"),new LongWritable(1L))
-      .withOutput(new Text("2B"),new LongWritable(1L))
-      .withOutput(new Text("1C"),new LongWritable(1L))
-      .withKeyGroupingComparator(new SecondCharComparator())
-      .runTest(false);
+    driver.withInput(new Text("1A"), new LongWritable(1L))
+        .withInput(new Text("2A"), new LongWritable(1L))
+        .withInput(new Text("1B"), new LongWritable(1L))
+        .withInput(new Text("2B"), new LongWritable(1L))
+        .withInput(new Text("1C"), new LongWritable(1L))
+        .withOutput(new Text("1A"), new LongWritable(1L))
+        .withOutput(new Text("2A"), new LongWritable(1L))
+        .withOutput(new Text("1B"), new LongWritable(1L))
+        .withOutput(new Text("2B"), new LongWritable(1L))
+        .withOutput(new Text("1C"), new LongWritable(1L))
+        .withKeyGroupingComparator(new SecondCharComparator()).runTest(false);
   }
 
   @Test
   public void testGroupingComparatorSpecifiedByConf() throws IOException {
     JobConf conf = new JobConf(new Configuration());
     conf.setOutputValueGroupingComparator(FirstCharComparator.class);
-    driver.withInput(new Text("A1"),new LongWritable(1L))
-      .withInput(new Text("A2"),new LongWritable(1L))
-      .withInput(new Text("B1"),new LongWritable(1L))
-      .withInput(new Text("B2"),new LongWritable(1L))
-      .withInput(new Text("C1"),new LongWritable(1L))
-      .withOutput(new Text("A1"),new LongWritable(2L))
-      .withOutput(new Text("B1"),new LongWritable(2L))
-      .withOutput(new Text("C1"),new LongWritable(1L))
-      .withConfiguration(conf)
-      .runTest(false);
+    driver.withInput(new Text("A1"), new LongWritable(1L))
+        .withInput(new Text("A2"), new LongWritable(1L))
+        .withInput(new Text("B1"), new LongWritable(1L))
+        .withInput(new Text("B2"), new LongWritable(1L))
+        .withInput(new Text("C1"), new LongWritable(1L))
+        .withOutput(new Text("A1"), new LongWritable(2L))
+        .withOutput(new Text("B1"), new LongWritable(2L))
+        .withOutput(new Text("C1"), new LongWritable(1L))
+        .withConfiguration(conf).runTest(false);
   }
 
   @Test
@@ -660,13 +562,14 @@
         .newMapReduceDriver(new IdentityMapper<TestWritable, Text>(),
             new IdentityReducer<TestWritable, Text>());
     driver.withInput(new TestWritable("A1"), new Text("A1"))
-      .withInput(new TestWritable("A2"), new Text("A2"))
-      .withInput(new TestWritable("A3"), new Text("A3"))
-      .withKeyGroupingComparator(new TestWritable.SingleGroupComparator())
-      .withOutput(new TestWritable("A3"), new Text("A3"))
-      .withOutput(new TestWritable("A3"), new Text("A2"))
-      .withOutput(new TestWritable("A3"), new Text("A1"))
-      .runTest(true); //ordering is important
+        .withInput(new TestWritable("A2"), new Text("A2"))
+        .withInput(new TestWritable("A3"), new Text("A3"))
+        .withKeyGroupingComparator(new TestWritable.SingleGroupComparator())
+        .withOutput(new TestWritable("A3"), new Text("A3"))
+        .withOutput(new TestWritable("A3"), new Text("A2"))
+        .withOutput(new TestWritable("A3"), new Text("A1")).runTest(true); // ordering
+                                                                           // is
+                                                                           // important
   }
 
 }
diff --git a/src/test/java/org/apache/hadoop/mrunit/TestMultipleInputsMapReduceDriver.java b/src/test/java/org/apache/hadoop/mrunit/TestMultipleInputsMapReduceDriver.java
new file mode 100644
index 0000000..0a36ce8
--- /dev/null
+++ b/src/test/java/org/apache/hadoop/mrunit/TestMultipleInputsMapReduceDriver.java
@@ -0,0 +1,675 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mrunit;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.serializer.JavaSerializationComparator;
+import org.apache.hadoop.mapred.*;
+import org.apache.hadoop.mapred.lib.IdentityMapper;
+import org.apache.hadoop.mapred.lib.IdentityReducer;
+import org.apache.hadoop.mapred.lib.LongSumReducer;
+import org.apache.hadoop.mrunit.types.Pair;
+import org.apache.hadoop.mrunit.types.TestWritable;
+import org.apache.hadoop.mrunit.types.UncomparableWritable;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+import static org.apache.hadoop.mrunit.ExtendedAssert.assertListEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
+public class TestMultipleInputsMapReduceDriver {
+  @Rule
+  public final ExpectedSuppliedException thrown = ExpectedSuppliedException
+      .none();
+
+  private static final int FOO_IN_A = 42;
+  private static final int FOO_IN_B = 10;
+  private static final int TOKEN_IN_A = 1;
+  private static final int TOKEN_IN_B = 2;
+  private static final int BAR_IN = 12;
+  private static final int BAR_OUT = BAR_IN + TOKEN_IN_A + TOKEN_IN_B;
+  private static final int FOO_OUT = FOO_IN_A + FOO_IN_B + TOKEN_IN_A + 2
+      * TOKEN_IN_B;
+  private static final String TOKEN_A = "foo bar";
+  private static final String TOKEN_B = "foo foo bar";
+
+  private Mapper<Text, LongWritable, Text, LongWritable> mapper;
+  private Reducer<Text, LongWritable, Text, LongWritable> reducer;
+  private TokenMapper tokenMapper;
+  private MultipleInputsMapReduceDriver<Text, LongWritable, Text, LongWritable> driver;
+
+  @Before
+  public void setUp() {
+    mapper = new IdentityMapper<Text, LongWritable>();
+    reducer = new LongSumReducer<Text>();
+    tokenMapper = new TokenMapper();
+    driver = new MultipleInputsMapReduceDriver<Text, LongWritable, Text, LongWritable>(
+        reducer);
+    driver.addMapper(mapper);
+    driver.addMapper(tokenMapper);
+  }
+
+  @Test
+  public void testRun() throws IOException {
+    final List<Pair<Text, LongWritable>> out = driver
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_A))
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_B))
+        .withInput(mapper, new Text("bar"), new LongWritable(BAR_IN))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_A), new Text(TOKEN_A))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_B), new Text(TOKEN_B))
+        .run();
+
+    final List<Pair<Text, LongWritable>> expected = new ArrayList<Pair<Text, LongWritable>>();
+    expected.add(new Pair<Text, LongWritable>(new Text("bar"),
+        new LongWritable(BAR_OUT)));
+    expected.add(new Pair<Text, LongWritable>(new Text("foo"),
+        new LongWritable(FOO_OUT)));
+
+    assertListEquals(expected, out);
+  }
+
+  @Test
+  public void testUncomparable() throws IOException {
+    MultipleInputsMapReduceDriver<Text, Object, Text, Object> testDriver = MultipleInputsMapReduceDriver
+        .newMultipleInputMapReduceDriver(new IdentityReducer<Text, Object>());
+
+    Mapper<Text, Object, Text, Object> identity = new IdentityMapper<Text, Object>();
+    testDriver.addMapper(identity);
+    Text k1 = new Text("foo");
+    Object v1 = new UncomparableWritable(1);
+    testDriver.withInput(identity, k1, v1);
+
+    ReverseIdentityMapper<Object, Text> reverse = new ReverseIdentityMapper<Object, Text>();
+    testDriver.addMapper(reverse);
+    Text k2 = new Text("bar");
+    Object v2 = new UncomparableWritable(2);
+    testDriver.withInput(reverse, v2, k2);
+
+    testDriver.withOutput(k1, v1).withOutput(k2, v2);
+
+    testDriver.runTest(false);
+  }
+
+  @Test
+  public void testTestRun() throws IOException {
+    driver
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_A))
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_B))
+        .withInput(mapper, new Text("bar"), new LongWritable(BAR_IN))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_A), new Text(TOKEN_A))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_B), new Text(TOKEN_B))
+        .withOutput(new Text("bar"), new LongWritable(BAR_OUT))
+        .withOutput(new Text("foo"), new LongWritable(FOO_OUT)).runTest(false);
+  }
+
+  @Test
+  public void testAddAll() throws IOException {
+    final List<Pair<Text, LongWritable>> mapperInputs = new ArrayList<Pair<Text, LongWritable>>();
+    mapperInputs.add(new Pair<Text, LongWritable>(new Text("foo"),
+        new LongWritable(FOO_IN_A)));
+    mapperInputs.add(new Pair<Text, LongWritable>(new Text("foo"),
+        new LongWritable(FOO_IN_B)));
+    mapperInputs.add(new Pair<Text, LongWritable>(new Text("bar"),
+        new LongWritable(BAR_IN)));
+
+    final List<Pair<LongWritable, Text>> tokenMapperInputs = new ArrayList<Pair<LongWritable, Text>>();
+    tokenMapperInputs.add(new Pair<LongWritable, Text>(new LongWritable(
+        TOKEN_IN_A), new Text(TOKEN_A)));
+    tokenMapperInputs.add(new Pair<LongWritable, Text>(new LongWritable(
+        TOKEN_IN_B), new Text(TOKEN_B)));
+
+    final List<Pair<Text, LongWritable>> outputs = new ArrayList<Pair<Text, LongWritable>>();
+    outputs.add(new Pair<Text, LongWritable>(new Text("bar"), new LongWritable(
+        BAR_OUT)));
+    outputs.add(new Pair<Text, LongWritable>(new Text("foo"), new LongWritable(
+        FOO_OUT)));
+
+    driver.withAll(mapper, mapperInputs)
+        .withAll(tokenMapper, tokenMapperInputs).withAllOutput(outputs)
+        .runTest(false);
+  }
+
+  @Test
+  public void testNoInput() throws IOException {
+    thrown.expectMessage(IllegalStateException.class,
+        "No input was provided for mapper");
+    driver.runTest(false);
+  }
+
+  @Test
+  public void testNoInputForMapper() throws IOException {
+    MultipleInputsMapReduceDriver<Text, LongWritable, Text, LongWritable> testDriver = new MultipleInputsMapReduceDriver<Text, LongWritable, Text, LongWritable>();
+    testDriver.addMapper(mapper);
+    testDriver.addMapper(tokenMapper);
+    testDriver.withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_A));
+    thrown.expectMessage(IllegalStateException.class,
+        String.format("No input was provided for mapper %s", tokenMapper));
+    testDriver.runTest(false);
+  }
+
+  @Test
+  public void testNoReducer() throws IOException {
+    MultipleInputsMapReduceDriver<Text, LongWritable, Text, LongWritable> testDriver = new MultipleInputsMapReduceDriver<Text, LongWritable, Text, LongWritable>();
+    testDriver.addMapper(mapper);
+    testDriver.withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_A));
+    thrown.expectMessage(IllegalStateException.class,
+        "No reducer class was provided");
+    testDriver.runTest(false);
+  }
+
+  @Test
+  public void testIdentityCombiner() throws IOException {
+    driver
+        .withCombiner(new IdentityReducer<Text, LongWritable>())
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_A))
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_B))
+        .withInput(mapper, new Text("bar"), new LongWritable(BAR_IN))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_A), new Text(TOKEN_A))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_B), new Text(TOKEN_B))
+        .withOutput(new Text("foo"), new LongWritable(FOO_OUT))
+        .withOutput(new Text("bar"), new LongWritable(BAR_OUT)).runTest(false);
+  }
+
+  @Test
+  public void testLongSumCombiner() throws IOException {
+    driver
+        .withCombiner(new LongSumReducer<Text>())
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_A))
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_B))
+        .withInput(mapper, new Text("bar"), new LongWritable(BAR_IN))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_A), new Text(TOKEN_A))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_B), new Text(TOKEN_B))
+        .withOutput(new Text("foo"), new LongWritable(FOO_OUT))
+        .withOutput(new Text("bar"), new LongWritable(BAR_OUT)).runTest(false);
+  }
+
+  @Test
+  public void testLongSumCombinerAndIdentityReducer() throws IOException {
+    driver
+        .withCombiner(new LongSumReducer<Text>())
+        .withReducer(new IdentityReducer<Text, LongWritable>())
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_A))
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_B))
+        .withInput(mapper, new Text("bar"), new LongWritable(BAR_IN))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_A), new Text(TOKEN_A))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_B), new Text(TOKEN_B))
+        .withOutput(new Text("foo"), new LongWritable(FOO_OUT))
+        .withOutput(new Text("bar"), new LongWritable(BAR_OUT)).runTest(false);
+  }
+
+  @Test
+  public void testRepeatRun() throws IOException {
+    driver
+        .withCombiner(new IdentityReducer<Text, LongWritable>())
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_A))
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_B))
+        .withInput(mapper, new Text("bar"), new LongWritable(BAR_IN))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_A), new Text(TOKEN_A))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_B), new Text(TOKEN_B))
+        .withOutput(new Text("foo"), new LongWritable(FOO_OUT))
+        .withOutput(new Text("bar"), new LongWritable(BAR_OUT)).runTest(false);
+    thrown.expectMessage(IllegalStateException.class,
+        "Driver reuse not allowed");
+    driver.runTest(false);
+  }
+
+  // Test the key grouping and value ordering comparators
+  @Test
+  public void testComparators() throws IOException {
+    // reducer to track the order of the input values using bit shifting
+    driver.withReducer(new Reducer<Text, LongWritable, Text, LongWritable>() {
+      @Override
+      public void reduce(final Text key, final Iterator<LongWritable> values,
+          final OutputCollector<Text, LongWritable> output,
+          final Reporter reporter) throws IOException {
+        long outputValue = 0;
+        int count = 0;
+        while (values.hasNext()) {
+          outputValue |= (values.next().get() << (count++ * 8));
+        }
+
+        output.collect(key, new LongWritable(outputValue));
+      }
+
+      @Override
+      public void configure(final JobConf job) {
+      }
+
+      @Override
+      public void close() throws IOException {
+      }
+    });
+
+    driver
+        .withKeyGroupingComparator(new TestMapReduceDriver.FirstCharComparator());
+    driver
+        .withKeyOrderComparator(new TestMapReduceDriver.SecondCharComparator());
+
+    driver.addInput(mapper, new Text("a1"), new LongWritable(1));
+    driver.addInput(mapper, new Text("b1"), new LongWritable(1));
+    driver.addInput(mapper, new Text("a3"), new LongWritable(3));
+    driver.addInput(mapper, new Text("a2"), new LongWritable(2));
+
+    driver.addInput(tokenMapper, new LongWritable(1), new Text("c1 d1"));
+
+    driver.addOutput(new Text("a1"), new LongWritable(0x1));
+    driver.addOutput(new Text("b1"), new LongWritable(0x1));
+    driver.addOutput(new Text("a2"), new LongWritable(0x2 | (0x3 << 8)));
+    driver.addOutput(new Text("c1"), new LongWritable(0x1));
+    driver.addOutput(new Text("d1"), new LongWritable(0x1));
+
+    driver.runTest(false);
+  }
+
+  @Test
+  public void testNoMapper() throws IOException {
+    MultipleInputsMapReduceDriver<Text, LongWritable, Text, LongWritable> testDriver = new MultipleInputsMapReduceDriver<Text, LongWritable, Text, LongWritable>();
+    testDriver.withReducer(reducer);
+    thrown.expectMessage(IllegalStateException.class,
+        "No mappers were provided");
+    testDriver.runTest(false);
+  }
+
+  @Test
+  public void testWithCounter() throws IOException {
+    MultipleInputsMapReduceDriver<Text, Text, Text, Text> testDriver = new MultipleInputsMapReduceDriver<Text, Text, Text, Text>();
+    Mapper<Text, Text, Text, Text> mapperWithCounters = new TestMapDriver.MapperWithCounters<Text, Text, Text, Text>();
+    Mapper<Text, Text, Text, Text> tokenMapperWithCounters = new TokenMapperWithCounters();
+    testDriver
+        .withMapper(mapperWithCounters)
+        .withInput(mapperWithCounters, new Text("hie"), new Text("Hi"))
+        .withMapper(tokenMapperWithCounters)
+        .withInput(tokenMapperWithCounters, new Text("bie"),
+            new Text("Goodbye Bye"))
+        .withOutput(new Text("hie"), new Text("Hi"))
+        .withOutput(new Text("bie"), new Text("Goodbye"))
+        .withOutput(new Text("bie"), new Text("Bye"))
+        .withCounter(TestMapDriver.MapperWithCounters.Counters.X, 1)
+        .withCounter(TokenMapperWithCounters.Counters.Y, 2)
+        .withCounter("category", "name", 3)
+        .withReducer(
+            new TestReduceDriver.ReducerWithCounters<Text, Text, Text, Text>())
+        .withCounter(TestReduceDriver.ReducerWithCounters.Counters.COUNT, 2)
+        .withCounter(TestReduceDriver.ReducerWithCounters.Counters.SUM, 3)
+        .withCounter("category", "count", 2).withCounter("category", "sum", 3)
+        .runTest(false);
+  }
+
+  @Test
+  public void testWithCounterAndEnumCounterMissing() throws IOException {
+    MultipleInputsMapReduceDriver<Text, Text, Text, Text> testDriver = new MultipleInputsMapReduceDriver<Text, Text, Text, Text>();
+
+    thrown
+        .expectAssertionErrorMessage("2 Error(s): (Actual counter ("
+            + "\"org.apache.hadoop.mrunit.TestMapDriver$MapperWithCounters$Counters\",\"X\")"
+            + " was not found in expected counters, Actual counter ("
+            + "\"org.apache.hadoop.mrunit.TestMultipleInputsMapReduceDriver$TokenMapperWithCounters$Counters\",\"Y\")"
+            + " was not found in expected counters");
+
+    Mapper<Text, Text, Text, Text> mapperWithCounters = new TestMapDriver.MapperWithCounters<Text, Text, Text, Text>();
+    Mapper<Text, Text, Text, Text> tokenMapperWithCounters = new TokenMapperWithCounters();
+
+    testDriver
+        .withMapper(mapperWithCounters)
+        .withInput(mapperWithCounters, new Text("hie"), new Text("Hi"))
+        .withMapper(tokenMapperWithCounters)
+        .withInput(tokenMapperWithCounters, new Text("bie"),
+            new Text("Goodbye Bye"))
+        .withOutput(new Text("hie"), new Text("Hi"))
+        .withOutput(new Text("bie"), new Text("Goodbye"))
+        .withOutput(new Text("bie"), new Text("Bye"))
+        .withStrictCounterChecking()
+        .withCounter("category", "name", 3)
+        .withReducer(
+            new TestReduceDriver.ReducerWithCounters<Text, Text, Text, Text>())
+        .withCounter(TestReduceDriver.ReducerWithCounters.Counters.COUNT, 2)
+        .withCounter(TestReduceDriver.ReducerWithCounters.Counters.SUM, 3)
+        .withCounter("category", "count", 2).withCounter("category", "sum", 3)
+        .runTest(false);
+  }
+
+  @Test
+  public void testWithCounterAndStringCounterMissing() throws IOException {
+    MultipleInputsMapReduceDriver<Text, Text, Text, Text> testDriver = new MultipleInputsMapReduceDriver<Text, Text, Text, Text>();
+
+    thrown.expectAssertionErrorMessage("1 Error(s): (Actual counter ("
+        + "\"category\",\"name\")" + " was not found in expected counters");
+
+    Mapper<Text, Text, Text, Text> mapperWithCounters = new TestMapDriver.MapperWithCounters<Text, Text, Text, Text>();
+    Mapper<Text, Text, Text, Text> tokenMapperWithCounters = new TokenMapperWithCounters();
+
+    testDriver
+        .withMapper(mapperWithCounters)
+        .withInput(mapperWithCounters, new Text("hie"), new Text("Hi"))
+        .withMapper(tokenMapperWithCounters)
+        .withInput(tokenMapperWithCounters, new Text("bie"),
+            new Text("Goodbye Bye"))
+        .withOutput(new Text("hie"), new Text("Hi"))
+        .withOutput(new Text("bie"), new Text("Goodbye"))
+        .withOutput(new Text("bie"), new Text("Bye"))
+        .withStrictCounterChecking()
+        .withCounter(TestMapDriver.MapperWithCounters.Counters.X, 1)
+        .withCounter(TokenMapperWithCounters.Counters.Y, 2)
+        .withReducer(
+            new TestReduceDriver.ReducerWithCounters<Text, Text, Text, Text>())
+        .withCounter(TestReduceDriver.ReducerWithCounters.Counters.COUNT, 2)
+        .withCounter(TestReduceDriver.ReducerWithCounters.Counters.SUM, 3)
+        .withCounter("category", "count", 2).withCounter("category", "sum", 3)
+        .runTest(false);
+  }
+
+  @Test
+  public void testWithFailedCounter() throws IOException {
+    MultipleInputsMapReduceDriver<Text, Text, Text, Text> testDriver = new MultipleInputsMapReduceDriver<Text, Text, Text, Text>();
+
+    thrown
+        .expectAssertionErrorMessage("3 Error(s): ("
+            + "Counter org.apache.hadoop.mrunit.TestMapDriver.MapperWithCounters.Counters.X has value 1 instead of expected 20, "
+            + "Counter org.apache.hadoop.mrunit.TestMultipleInputsMapReduceDriver.TokenMapperWithCounters.Counters.Y has value 2 instead of expected 30, "
+            + "Counter with category category and name name has value 3 instead of expected 20)");
+
+    Mapper<Text, Text, Text, Text> mapperWithCounters = new TestMapDriver.MapperWithCounters<Text, Text, Text, Text>();
+    Mapper<Text, Text, Text, Text> tokenMapperWithCounters = new TokenMapperWithCounters();
+
+    testDriver
+        .withMapper(mapperWithCounters)
+        .withInput(mapperWithCounters, new Text("hie"), new Text("Hi"))
+        .withMapper(tokenMapperWithCounters)
+        .withInput(tokenMapperWithCounters, new Text("bie"),
+            new Text("Goodbye Bye"))
+        .withOutput(new Text("hie"), new Text("Hi"))
+        .withOutput(new Text("bie"), new Text("Goodbye"))
+        .withOutput(new Text("bie"), new Text("Bye"))
+        .withCounter(TestMapDriver.MapperWithCounters.Counters.X, 20)
+        .withCounter(TokenMapperWithCounters.Counters.Y, 30)
+        .withReducer(
+            new TestReduceDriver.ReducerWithCounters<Text, Text, Text, Text>())
+        .withCounter("category", "name", 20).runTest(false);
+  }
+
+  @Test
+  public void testJavaSerialization() throws IOException {
+    final Configuration conf = new Configuration();
+    conf.setStrings("io.serializations", conf.get("io.serializations"),
+        "org.apache.hadoop.io.serializer.JavaSerialization");
+    final MultipleInputsMapReduceDriver<Integer, IntWritable, Integer, IntWritable> testDriver = MultipleInputsMapReduceDriver
+        .newMultipleInputMapReduceDriver(
+                new IdentityReducer<Integer, IntWritable>())
+        .withConfiguration(conf);
+    Mapper<Integer, IntWritable, Integer, IntWritable> identityMapper = new IdentityMapper<Integer, IntWritable>();
+    Mapper<Integer, IntWritable, Integer, IntWritable> anotherIdentityMapper = new IdentityMapper<Integer, IntWritable>();
+    testDriver.addMapper(identityMapper);
+    testDriver.withInput(identityMapper, 1, new IntWritable(2)).withInput(
+        identityMapper, 2, new IntWritable(3));
+    testDriver.addMapper(anotherIdentityMapper);
+    testDriver.withInput(anotherIdentityMapper, 3, new IntWritable(4))
+        .withInput(anotherIdentityMapper, 4, new IntWritable(5));
+    testDriver
+        .withKeyOrderComparator(new JavaSerializationComparator<Integer>());
+    testDriver
+        .withKeyGroupingComparator(TestMapReduceDriver.INTEGER_COMPARATOR);
+
+    testDriver.withOutput(1, new IntWritable(2))
+        .withOutput(2, new IntWritable(3)).withOutput(3, new IntWritable(4))
+        .withOutput(4, new IntWritable(5));
+    testDriver.runTest(false);
+  }
+
+  @Test
+  public void testCopy() throws IOException {
+    final Text key = new Text("a");
+    final LongWritable value = new LongWritable(1);
+    driver.addInput(mapper, key, value);
+    key.set("b");
+    value.set(2);
+    driver.addInput(mapper, key, value);
+
+    key.set("a");
+    value.set(1);
+    driver.addOutput(key, value);
+    key.set("b");
+    value.set(2);
+    driver.addOutput(key, value);
+
+    final LongWritable longKey = new LongWritable(3);
+    final Text textValue = new Text("c d");
+    driver.addInput(tokenMapper, longKey, textValue);
+    longKey.set(4);
+    textValue.set("e f g");
+    driver.addInput(tokenMapper, longKey, textValue);
+
+    key.set("c");
+    value.set(3);
+    driver.addOutput(key, value);
+    key.set("d");
+    value.set(3);
+    driver.addOutput(key, value);
+    key.set("e");
+    value.set(4);
+    driver.addOutput(key, value);
+    key.set("f");
+    value.set(4);
+    driver.addOutput(key, value);
+    key.set("g");
+    value.set(4);
+    driver.addOutput(key, value);
+
+    driver.runTest(false);
+  }
+
+  @Test
+  public void testOutputFormat() throws IOException {
+    driver.withInputFormat(SequenceFileInputFormat.class);
+    driver.withOutputFormat(SequenceFileOutputFormat.class);
+    driver.withInput(mapper, new Text("a"), new LongWritable(1));
+    driver.withInput(mapper, new Text("a"), new LongWritable(2));
+    driver.withInput(tokenMapper, new LongWritable(3), new Text("a b"));
+    driver.withOutput(new Text("a"), new LongWritable(6));
+    driver.withOutput(new Text("b"), new LongWritable(3));
+    driver.runTest(false);
+  }
+
+  @SuppressWarnings({ "unchecked", "rawtypes" })
+  @Test
+  public void testOutputFormatWithMismatchInOutputClasses() throws IOException {
+    final MultipleInputsMapReduceDriver testDriver = this.driver;
+    testDriver.withInputFormat(TextInputFormat.class);
+    testDriver.withOutputFormat(TextOutputFormat.class);
+    testDriver.withInput(mapper, new Text("a"), new LongWritable(1));
+    testDriver.withInput(mapper, new Text("a"), new LongWritable(2));
+    testDriver.withInput(tokenMapper, new LongWritable(3), new Text("a b"));
+    testDriver.withOutput(new LongWritable(0), new Text("a\t6"));
+    testDriver.withOutput(new LongWritable(4), new Text("b\t3"));
+    testDriver.runTest(false);
+  }
+
+  @Test
+  public void testMapInputFile() throws IOException {
+    MultipleInputsMapReduceDriver<Text, LongWritable, Text, LongWritable> testDriver = new MultipleInputsMapReduceDriver<Text, LongWritable, Text, LongWritable>(
+        reducer);
+
+    InputPathStoringMapper<LongWritable, LongWritable> inputPathStoringMapper = new InputPathStoringMapper<LongWritable, LongWritable>();
+    Path mapInputPath = new Path("myfile");
+    testDriver.addMapper(inputPathStoringMapper);
+    testDriver.setMapInputPath(inputPathStoringMapper, mapInputPath);
+    assertEquals(mapInputPath.getName(),
+        testDriver.getMapInputPath(inputPathStoringMapper).getName());
+    testDriver.withInput(inputPathStoringMapper, new Text("a"),
+        new LongWritable(1));
+
+    InputPathStoringMapper<LongWritable, LongWritable> anotherInputPathStoringMapper = new InputPathStoringMapper<LongWritable, LongWritable>();
+    Path anotherMapInputPath = new Path("myotherfile");
+    testDriver.addMapper(anotherInputPathStoringMapper);
+    testDriver.setMapInputPath(anotherInputPathStoringMapper,
+        anotherMapInputPath);
+    assertEquals(anotherMapInputPath.getName(),
+        testDriver.getMapInputPath(anotherInputPathStoringMapper).getName());
+    testDriver.withInput(anotherInputPathStoringMapper, new Text("b"),
+        new LongWritable(2));
+
+    testDriver.runTest(false);
+    assertNotNull(inputPathStoringMapper.getMapInputPath());
+    assertEquals(mapInputPath.getName(), inputPathStoringMapper
+        .getMapInputPath().getName());
+  }
+
+  @Test
+  public void testGroupComparatorBehaviorFirst() throws IOException {
+    driver
+        .withInput(mapper, new Text("A1"), new LongWritable(1L))
+        .withInput(mapper, new Text("A2"), new LongWritable(1L))
+        .withInput(mapper, new Text("B1"), new LongWritable(1L))
+        .withInput(mapper, new Text("B2"), new LongWritable(1L))
+        .withInput(mapper, new Text("C1"), new LongWritable(1L))
+        .withInput(tokenMapper, new LongWritable(3L), new Text("D1 D2 E1"))
+        .withOutput(new Text("A1"), new LongWritable(2L))
+        .withOutput(new Text("B1"), new LongWritable(2L))
+        .withOutput(new Text("C1"), new LongWritable(1L))
+        .withOutput(new Text("D1"), new LongWritable(6L))
+        .withOutput(new Text("E1"), new LongWritable(3L))
+        .withKeyGroupingComparator(
+            new TestMapReduceDriver.FirstCharComparator()).runTest(false);
+  }
+
+  @Test
+  public void testGroupComparatorBehaviorSecond() throws IOException {
+    driver
+        .withInput(mapper, new Text("1A"), new LongWritable(1L))
+        .withInput(mapper, new Text("2A"), new LongWritable(1L))
+        .withInput(mapper, new Text("1B"), new LongWritable(1L))
+        .withInput(mapper, new Text("2B"), new LongWritable(1L))
+        .withInput(mapper, new Text("1C"), new LongWritable(1L))
+        .withInput(tokenMapper, new LongWritable(2L), new Text("1D 2D 1E"))
+        .withOutput(new Text("1A"), new LongWritable(1L))
+        .withOutput(new Text("2A"), new LongWritable(1L))
+        .withOutput(new Text("1B"), new LongWritable(1L))
+        .withOutput(new Text("2B"), new LongWritable(1L))
+        .withOutput(new Text("1C"), new LongWritable(1L))
+        .withOutput(new Text("1D"), new LongWritable(2L))
+        .withOutput(new Text("2D"), new LongWritable(2L))
+        .withOutput(new Text("1E"), new LongWritable(2L))
+        .withKeyGroupingComparator(
+            new TestMapReduceDriver.SecondCharComparator()).runTest(false);
+  }
+
+  @Test
+  public void testGroupingComparatorSpecifiedByConf() throws IOException {
+    JobConf conf = new JobConf(new Configuration());
+    conf.setOutputValueGroupingComparator(TestMapReduceDriver.FirstCharComparator.class);
+    driver.withInput(mapper, new Text("A1"), new LongWritable(1L))
+        .withInput(mapper, new Text("A2"), new LongWritable(1L))
+        .withInput(mapper, new Text("B1"), new LongWritable(1L))
+        .withInput(mapper, new Text("B2"), new LongWritable(1L))
+        .withInput(mapper, new Text("C1"), new LongWritable(1L))
+        .withInput(tokenMapper, new LongWritable(3L), new Text("D1 D2 E1"))
+        .withOutput(new Text("A1"), new LongWritable(2L))
+        .withOutput(new Text("B1"), new LongWritable(2L))
+        .withOutput(new Text("C1"), new LongWritable(1L))
+        .withOutput(new Text("D1"), new LongWritable(6L))
+        .withOutput(new Text("E1"), new LongWritable(3L))
+        .withConfiguration(conf).runTest(false);
+  }
+
+  @SuppressWarnings("unchecked")
+  @Test
+  public void testUseOfWritableRegisteredComparator() throws IOException {
+    MultipleInputsMapReduceDriver<TestWritable, Text, TestWritable, Text> testDriver = new MultipleInputsMapReduceDriver<TestWritable, Text, TestWritable, Text>(
+        new IdentityReducer<TestWritable, Text>());
+
+    IdentityMapper<TestWritable, Text> identityMapper = new IdentityMapper<TestWritable, Text>();
+    IdentityMapper<TestWritable, Text> anotherIdentityMapper = new IdentityMapper<TestWritable, Text>();
+    testDriver.addMapper(identityMapper);
+    testDriver.addMapper(anotherIdentityMapper);
+    testDriver
+        .withInput(identityMapper, new TestWritable("A1"), new Text("A1"))
+        .withInput(identityMapper, new TestWritable("A2"), new Text("A2"))
+        .withInput(identityMapper, new TestWritable("A3"), new Text("A3"))
+        .withInput(anotherIdentityMapper, new TestWritable("B1"),
+            new Text("B1"))
+        .withInput(anotherIdentityMapper, new TestWritable("B2"),
+            new Text("B2"))
+        .withKeyGroupingComparator(new TestWritable.SingleGroupComparator())
+        .withOutput(new TestWritable("B2"), new Text("B2"))
+        .withOutput(new TestWritable("B2"), new Text("B1"))
+        .withOutput(new TestWritable("B2"), new Text("A3"))
+        .withOutput(new TestWritable("B2"), new Text("A2"))
+        .withOutput(new TestWritable("B2"), new Text("A1")).runTest(true); // ordering
+                                                                           // is
+                                                                           // important
+  }
+
+  static class TokenMapperWithCounters extends MapReduceBase implements
+      Mapper<Text, Text, Text, Text> {
+    private final Text output = new Text();
+
+    @Override
+    public void map(Text key, Text value,
+        OutputCollector<Text, Text> collector, Reporter reporter)
+        throws IOException {
+      String[] tokens = value.toString().split("\\s");
+      for (String token : tokens) {
+        output.set(token);
+        collector.collect(key, output);
+        reporter.getCounter(Counters.Y).increment(1);
+        reporter.getCounter("category", "name").increment(1);
+      }
+    }
+
+    public static enum Counters {
+      Y
+    }
+  }
+
+  static class TokenMapper extends MapReduceBase implements
+      Mapper<LongWritable, Text, Text, LongWritable> {
+    private final Text output = new Text();
+
+    @Override
+    public void map(LongWritable longWritable, Text text,
+        OutputCollector<Text, LongWritable> textLongWritableOutputCollector,
+        Reporter reporter) throws IOException {
+      String[] tokens = text.toString().split("\\s");
+      for (String token : tokens) {
+        output.set(token);
+        textLongWritableOutputCollector.collect(output, longWritable);
+      }
+    }
+  }
+
+  static class ReverseIdentityMapper<KEYIN, VALUEIN> extends MapReduceBase
+      implements Mapper<KEYIN, VALUEIN, VALUEIN, KEYIN> {
+    @Override
+    public void map(KEYIN key, VALUEIN value,
+        OutputCollector<VALUEIN, KEYIN> vkOutputCollector, Reporter reporter)
+        throws IOException {
+      vkOutputCollector.collect(value, key);
+    }
+  }
+}
diff --git a/src/test/java/org/apache/hadoop/mrunit/mapreduce/TestMapReduceDriver.java b/src/test/java/org/apache/hadoop/mrunit/mapreduce/TestMapReduceDriver.java
index 0f6a521..bbd0c82 100644
--- a/src/test/java/org/apache/hadoop/mrunit/mapreduce/TestMapReduceDriver.java
+++ b/src/test/java/org/apache/hadoop/mrunit/mapreduce/TestMapReduceDriver.java
@@ -17,13 +17,6 @@
  */
 package org.apache.hadoop.mrunit.mapreduce;
 
-import static org.apache.hadoop.mrunit.ExtendedAssert.assertListEquals;
-import static org.junit.Assert.*;
-
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.IntWritable;
@@ -31,8 +24,6 @@
 import org.apache.hadoop.io.NullWritable;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.io.serializer.JavaSerializationComparator;
-import org.apache.hadoop.mapred.lib.IdentityMapper;
-import org.apache.hadoop.mapred.lib.IdentityReducer;
 import org.apache.hadoop.mapreduce.Mapper;
 import org.apache.hadoop.mapreduce.Reducer;
 import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
@@ -43,7 +34,6 @@
 import org.apache.hadoop.mapreduce.lib.reduce.IntSumReducer;
 import org.apache.hadoop.mapreduce.lib.reduce.LongSumReducer;
 import org.apache.hadoop.mrunit.ExpectedSuppliedException;
-import org.apache.hadoop.mrunit.mapreduce.MapReduceDriver;
 import org.apache.hadoop.mrunit.TestMapReduceDriver.FirstCharComparator;
 import org.apache.hadoop.mrunit.TestMapReduceDriver.SecondCharComparator;
 import org.apache.hadoop.mrunit.mapreduce.TestMapDriver.ConfigurationMapper;
@@ -56,6 +46,14 @@
 import org.junit.Rule;
 import org.junit.Test;
 
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.apache.hadoop.mrunit.ExtendedAssert.assertListEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
 public class TestMapReduceDriver {
 
   private static final int FOO_IN_A = 42;
@@ -103,10 +101,10 @@
   public void testUncomparable() throws IOException {
     Text k = new Text("test");
     Object v = new UncomparableWritable(2);
-    MapReduceDriver.newMapReduceDriver(
-        new Mapper<Text, Object,Text, Object>(),
-        new Reducer<Text, Object,Text, Object>())
-        .withInput(k, v).withOutput(k, v).runTest();
+    MapReduceDriver
+        .newMapReduceDriver(new Mapper<Text, Object, Text, Object>(),
+            new Reducer<Text, Object, Text, Object>()).withInput(k, v)
+        .withOutput(k, v).runTest();
   }
 
   @Test
@@ -130,8 +128,10 @@
   @Test
   public void testTestRun3() throws IOException {
     thrown.expectAssertionErrorMessage("2 Error(s)");
-    thrown.expectAssertionErrorMessage("Missing expected output (foo, 52) at position 0, got (bar, 12).");
-    thrown.expectAssertionErrorMessage("Missing expected output (bar, 12) at position 1, got (foo, 52).");
+    thrown
+        .expectAssertionErrorMessage("Missing expected output (foo, 52) at position 0, got (bar, 12).");
+    thrown
+        .expectAssertionErrorMessage("Missing expected output (bar, 12) at position 1, got (foo, 52).");
     driver.withInput(new Text("foo"), new LongWritable(FOO_IN_A))
         .withInput(new Text("bar"), new LongWritable(BAR_IN))
         .withInput(new Text("foo"), new LongWritable(FOO_IN_B))
@@ -142,13 +142,18 @@
   @Test
   public void testAddAll() throws IOException {
     final List<Pair<Text, LongWritable>> inputs = new ArrayList<Pair<Text, LongWritable>>();
-    inputs.add(new Pair<Text, LongWritable>(new Text("foo"), new LongWritable(FOO_IN_A)));
-    inputs.add(new Pair<Text, LongWritable>(new Text("foo"), new LongWritable(FOO_IN_B)));
-    inputs.add(new Pair<Text, LongWritable>(new Text("bar"), new LongWritable(BAR_IN)));
+    inputs.add(new Pair<Text, LongWritable>(new Text("foo"), new LongWritable(
+        FOO_IN_A)));
+    inputs.add(new Pair<Text, LongWritable>(new Text("foo"), new LongWritable(
+        FOO_IN_B)));
+    inputs.add(new Pair<Text, LongWritable>(new Text("bar"), new LongWritable(
+        BAR_IN)));
 
     final List<Pair<Text, LongWritable>> outputs = new ArrayList<Pair<Text, LongWritable>>();
-    outputs.add(new Pair<Text, LongWritable>(new Text("bar"), new LongWritable(BAR_IN)));
-    outputs.add(new Pair<Text, LongWritable>(new Text("foo"), new LongWritable(FOO_OUT)));
+    outputs.add(new Pair<Text, LongWritable>(new Text("bar"), new LongWritable(
+        BAR_IN)));
+    outputs.add(new Pair<Text, LongWritable>(new Text("foo"), new LongWritable(
+        FOO_OUT)));
 
     driver.withAll(inputs).withAllOutput(outputs).runTest();
   }
@@ -170,100 +175,10 @@
   }
 
   @Test
-  public void testEmptyShuffle() {
-    final List<Pair<Text, Text>> inputs = new ArrayList<Pair<Text, Text>>();
-    final List<Pair<Text, List<Text>>> outputs = driver2.shuffle(inputs);
-    assertEquals(0, outputs.size());
-  }
-
-  // just shuffle a single (k, v) pair
-  @Test
-  public void testSingleShuffle() {
-    final List<Pair<Text, Text>> inputs = new ArrayList<Pair<Text, Text>>();
-    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("b")));
-
-    final List<Pair<Text, List<Text>>> outputs = driver2.shuffle(inputs);
-
-    final List<Pair<Text, List<Text>>> expected = new ArrayList<Pair<Text, List<Text>>>();
-    final List<Text> sublist = new ArrayList<Text>();
-    sublist.add(new Text("b"));
-    expected.add(new Pair<Text, List<Text>>(new Text("a"), sublist));
-
-    assertListEquals(expected, outputs);
-  }
-
-  // shuffle multiple values from the same key.
-  @Test
-  public void testShuffleOneKey() {
-    final List<Pair<Text, Text>> inputs = new ArrayList<Pair<Text, Text>>();
-    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("b")));
-    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("c")));
-
-    final List<Pair<Text, List<Text>>> outputs = driver2.shuffle(inputs);
-
-    final List<Pair<Text, List<Text>>> expected = new ArrayList<Pair<Text, List<Text>>>();
-    final List<Text> sublist = new ArrayList<Text>();
-    sublist.add(new Text("b"));
-    sublist.add(new Text("c"));
-    expected.add(new Pair<Text, List<Text>>(new Text("a"), sublist));
-
-    assertListEquals(expected, outputs);
-  }
-
-  // shuffle multiple keys
-  @Test
-  public void testMultiShuffle1() {
-    final List<Pair<Text, Text>> inputs = new ArrayList<Pair<Text, Text>>();
-    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("x")));
-    inputs.add(new Pair<Text, Text>(new Text("b"), new Text("z")));
-    inputs.add(new Pair<Text, Text>(new Text("b"), new Text("w")));
-    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("y")));
-
-    final List<Pair<Text, List<Text>>> outputs = driver2.shuffle(inputs);
-
-    final List<Pair<Text, List<Text>>> expected = new ArrayList<Pair<Text, List<Text>>>();
-    final List<Text> sublist1 = new ArrayList<Text>();
-    sublist1.add(new Text("x"));
-    sublist1.add(new Text("y"));
-    expected.add(new Pair<Text, List<Text>>(new Text("a"), sublist1));
-
-    final List<Text> sublist2 = new ArrayList<Text>();
-    sublist2.add(new Text("z"));
-    sublist2.add(new Text("w"));
-    expected.add(new Pair<Text, List<Text>>(new Text("b"), sublist2));
-
-    assertListEquals(expected, outputs);
-  }
-
-  // shuffle multiple keys that are out-of-order to start.
-  @Test
-  public void testMultiShuffle2() {
-    final List<Pair<Text, Text>> inputs = new ArrayList<Pair<Text, Text>>();
-    inputs.add(new Pair<Text, Text>(new Text("b"), new Text("z")));
-    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("x")));
-    inputs.add(new Pair<Text, Text>(new Text("b"), new Text("w")));
-    inputs.add(new Pair<Text, Text>(new Text("a"), new Text("y")));
-
-    final List<Pair<Text, List<Text>>> outputs = driver2.shuffle(inputs);
-
-    final List<Pair<Text, List<Text>>> expected = new ArrayList<Pair<Text, List<Text>>>();
-    final List<Text> sublist1 = new ArrayList<Text>();
-    sublist1.add(new Text("x"));
-    sublist1.add(new Text("y"));
-    expected.add(new Pair<Text, List<Text>>(new Text("a"), sublist1));
-
-    final List<Text> sublist2 = new ArrayList<Text>();
-    sublist2.add(new Text("z"));
-    sublist2.add(new Text("w"));
-    expected.add(new Pair<Text, List<Text>>(new Text("b"), sublist2));
-
-    assertListEquals(expected, outputs);
-  }
-
-  @Test
   public void testEmptySortAndGroup() {
     final List<Pair<Text, Text>> inputs = new ArrayList<Pair<Text, Text>>();
-    final List<KeyValueReuseList<Text, Text>> outputs = driver2.sortAndGroup(inputs);
+    final List<KeyValueReuseList<Text, Text>> outputs = driver2
+        .sortAndGroup(inputs);
     assertEquals(0, outputs.size());
   }
 
@@ -273,10 +188,12 @@
     final List<Pair<Text, Text>> inputs = new ArrayList<Pair<Text, Text>>();
     inputs.add(new Pair<Text, Text>(new Text("a"), new Text("b")));
 
-    final List<KeyValueReuseList<Text, Text>> outputs = driver2.sortAndGroup(inputs);
+    final List<KeyValueReuseList<Text, Text>> outputs = driver2
+        .sortAndGroup(inputs);
 
     final List<KeyValueReuseList<Text, Text>> expected = new ArrayList<KeyValueReuseList<Text, Text>>();
-    final KeyValueReuseList<Text, Text> sublist = new KeyValueReuseList<Text, Text>(new Text(), new Text(), driver2.getConfiguration());
+    final KeyValueReuseList<Text, Text> sublist = new KeyValueReuseList<Text, Text>(
+        new Text(), new Text(), driver2.getConfiguration());
     sublist.add(new Pair<Text, Text>(new Text("a"), new Text("b")));
     expected.add(sublist);
 
@@ -290,10 +207,12 @@
     inputs.add(new Pair<Text, Text>(new Text("a"), new Text("b")));
     inputs.add(new Pair<Text, Text>(new Text("a"), new Text("c")));
 
-    final List<KeyValueReuseList<Text, Text>> outputs = driver2.sortAndGroup(inputs);
+    final List<KeyValueReuseList<Text, Text>> outputs = driver2
+        .sortAndGroup(inputs);
 
     final List<KeyValueReuseList<Text, Text>> expected = new ArrayList<KeyValueReuseList<Text, Text>>();
-    final KeyValueReuseList<Text, Text> sublist = new KeyValueReuseList<Text, Text>(new Text(), new Text(), driver2.getConfiguration());
+    final KeyValueReuseList<Text, Text> sublist = new KeyValueReuseList<Text, Text>(
+        new Text(), new Text(), driver2.getConfiguration());
     sublist.add(new Pair<Text, Text>(new Text("a"), new Text("b")));
     sublist.add(new Pair<Text, Text>(new Text("a"), new Text("c")));
     expected.add(sublist);
@@ -310,15 +229,18 @@
     inputs.add(new Pair<Text, Text>(new Text("b"), new Text("w")));
     inputs.add(new Pair<Text, Text>(new Text("a"), new Text("y")));
 
-    final List<KeyValueReuseList<Text, Text>> outputs = driver2.sortAndGroup(inputs);
+    final List<KeyValueReuseList<Text, Text>> outputs = driver2
+        .sortAndGroup(inputs);
 
     final List<KeyValueReuseList<Text, Text>> expected = new ArrayList<KeyValueReuseList<Text, Text>>();
-    final KeyValueReuseList<Text, Text> sublist1 = new KeyValueReuseList<Text, Text>(new Text(), new Text(), driver2.getConfiguration());
+    final KeyValueReuseList<Text, Text> sublist1 = new KeyValueReuseList<Text, Text>(
+        new Text(), new Text(), driver2.getConfiguration());
     sublist1.add(new Pair<Text, Text>(new Text("a"), new Text("x")));
     sublist1.add(new Pair<Text, Text>(new Text("a"), new Text("y")));
     expected.add(sublist1);
 
-    final KeyValueReuseList<Text, Text> sublist2 = new KeyValueReuseList<Text, Text>(new Text(), new Text(), driver2.getConfiguration());
+    final KeyValueReuseList<Text, Text> sublist2 = new KeyValueReuseList<Text, Text>(
+        new Text(), new Text(), driver2.getConfiguration());
     sublist2.add(new Pair<Text, Text>(new Text("b"), new Text("z")));
     sublist2.add(new Pair<Text, Text>(new Text("b"), new Text("w")));
     expected.add(sublist2);
@@ -335,15 +257,18 @@
     inputs.add(new Pair<Text, Text>(new Text("b"), new Text("w")));
     inputs.add(new Pair<Text, Text>(new Text("a"), new Text("y")));
 
-    final List<KeyValueReuseList<Text, Text>> outputs = driver2.sortAndGroup(inputs);
+    final List<KeyValueReuseList<Text, Text>> outputs = driver2
+        .sortAndGroup(inputs);
 
     final List<KeyValueReuseList<Text, Text>> expected = new ArrayList<KeyValueReuseList<Text, Text>>();
-    final KeyValueReuseList<Text, Text> sublist1 = new KeyValueReuseList<Text, Text>(new Text(), new Text(), driver2.getConfiguration());
+    final KeyValueReuseList<Text, Text> sublist1 = new KeyValueReuseList<Text, Text>(
+        new Text(), new Text(), driver2.getConfiguration());
     sublist1.add(new Pair<Text, Text>(new Text("a"), new Text("x")));
     sublist1.add(new Pair<Text, Text>(new Text("a"), new Text("y")));
     expected.add(sublist1);
 
-    final KeyValueReuseList<Text, Text> sublist2 = new KeyValueReuseList<Text, Text>(new Text(), new Text(), driver2.getConfiguration());
+    final KeyValueReuseList<Text, Text> sublist2 = new KeyValueReuseList<Text, Text>(
+        new Text(), new Text(), driver2.getConfiguration());
     sublist2.add(new Pair<Text, Text>(new Text("b"), new Text("z")));
     sublist2.add(new Pair<Text, Text>(new Text("b"), new Text("w")));
     expected.add(sublist2);
@@ -614,16 +539,15 @@
 
   @Test
   public void testGroupingComparatorBehaviour1() throws IOException {
-    driver.withInput(new Text("A1"),new LongWritable(1L))
-      .withInput(new Text("A2"),new LongWritable(1L))
-      .withInput(new Text("B1"),new LongWritable(1L))
-      .withInput(new Text("B2"),new LongWritable(1L))
-      .withInput(new Text("C1"),new LongWritable(1L))
-      .withOutput(new Text("A2"),new LongWritable(2L))
-      .withOutput(new Text("B2"),new LongWritable(2L))
-      .withOutput(new Text("C1"),new LongWritable(1L))
-      .withKeyGroupingComparator(new FirstCharComparator())
-      .runTest(false);
+    driver.withInput(new Text("A1"), new LongWritable(1L))
+        .withInput(new Text("A2"), new LongWritable(1L))
+        .withInput(new Text("B1"), new LongWritable(1L))
+        .withInput(new Text("B2"), new LongWritable(1L))
+        .withInput(new Text("C1"), new LongWritable(1L))
+        .withOutput(new Text("A2"), new LongWritable(2L))
+        .withOutput(new Text("B2"), new LongWritable(2L))
+        .withOutput(new Text("C1"), new LongWritable(1L))
+        .withKeyGroupingComparator(new FirstCharComparator()).runTest(false);
   }
 
   @Test
@@ -632,18 +556,17 @@
     // grouping of reduce keys in "shuffle".
     // MapReduce doesn't group keys which aren't in a contiguous
     // range when sorted by their sorting comparator.
-    driver.withInput(new Text("1A"),new LongWritable(1L))
-      .withInput(new Text("2A"),new LongWritable(1L))
-      .withInput(new Text("1B"),new LongWritable(1L))
-      .withInput(new Text("2B"),new LongWritable(1L))
-      .withInput(new Text("1C"),new LongWritable(1L))
-      .withOutput(new Text("1A"),new LongWritable(1L))
-      .withOutput(new Text("2A"),new LongWritable(1L))
-      .withOutput(new Text("1B"),new LongWritable(1L))
-      .withOutput(new Text("2B"),new LongWritable(1L))
-      .withOutput(new Text("1C"),new LongWritable(1L))
-      .withKeyGroupingComparator(new SecondCharComparator())
-      .runTest(false);
+    driver.withInput(new Text("1A"), new LongWritable(1L))
+        .withInput(new Text("2A"), new LongWritable(1L))
+        .withInput(new Text("1B"), new LongWritable(1L))
+        .withInput(new Text("2B"), new LongWritable(1L))
+        .withInput(new Text("1C"), new LongWritable(1L))
+        .withOutput(new Text("1A"), new LongWritable(1L))
+        .withOutput(new Text("2A"), new LongWritable(1L))
+        .withOutput(new Text("1B"), new LongWritable(1L))
+        .withOutput(new Text("2B"), new LongWritable(1L))
+        .withOutput(new Text("1C"), new LongWritable(1L))
+        .withKeyGroupingComparator(new SecondCharComparator()).runTest(false);
   }
 
   @Test
@@ -651,33 +574,36 @@
 
     // this test should use the comparator registered inside TestWritable
     // to output the keys in reverse order
-    MapReduceDriver<TestWritable,Text,TestWritable,Text,TestWritable,Text> driver
-      = MapReduceDriver.newMapReduceDriver(new Mapper(), new Reducer());
+    MapReduceDriver<TestWritable, Text, TestWritable, Text, TestWritable, Text> driver = MapReduceDriver
+        .newMapReduceDriver(new Mapper(), new Reducer());
 
-    driver.withInput(new TestWritable("A1"), new Text("A1"))
-      .withInput(new TestWritable("A2"), new Text("A2"))
-      .withInput(new TestWritable("A3"), new Text("A3"))
-      .withKeyGroupingComparator(new TestWritable.SingleGroupComparator())
-      // these output keys are incorrect because of MRUNIT-129
-      //.withOutput(new TestWritable("A3"), new Text("A3"))
-      //.withOutput(new TestWritable("A3"), new Text("A2"))
-      //.withOutput(new TestWritable("A3"), new Text("A1"))
-      //the following are the actual correct outputs
-      .withOutput(new TestWritable("A3"), new Text("A3"))
-      .withOutput(new TestWritable("A2"), new Text("A2"))
-      .withOutput(new TestWritable("A1"), new Text("A1"))
-      .runTest(true); //ordering is important
+    driver
+        .withInput(new TestWritable("A1"), new Text("A1"))
+        .withInput(new TestWritable("A2"), new Text("A2"))
+        .withInput(new TestWritable("A3"), new Text("A3"))
+        .withKeyGroupingComparator(new TestWritable.SingleGroupComparator())
+        // these output keys are incorrect because of MRUNIT-129
+        // .withOutput(new TestWritable("A3"), new Text("A3"))
+        // .withOutput(new TestWritable("A3"), new Text("A2"))
+        // .withOutput(new TestWritable("A3"), new Text("A1"))
+        // the following are the actual correct outputs
+        .withOutput(new TestWritable("A3"), new Text("A3"))
+        .withOutput(new TestWritable("A2"), new Text("A2"))
+        .withOutput(new TestWritable("A1"), new Text("A1")).runTest(true); // ordering
+                                                                           // is
+                                                                           // important
   }
 
   @Test
   public void testRepeatRun() throws IOException {
     driver.withCombiner(new Reducer<Text, LongWritable, Text, LongWritable>())
-            .withInput(new Text("foo"), new LongWritable(FOO_IN_A))
-            .withInput(new Text("foo"), new LongWritable(FOO_IN_B))
-            .withInput(new Text("bar"), new LongWritable(BAR_IN))
-            .withOutput(new Text("bar"), new LongWritable(BAR_IN))
-            .withOutput(new Text("foo"), new LongWritable(FOO_OUT)).runTest();
-    thrown.expectMessage(IllegalStateException.class, "Driver reuse not allowed");
+        .withInput(new Text("foo"), new LongWritable(FOO_IN_A))
+        .withInput(new Text("foo"), new LongWritable(FOO_IN_B))
+        .withInput(new Text("bar"), new LongWritable(BAR_IN))
+        .withOutput(new Text("bar"), new LongWritable(BAR_IN))
+        .withOutput(new Text("foo"), new LongWritable(FOO_OUT)).runTest();
+    thrown.expectMessage(IllegalStateException.class,
+        "Driver reuse not allowed");
     driver.runTest();
   }
 
diff --git a/src/test/java/org/apache/hadoop/mrunit/mapreduce/TestMultipleInputsMapReduceDriver.java b/src/test/java/org/apache/hadoop/mrunit/mapreduce/TestMultipleInputsMapReduceDriver.java
new file mode 100644
index 0000000..b6a7b86
--- /dev/null
+++ b/src/test/java/org/apache/hadoop/mrunit/mapreduce/TestMultipleInputsMapReduceDriver.java
@@ -0,0 +1,658 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mrunit.mapreduce;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.serializer.JavaSerializationComparator;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
+import org.apache.hadoop.mapreduce.lib.reduce.LongSumReducer;
+import org.apache.hadoop.mrunit.ExpectedSuppliedException;
+import org.apache.hadoop.mrunit.TestMapReduceDriver;
+import org.apache.hadoop.mrunit.types.Pair;
+import org.apache.hadoop.mrunit.types.TestWritable;
+import org.apache.hadoop.mrunit.types.UncomparableWritable;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.apache.hadoop.mrunit.ExtendedAssert.assertListEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
+public class TestMultipleInputsMapReduceDriver {
+  @Rule
+  public final ExpectedSuppliedException thrown = ExpectedSuppliedException
+      .none();
+
+  private static final int FOO_IN_A = 42;
+  private static final int FOO_IN_B = 10;
+  private static final int TOKEN_IN_A = 1;
+  private static final int TOKEN_IN_B = 2;
+  private static final int BAR_IN = 12;
+  private static final int BAR_OUT = BAR_IN + TOKEN_IN_A + TOKEN_IN_B;
+  private static final int FOO_OUT = FOO_IN_A + FOO_IN_B + TOKEN_IN_A + 2
+      * TOKEN_IN_B;
+  private static final String TOKEN_A = "foo bar";
+  private static final String TOKEN_B = "foo foo bar";
+
+  private Mapper<Text, LongWritable, Text, LongWritable> mapper;
+  private Reducer<Text, LongWritable, Text, LongWritable> reducer;
+  private TokenMapper tokenMapper;
+  private MultipleInputsMapReduceDriver<Text, LongWritable, Text, LongWritable> driver;
+
+  @Before
+  public void setUp() {
+    mapper = new Mapper<Text, LongWritable, Text, LongWritable>();
+    reducer = new LongSumReducer<Text>();
+    tokenMapper = new TokenMapper();
+    driver = new MultipleInputsMapReduceDriver<Text, LongWritable, Text, LongWritable>(
+        reducer);
+    driver.addMapper(mapper);
+    driver.addMapper(tokenMapper);
+  }
+
+  @Test
+  public void testRun() throws IOException {
+    final List<Pair<Text, LongWritable>> out = driver
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_A))
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_B))
+        .withInput(mapper, new Text("bar"), new LongWritable(BAR_IN))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_A), new Text(TOKEN_A))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_B), new Text(TOKEN_B))
+        .run();
+
+    final List<Pair<Text, LongWritable>> expected = new ArrayList<Pair<Text, LongWritable>>();
+    expected.add(new Pair<Text, LongWritable>(new Text("bar"),
+        new LongWritable(BAR_OUT)));
+    expected.add(new Pair<Text, LongWritable>(new Text("foo"),
+        new LongWritable(FOO_OUT)));
+
+    assertListEquals(expected, out);
+  }
+
+  @Test
+  public void testUncomparable() throws IOException {
+    MultipleInputsMapReduceDriver<Text, Object, Text, Object> testDriver = MultipleInputsMapReduceDriver
+        .newMultipleInputMapReduceDriver(new Reducer<Text, Object, Text, Object>());
+
+    Mapper<Text, Object, Text, Object> identity = new Mapper<Text, Object, Text, Object>();
+    testDriver.addMapper(identity);
+    Text k1 = new Text("foo");
+    Object v1 = new UncomparableWritable(1);
+    testDriver.withInput(identity, k1, v1);
+
+    ReverseMapper<Object, Text> reverse = new ReverseMapper<Object, Text>();
+    testDriver.addMapper(reverse);
+    Text k2 = new Text("bar");
+    Object v2 = new UncomparableWritable(2);
+    testDriver.withInput(reverse, v2, k2);
+
+    testDriver.withOutput(k1, v1).withOutput(k2, v2);
+
+    testDriver.runTest(false);
+  }
+
+  @Test
+  public void testTestRun() throws IOException {
+    driver
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_A))
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_B))
+        .withInput(mapper, new Text("bar"), new LongWritable(BAR_IN))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_A), new Text(TOKEN_A))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_B), new Text(TOKEN_B))
+        .withOutput(new Text("bar"), new LongWritable(BAR_OUT))
+        .withOutput(new Text("foo"), new LongWritable(FOO_OUT)).runTest(false);
+  }
+
+  @Test
+  public void testAddAll() throws IOException {
+    final List<Pair<Text, LongWritable>> mapperInputs = new ArrayList<Pair<Text, LongWritable>>();
+    mapperInputs.add(new Pair<Text, LongWritable>(new Text("foo"),
+        new LongWritable(FOO_IN_A)));
+    mapperInputs.add(new Pair<Text, LongWritable>(new Text("foo"),
+        new LongWritable(FOO_IN_B)));
+    mapperInputs.add(new Pair<Text, LongWritable>(new Text("bar"),
+        new LongWritable(BAR_IN)));
+
+    final List<Pair<LongWritable, Text>> tokenMapperInputs = new ArrayList<Pair<LongWritable, Text>>();
+    tokenMapperInputs.add(new Pair<LongWritable, Text>(new LongWritable(
+        TOKEN_IN_A), new Text(TOKEN_A)));
+    tokenMapperInputs.add(new Pair<LongWritable, Text>(new LongWritable(
+        TOKEN_IN_B), new Text(TOKEN_B)));
+
+    final List<Pair<Text, LongWritable>> outputs = new ArrayList<Pair<Text, LongWritable>>();
+    outputs.add(new Pair<Text, LongWritable>(new Text("bar"), new LongWritable(
+        BAR_OUT)));
+    outputs.add(new Pair<Text, LongWritable>(new Text("foo"), new LongWritable(
+        FOO_OUT)));
+
+    driver.withAll(mapper, mapperInputs)
+        .withAll(tokenMapper, tokenMapperInputs).withAllOutput(outputs)
+        .runTest(false);
+  }
+
+  @Test
+  public void testNoInput() throws IOException {
+    thrown.expectMessage(IllegalStateException.class,
+        "No input was provided for mapper");
+    driver.runTest(false);
+  }
+
+  @Test
+  public void testNoInputForMapper() throws IOException {
+    MultipleInputsMapReduceDriver<Text, LongWritable, Text, LongWritable> testDriver = new MultipleInputsMapReduceDriver<Text, LongWritable, Text, LongWritable>();
+    testDriver.addMapper(mapper);
+    testDriver.addMapper(tokenMapper);
+    testDriver.withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_A));
+    thrown.expectMessage(IllegalStateException.class,
+        String.format("No input was provided for mapper %s", tokenMapper));
+    testDriver.runTest(false);
+  }
+
+  @Test
+  public void testNoReducer() throws IOException {
+    MultipleInputsMapReduceDriver<Text, LongWritable, Text, LongWritable> testDriver = new MultipleInputsMapReduceDriver<Text, LongWritable, Text, LongWritable>();
+    testDriver.addMapper(mapper);
+    testDriver.withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_A));
+    thrown.expectMessage(IllegalStateException.class,
+        "No reducer class was provided");
+    testDriver.runTest(false);
+  }
+
+  @Test
+  public void testIdentityCombiner() throws IOException {
+    driver
+        .withCombiner(new Reducer<Text, LongWritable, Text, LongWritable>())
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_A))
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_B))
+        .withInput(mapper, new Text("bar"), new LongWritable(BAR_IN))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_A), new Text(TOKEN_A))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_B), new Text(TOKEN_B))
+        .withOutput(new Text("foo"), new LongWritable(FOO_OUT))
+        .withOutput(new Text("bar"), new LongWritable(BAR_OUT)).runTest(false);
+  }
+
+  @Test
+  public void testLongSumCombiner() throws IOException {
+    driver
+        .withCombiner(new LongSumReducer<Text>())
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_A))
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_B))
+        .withInput(mapper, new Text("bar"), new LongWritable(BAR_IN))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_A), new Text(TOKEN_A))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_B), new Text(TOKEN_B))
+        .withOutput(new Text("foo"), new LongWritable(FOO_OUT))
+        .withOutput(new Text("bar"), new LongWritable(BAR_OUT)).runTest(false);
+  }
+
+  @Test
+  public void testLongSumCombinerAndIdentityReducer() throws IOException {
+    driver
+        .withCombiner(new LongSumReducer<Text>())
+        .withReducer(new Reducer<Text, LongWritable, Text, LongWritable>())
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_A))
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_B))
+        .withInput(mapper, new Text("bar"), new LongWritable(BAR_IN))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_A), new Text(TOKEN_A))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_B), new Text(TOKEN_B))
+        .withOutput(new Text("foo"), new LongWritable(FOO_OUT))
+        .withOutput(new Text("bar"), new LongWritable(BAR_OUT)).runTest(false);
+  }
+
+  @Test
+  public void testRepeatRun() throws IOException {
+    driver
+        .withCombiner(new Reducer<Text, LongWritable, Text, LongWritable>())
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_A))
+        .withInput(mapper, new Text("foo"), new LongWritable(FOO_IN_B))
+        .withInput(mapper, new Text("bar"), new LongWritable(BAR_IN))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_A), new Text(TOKEN_A))
+        .withInput(tokenMapper, new LongWritable(TOKEN_IN_B), new Text(TOKEN_B))
+        .withOutput(new Text("foo"), new LongWritable(FOO_OUT))
+        .withOutput(new Text("bar"), new LongWritable(BAR_OUT)).runTest(false);
+    thrown.expectMessage(IllegalStateException.class,
+        "Driver reuse not allowed");
+    driver.runTest(false);
+  }
+
+  // Test the key grouping and value ordering comparators
+  @Test
+  public void testComparators() throws IOException {
+    // reducer to track the order of the input values using bit shifting
+    driver.withReducer(new Reducer<Text, LongWritable, Text, LongWritable>() {
+      @Override
+      protected void reduce(Text key, Iterable<LongWritable> values,
+          Context context) throws IOException, InterruptedException {
+        Text outKey = new Text(key);
+        long outputValue = 0;
+        int count = 0;
+        for (LongWritable value : values) {
+          outputValue |= (value.get() << (count++ * 8));
+        }
+
+        context.write(outKey, new LongWritable(outputValue));
+      }
+    });
+
+    driver
+        .withKeyGroupingComparator(new org.apache.hadoop.mrunit.TestMapReduceDriver.FirstCharComparator());
+    driver
+        .withKeyOrderComparator(new org.apache.hadoop.mrunit.TestMapReduceDriver.SecondCharComparator());
+
+    driver.addInput(mapper, new Text("a1"), new LongWritable(1));
+    driver.addInput(mapper, new Text("b1"), new LongWritable(1));
+    driver.addInput(mapper, new Text("a3"), new LongWritable(3));
+    driver.addInput(mapper, new Text("a2"), new LongWritable(2));
+
+    driver.addInput(tokenMapper, new LongWritable(1), new Text("c1 d1"));
+
+    driver.addOutput(new Text("a1"), new LongWritable(0x1));
+    driver.addOutput(new Text("b1"), new LongWritable(0x1));
+    driver.addOutput(new Text("a2"), new LongWritable(0x2 | (0x3 << 8)));
+    driver.addOutput(new Text("c1"), new LongWritable(0x1));
+    driver.addOutput(new Text("d1"), new LongWritable(0x1));
+
+    driver.runTest(false);
+  }
+
+  @Test
+  public void testNoMapper() throws IOException {
+    MultipleInputsMapReduceDriver<Text, LongWritable, Text, LongWritable> testDriver = new MultipleInputsMapReduceDriver<Text, LongWritable, Text, LongWritable>();
+    testDriver.withReducer(reducer);
+    thrown.expectMessage(IllegalStateException.class,
+        "No mappers were provided");
+    testDriver.runTest(false);
+  }
+
+  @Test
+  public void testWithCounter() throws IOException {
+    MultipleInputsMapReduceDriver<Text, Text, Text, Text> testDriver = new MultipleInputsMapReduceDriver<Text, Text, Text, Text>();
+    Mapper<Text, Text, Text, Text> mapperWithCounters = new TestMapDriver.MapperWithCounters<Text, Text, Text, Text>();
+    Mapper<Text, Text, Text, Text> tokenMapperWithCounters = new TokenMapperWithCounters();
+    testDriver
+        .withMapper(mapperWithCounters)
+        .withInput(mapperWithCounters, new Text("hie"), new Text("Hi"))
+        .withMapper(tokenMapperWithCounters)
+        .withInput(tokenMapperWithCounters, new Text("bie"),
+            new Text("Goodbye Bye"))
+        .withCounter(TestMapDriver.MapperWithCounters.Counters.X, 1)
+        .withCounter(TokenMapperWithCounters.Counters.Y, 2)
+        .withCounter("category", "name", 3)
+        .withReducer(
+            new TestReduceDriver.ReducerWithCounters<Text, Text, Text, Text>())
+        .withCounter(TestReduceDriver.ReducerWithCounters.Counters.COUNT, 2)
+        .withCounter(TestReduceDriver.ReducerWithCounters.Counters.SUM, 3)
+        .withCounter("category", "count", 2).withCounter("category", "sum", 3)
+        .runTest(false);
+  }
+
+  @Test
+  public void testWithCounterAndEnumCounterMissing() throws IOException {
+    MultipleInputsMapReduceDriver<Text, Text, Text, Text> testDriver = new MultipleInputsMapReduceDriver<Text, Text, Text, Text>();
+
+    thrown
+        .expectAssertionErrorMessage("2 Error(s): (Actual counter ("
+            + "\"org.apache.hadoop.mrunit.mapreduce.TestMapDriver$MapperWithCounters$Counters\",\"X\")"
+            + " was not found in expected counters, Actual counter ("
+            + "\"org.apache.hadoop.mrunit.mapreduce.TestMultipleInputsMapReduceDriver$TokenMapperWithCounters$Counters\",\"Y\")"
+            + " was not found in expected counters");
+
+    Mapper<Text, Text, Text, Text> mapperWithCounters = new TestMapDriver.MapperWithCounters<Text, Text, Text, Text>();
+    Mapper<Text, Text, Text, Text> tokenMapperWithCounters = new TokenMapperWithCounters();
+
+    testDriver
+        .withMapper(mapperWithCounters)
+        .withInput(mapperWithCounters, new Text("hie"), new Text("Hi"))
+        .withMapper(tokenMapperWithCounters)
+        .withInput(tokenMapperWithCounters, new Text("bie"),
+            new Text("Goodbye Bye"))
+        .withStrictCounterChecking()
+        .withCounter("category", "name", 3)
+        .withReducer(
+            new TestReduceDriver.ReducerWithCounters<Text, Text, Text, Text>())
+        .withCounter(TestReduceDriver.ReducerWithCounters.Counters.COUNT, 2)
+        .withCounter(TestReduceDriver.ReducerWithCounters.Counters.SUM, 3)
+        .withCounter("category", "count", 2).withCounter("category", "sum", 3)
+        .runTest(false);
+  }
+
+  @Test
+  public void testWithCounterAndStringCounterMissing() throws IOException {
+    MultipleInputsMapReduceDriver<Text, Text, Text, Text> testDriver = new MultipleInputsMapReduceDriver<Text, Text, Text, Text>();
+
+    thrown.expectAssertionErrorMessage("1 Error(s): (Actual counter ("
+        + "\"category\",\"name\")" + " was not found in expected counters");
+
+    Mapper<Text, Text, Text, Text> mapperWithCounters = new TestMapDriver.MapperWithCounters<Text, Text, Text, Text>();
+    Mapper<Text, Text, Text, Text> tokenMapperWithCounters = new TokenMapperWithCounters();
+
+    testDriver
+        .withMapper(mapperWithCounters)
+        .withInput(mapperWithCounters, new Text("hie"), new Text("Hi"))
+        .withMapper(tokenMapperWithCounters)
+        .withInput(tokenMapperWithCounters, new Text("bie"),
+            new Text("Goodbye Bye"))
+        .withStrictCounterChecking()
+        .withCounter(TestMapDriver.MapperWithCounters.Counters.X, 1)
+        .withCounter(TokenMapperWithCounters.Counters.Y, 2)
+        .withReducer(
+            new TestReduceDriver.ReducerWithCounters<Text, Text, Text, Text>())
+        .withCounter(TestReduceDriver.ReducerWithCounters.Counters.COUNT, 2)
+        .withCounter(TestReduceDriver.ReducerWithCounters.Counters.SUM, 3)
+        .withCounter("category", "count", 2).withCounter("category", "sum", 3)
+        .runTest(false);
+  }
+
+  @Test
+  public void testWithFailedCounter() throws IOException {
+    MultipleInputsMapReduceDriver<Text, Text, Text, Text> testDriver = new MultipleInputsMapReduceDriver<Text, Text, Text, Text>();
+
+    thrown
+        .expectAssertionErrorMessage("3 Error(s): ("
+            + "Counter org.apache.hadoop.mrunit.mapreduce.TestMapDriver.MapperWithCounters.Counters.X has value 1 instead of expected 20, "
+            + "Counter org.apache.hadoop.mrunit.mapreduce.TestMultipleInputsMapReduceDriver.TokenMapperWithCounters.Counters.Y has value 2 instead of expected 30, "
+            + "Counter with category category and name name has value 3 instead of expected 20)");
+
+    Mapper<Text, Text, Text, Text> mapperWithCounters = new TestMapDriver.MapperWithCounters<Text, Text, Text, Text>();
+    Mapper<Text, Text, Text, Text> tokenMapperWithCounters = new TokenMapperWithCounters();
+
+    testDriver
+        .withMapper(mapperWithCounters)
+        .withInput(mapperWithCounters, new Text("hie"), new Text("Hi"))
+        .withMapper(tokenMapperWithCounters)
+        .withInput(tokenMapperWithCounters, new Text("bie"),
+            new Text("Goodbye Bye"))
+        .withCounter(TestMapDriver.MapperWithCounters.Counters.X, 20)
+        .withCounter(TokenMapperWithCounters.Counters.Y, 30)
+        .withReducer(
+            new TestReduceDriver.ReducerWithCounters<Text, Text, Text, Text>())
+        .withCounter("category", "name", 20).runTest(false);
+  }
+
+  @Test
+  public void testJavaSerialization() throws IOException {
+    final Configuration conf = new Configuration();
+    conf.setStrings("io.serializations", conf.get("io.serializations"),
+        "org.apache.hadoop.io.serializer.JavaSerialization");
+    final MultipleInputsMapReduceDriver<Integer, IntWritable, Integer, IntWritable> testDriver = MultipleInputsMapReduceDriver
+        .newMultipleInputMapReduceDriver(
+                new Reducer<Integer, IntWritable, Integer, IntWritable>())
+        .withConfiguration(conf);
+    Mapper<Integer, IntWritable, Integer, IntWritable> identityMapper = new Mapper<Integer, IntWritable, Integer, IntWritable>();
+    Mapper<Integer, IntWritable, Integer, IntWritable> anotherIdentityMapper = new Mapper<Integer, IntWritable, Integer, IntWritable>();
+    testDriver.addMapper(identityMapper);
+    testDriver.withInput(identityMapper, 1, new IntWritable(2)).withInput(
+        identityMapper, 2, new IntWritable(3));
+    testDriver.addMapper(anotherIdentityMapper);
+    testDriver.withInput(anotherIdentityMapper, 3, new IntWritable(4))
+        .withInput(anotherIdentityMapper, 4, new IntWritable(5));
+    testDriver
+        .withKeyOrderComparator(new JavaSerializationComparator<Integer>());
+    testDriver
+        .withKeyGroupingComparator(org.apache.hadoop.mrunit.TestMapReduceDriver.INTEGER_COMPARATOR);
+
+    testDriver.withOutput(1, new IntWritable(2))
+        .withOutput(2, new IntWritable(3)).withOutput(3, new IntWritable(4))
+        .withOutput(4, new IntWritable(5));
+    testDriver.runTest(false);
+  }
+
+  @Test
+  public void testCopy() throws IOException {
+    final Text key = new Text("a");
+    final LongWritable value = new LongWritable(1);
+    driver.addInput(mapper, key, value);
+    key.set("b");
+    value.set(2);
+    driver.addInput(mapper, key, value);
+
+    key.set("a");
+    value.set(1);
+    driver.addOutput(key, value);
+    key.set("b");
+    value.set(2);
+    driver.addOutput(key, value);
+
+    final LongWritable longKey = new LongWritable(3);
+    final Text textValue = new Text("c d");
+    driver.addInput(tokenMapper, longKey, textValue);
+    longKey.set(4);
+    textValue.set("e f g");
+    driver.addInput(tokenMapper, longKey, textValue);
+
+    key.set("c");
+    value.set(3);
+    driver.addOutput(key, value);
+    key.set("d");
+    value.set(3);
+    driver.addOutput(key, value);
+    key.set("e");
+    value.set(4);
+    driver.addOutput(key, value);
+    key.set("f");
+    value.set(4);
+    driver.addOutput(key, value);
+    key.set("g");
+    value.set(4);
+    driver.addOutput(key, value);
+
+    driver.runTest(false);
+  }
+
+  @Test
+  public void testOutputFormat() throws IOException {
+    driver.withInputFormat(SequenceFileInputFormat.class);
+    driver.withOutputFormat(SequenceFileOutputFormat.class);
+    driver.withInput(mapper, new Text("a"), new LongWritable(1));
+    driver.withInput(mapper, new Text("a"), new LongWritable(2));
+    driver.withInput(tokenMapper, new LongWritable(3), new Text("a b"));
+    driver.withOutput(new Text("a"), new LongWritable(6));
+    driver.withOutput(new Text("b"), new LongWritable(3));
+    driver.runTest(false);
+  }
+
+  @SuppressWarnings({ "unchecked", "rawtypes" })
+  @Test
+  public void testOutputFormatWithMismatchInOutputClasses() throws IOException {
+    final MultipleInputsMapReduceDriver testDriver = this.driver;
+    testDriver.withInputFormat(TextInputFormat.class);
+    testDriver.withOutputFormat(TextOutputFormat.class);
+    testDriver.withInput(mapper, new Text("a"), new LongWritable(1));
+    testDriver.withInput(mapper, new Text("a"), new LongWritable(2));
+    testDriver.withInput(tokenMapper, new LongWritable(3), new Text("a b"));
+    testDriver.withOutput(new LongWritable(0), new Text("a\t6"));
+    testDriver.withOutput(new LongWritable(4), new Text("b\t3"));
+    testDriver.runTest(false);
+  }
+
+  @Test
+  public void testMapInputFile() throws IOException {
+    MultipleInputsMapReduceDriver<Text, LongWritable, Text, LongWritable> testDriver = new MultipleInputsMapReduceDriver<Text, LongWritable, Text, LongWritable>(
+        reducer);
+
+    InputPathStoringMapper<LongWritable, LongWritable> inputPathStoringMapper = new InputPathStoringMapper<LongWritable, LongWritable>();
+    Path mapInputPath = new Path("myfile");
+    testDriver.addMapper(inputPathStoringMapper);
+    testDriver.setMapInputPath(inputPathStoringMapper, mapInputPath);
+    assertEquals(mapInputPath.getName(),
+        testDriver.getMapInputPath(inputPathStoringMapper).getName());
+    testDriver.withInput(inputPathStoringMapper, new Text("a"),
+        new LongWritable(1));
+
+    InputPathStoringMapper<LongWritable, LongWritable> anotherInputPathStoringMapper = new InputPathStoringMapper<LongWritable, LongWritable>();
+    Path anotherMapInputPath = new Path("myotherfile");
+    testDriver.addMapper(anotherInputPathStoringMapper);
+    testDriver.setMapInputPath(anotherInputPathStoringMapper,
+        anotherMapInputPath);
+    assertEquals(anotherMapInputPath.getName(),
+        testDriver.getMapInputPath(anotherInputPathStoringMapper).getName());
+    testDriver.withInput(anotherInputPathStoringMapper, new Text("b"),
+        new LongWritable(2));
+
+    testDriver.runTest(false);
+    assertNotNull(inputPathStoringMapper.getMapInputPath());
+    assertEquals(mapInputPath.getName(), inputPathStoringMapper
+        .getMapInputPath().getName());
+  }
+
+  @Test
+  public void testGroupComparatorBehaviorFirst() throws IOException {
+    driver
+        .withInput(mapper, new Text("A1"), new LongWritable(1L))
+        .withInput(mapper, new Text("A2"), new LongWritable(1L))
+        .withInput(mapper, new Text("B1"), new LongWritable(1L))
+        .withInput(mapper, new Text("B2"), new LongWritable(1L))
+        .withInput(mapper, new Text("C1"), new LongWritable(1L))
+        .withInput(tokenMapper, new LongWritable(3L), new Text("D1 D2 E1"))
+        .withOutput(new Text("A2"), new LongWritable(2L))
+        .withOutput(new Text("B2"), new LongWritable(2L))
+        .withOutput(new Text("C1"), new LongWritable(1L))
+        .withOutput(new Text("D2"), new LongWritable(6L))
+        .withOutput(new Text("E1"), new LongWritable(3L))
+        .withKeyGroupingComparator(
+            new org.apache.hadoop.mrunit.TestMapReduceDriver.FirstCharComparator())
+        .runTest(false);
+  }
+
+  @Test
+  public void testGroupComparatorBehaviorSecond() throws IOException {
+    driver
+        .withInput(mapper, new Text("1A"), new LongWritable(1L))
+        .withInput(mapper, new Text("2A"), new LongWritable(1L))
+        .withInput(mapper, new Text("1B"), new LongWritable(1L))
+        .withInput(mapper, new Text("2B"), new LongWritable(1L))
+        .withInput(mapper, new Text("1C"), new LongWritable(1L))
+        .withInput(tokenMapper, new LongWritable(2L), new Text("1D 2D 1E"))
+        .withOutput(new Text("1A"), new LongWritable(1L))
+        .withOutput(new Text("2A"), new LongWritable(1L))
+        .withOutput(new Text("1B"), new LongWritable(1L))
+        .withOutput(new Text("2B"), new LongWritable(1L))
+        .withOutput(new Text("1C"), new LongWritable(1L))
+        .withOutput(new Text("1D"), new LongWritable(2L))
+        .withOutput(new Text("2D"), new LongWritable(2L))
+        .withOutput(new Text("1E"), new LongWritable(2L))
+        .withKeyGroupingComparator(
+            new org.apache.hadoop.mrunit.TestMapReduceDriver.SecondCharComparator())
+        .runTest(false);
+  }
+
+  @Test
+  public void testGroupingComparatorSpecifiedByConf() throws IOException {
+    JobConf conf = new JobConf(new Configuration());
+    conf.setOutputValueGroupingComparator(TestMapReduceDriver.FirstCharComparator.class);
+    driver.withInput(mapper, new Text("A1"), new LongWritable(1L))
+        .withInput(mapper, new Text("A2"), new LongWritable(1L))
+        .withInput(mapper, new Text("B1"), new LongWritable(1L))
+        .withInput(mapper, new Text("B2"), new LongWritable(1L))
+        .withInput(mapper, new Text("C1"), new LongWritable(1L))
+        .withInput(tokenMapper, new LongWritable(3L), new Text("D1 D2 E1"))
+        .withOutput(new Text("A2"), new LongWritable(2L))
+        .withOutput(new Text("B2"), new LongWritable(2L))
+        .withOutput(new Text("C1"), new LongWritable(1L))
+        .withOutput(new Text("D2"), new LongWritable(6L))
+        .withOutput(new Text("E1"), new LongWritable(3L))
+        .withConfiguration(conf).runTest(false);
+  }
+
+  @SuppressWarnings("unchecked")
+  @Test
+  public void testUseOfWritableRegisteredComparator() throws IOException {
+    MultipleInputsMapReduceDriver<TestWritable, Text, TestWritable, Text> testDriver = new MultipleInputsMapReduceDriver<TestWritable, Text, TestWritable, Text>(
+        new Reducer<TestWritable, Text, TestWritable, Text>());
+
+    Mapper<TestWritable, Text, TestWritable, Text> identityMapper = new Mapper<TestWritable, Text, TestWritable, Text>();
+    Mapper<TestWritable, Text, TestWritable, Text> anotherIdentityMapper = new Mapper<TestWritable, Text, TestWritable, Text>();
+    testDriver.addMapper(identityMapper);
+    testDriver.addMapper(anotherIdentityMapper);
+    testDriver
+        .withInput(identityMapper, new TestWritable("A1"), new Text("A1"))
+        .withInput(identityMapper, new TestWritable("A2"), new Text("A2"))
+        .withInput(identityMapper, new TestWritable("A3"), new Text("A3"))
+        .withInput(anotherIdentityMapper, new TestWritable("B1"),
+            new Text("B1"))
+        .withInput(anotherIdentityMapper, new TestWritable("B2"),
+            new Text("B2"))
+        .withKeyGroupingComparator(new TestWritable.SingleGroupComparator())
+        .withOutput(new TestWritable("B2"), new Text("B2"))
+        .withOutput(new TestWritable("B1"), new Text("B1"))
+        .withOutput(new TestWritable("A3"), new Text("A3"))
+        .withOutput(new TestWritable("A2"), new Text("A2"))
+        .withOutput(new TestWritable("A1"), new Text("A1")).runTest(true); // ordering
+                                                                           // is
+                                                                           // important
+  }
+
+  static class TokenMapperWithCounters extends Mapper<Text, Text, Text, Text> {
+    private final Text output = new Text();
+
+    @Override
+    protected void map(Text key, Text value, Context context)
+        throws IOException, InterruptedException {
+      String[] tokens = value.toString().split("\\s");
+      for (String token : tokens) {
+        output.set(token);
+        context.write(key, output);
+        context.getCounter(Counters.Y).increment(1);
+        context.getCounter("category", "name").increment(1);
+      }
+    }
+
+    public static enum Counters {
+      Y
+    }
+  }
+
+  static class TokenMapper extends
+      Mapper<LongWritable, Text, Text, LongWritable> {
+    private final Text output = new Text();
+
+    @Override
+    protected void map(LongWritable key, Text value, Context context)
+        throws IOException, InterruptedException {
+      String[] tokens = value.toString().split("\\s");
+      for (String token : tokens) {
+        output.set(token);
+        context.write(output, key);
+      }
+    }
+  }
+
+  static class ReverseMapper<KEYIN, VALUEIN> extends
+      Mapper<KEYIN, VALUEIN, VALUEIN, KEYIN> {
+    @Override
+    protected void map(KEYIN key, VALUEIN value, Context context)
+        throws IOException, InterruptedException {
+      context.write(value, key);
+    }
+  }
+}