PIG-4883: MapKeyType of splitter was set wrongly in specific multiquery case (kellyzly via rohini)

git-svn-id: https://svn.apache.org/repos/asf/pig/trunk@1743708 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/CHANGES.txt b/CHANGES.txt
index 8467c41..7af0f23 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -123,6 +123,8 @@
 
 BUG FIXES
 
+PIG-4883: MapKeyType of splitter was set wrongly in specific multiquery case (kellyzly via rohini)
+
 PIG-4887: Parameter substitution skipped with glob on register (knoguchi)
 
 PIG-4889: Replacing backslash fails as lexical error (knoguchi)
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/mapReduceLayer/MapReduceOper.java b/src/org/apache/pig/backend/hadoop/executionengine/mapReduceLayer/MapReduceOper.java
index b5611db..b0a714f 100644
--- a/src/org/apache/pig/backend/hadoop/executionengine/mapReduceLayer/MapReduceOper.java
+++ b/src/org/apache/pig/backend/hadoop/executionengine/mapReduceLayer/MapReduceOper.java
@@ -65,7 +65,10 @@
     // this is needed when the key is null to create
     // an appropriate NullableXXXWritable object
     public byte mapKeyType;
-    
+
+    //record the map key types of all splittees
+    public byte[] mapKeyTypeOfSplittees;
+
     //Indicates that the map plan creation
     //is complete
     boolean mapDone = false;
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/mapReduceLayer/MultiQueryOptimizer.java b/src/org/apache/pig/backend/hadoop/executionengine/mapReduceLayer/MultiQueryOptimizer.java
index 5331766..58e9da4 100644
--- a/src/org/apache/pig/backend/hadoop/executionengine/mapReduceLayer/MultiQueryOptimizer.java
+++ b/src/org/apache/pig/backend/hadoop/executionengine/mapReduceLayer/MultiQueryOptimizer.java
@@ -18,6 +18,7 @@
 package org.apache.pig.backend.hadoop.executionengine.mapReduceLayer;
 
 import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
@@ -580,18 +581,17 @@
     }
 
     private boolean hasSameMapKeyType(List<MapReduceOper> splittees) {
-        boolean sameKeyType = true;
-        for (MapReduceOper outer : splittees) {
-            for (MapReduceOper inner : splittees) {
-                if (inner.mapKeyType != outer.mapKeyType) {
-                    sameKeyType = false;
-                    break;
+        Set<Byte> keyTypes = new HashSet<Byte>();
+        for (MapReduceOper splittee : splittees) {
+            keyTypes.add(splittee.mapKeyType);
+            if (splittee.mapKeyTypeOfSplittees != null) {
+                for (int i = 0; i < splittee.mapKeyTypeOfSplittees.length; i++) {
+                    keyTypes.add(splittee.mapKeyTypeOfSplittees[i]);
                 }
             }
-            if (!sameKeyType) break;
-        }
 
-        return sameKeyType;
+        }
+        return keyTypes.size() == 1;
     }
 
     private int setIndexOnLRInSplit(int initial, POSplit splitOp, boolean sameKeyType)
@@ -1035,10 +1035,20 @@
         splitter.mapKeyType = sameKeyType ?
                 mergeList.get(0).mapKeyType : DataType.TUPLE;
 
+
+        setMapKeyTypeForSplitter(splitter,mergeList);
+
         log.info("Requested parallelism of splitter: "
                 + splitter.getRequestedParallelism());
     }
 
+    private void setMapKeyTypeForSplitter(MapReduceOper splitter, List<MapReduceOper> mergeList) {
+        splitter.mapKeyTypeOfSplittees = new byte[mergeList.size()];
+        for (int i = 0; i < mergeList.size(); i++) {
+            splitter.mapKeyTypeOfSplittees[i] = mergeList.get(i).mapKeyType;
+        }
+    }
+
     private void mergeSingleMapReduceSplittee(MapReduceOper mapReduce,
             MapReduceOper splitter, POSplit splitOp) throws VisitorException {
 
diff --git a/test/org/apache/pig/test/TestMultiQuery.java b/test/org/apache/pig/test/TestMultiQuery.java
index 40684b4..c32eab7 100644
--- a/test/org/apache/pig/test/TestMultiQuery.java
+++ b/test/org/apache/pig/test/TestMultiQuery.java
@@ -907,6 +907,49 @@
         Util.checkQueryOutputsAfterSort(actualResults.iterator(), expectedResults);
     }
 
+    @Test
+    public void testMultiQueryJiraPig4883() throws Exception {
+        Storage.Data data = Storage.resetData(myPig);
+        data.set("inputLocation",
+                Storage.tuple("c", "12"), Storage.tuple("d", "-12"));
+        myPig.setBatchOn();
+        myPig.registerQuery("A = load 'inputLocation' using mock.Storage();");
+        myPig.registerQuery("A = foreach A generate (chararray)$0 as id, (long)$1 as val;");
+        myPig.registerQuery("B = filter A by val > 0;");
+        myPig.registerQuery("B1 = group B by val;");
+        myPig.registerQuery("B1 = foreach B1 generate group as name, COUNT(B) as value;");
+        myPig.registerQuery("B1 = foreach B1 generate (chararray)name,value;");
+        myPig.registerQuery("store B1 into 'output1' using mock.Storage();");
+        myPig.registerQuery("B2 = group B by id;");
+        myPig.registerQuery("B2 = foreach B2 generate group as name, COUNT(B) as value;");
+        myPig.registerQuery("store B2 into 'output2' using mock.Storage();");
+        myPig.registerQuery("C = filter A by val < 0;");
+        myPig.registerQuery("C1 = group C by val;");
+        myPig.registerQuery("C1 = foreach C1 generate group as name, COUNT(C) as value;");
+        myPig.registerQuery("store C1 into 'output3' using mock.Storage();");
+        myPig.registerQuery("C2 = group C by id;");
+        myPig.registerQuery("C2 = foreach C2 generate group as name, COUNT(C) as value;");
+        myPig.registerQuery("store C2 into 'output4' using mock.Storage();");
+        myPig.executeBatch();
+
+        List<Tuple> actualResults = data.get("output1");
+        String[] expectedResults = new String[]{"(12, 1)"};
+        Util.checkQueryOutputsAfterSortRecursive(actualResults.iterator(), expectedResults, org.apache.pig.newplan.logical.Util.translateSchema(myPig.dumpSchema("B1")));
+
+
+        actualResults = data.get("output2");
+        expectedResults = new String[]{"(c,1)"};
+        Util.checkQueryOutputsAfterSortRecursive(actualResults.iterator(), expectedResults, org.apache.pig.newplan.logical.Util.translateSchema(myPig.dumpSchema("B2")));
+
+        actualResults = data.get("output3");
+        expectedResults = new String[]{"(-12, 1)"};
+        Util.checkQueryOutputsAfterSortRecursive(actualResults.iterator(), expectedResults, org.apache.pig.newplan.logical.Util.translateSchema(myPig.dumpSchema("C1")));
+
+        actualResults = data.get("output4");
+        expectedResults = new String[]{"(d,1)"};
+        Util.checkQueryOutputsAfterSortRecursive(actualResults.iterator(), expectedResults, org.apache.pig.newplan.logical.Util.translateSchema(myPig.dumpSchema("C2")));
+    }
+
     // --------------------------------------------------------------------------
     // Helper methods