PIG-5333: LoadCaster sometimes not set for complex type (knoguchi)


git-svn-id: https://svn.apache.org/repos/asf/pig/trunk@1828532 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/CHANGES.txt b/CHANGES.txt
index f7053ad..651fb7c 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -63,6 +63,8 @@
 OPTIMIZATIONS
  
 BUG FIXES
+PIG-5333: LoadCaster sometimes not set for complex type (knoguchi)
+
 PIG-5328: expressionOperator Divide.equalsZero(DataType.BIGDECIMAL) is invalid (michaelthoward via knoguchi)
 
 PIG-5320: TestCubeOperator#testRollupBasic is flaky on Spark 2.2 (nkollar via szita)
diff --git a/src/org/apache/pig/newplan/logical/visitor/LineageFindRelVisitor.java b/src/org/apache/pig/newplan/logical/visitor/LineageFindRelVisitor.java
index 116e681..166a1e7 100644
--- a/src/org/apache/pig/newplan/logical/visitor/LineageFindRelVisitor.java
+++ b/src/org/apache/pig/newplan/logical/visitor/LineageFindRelVisitor.java
@@ -139,7 +139,7 @@
             ){
                 long inpUid = predSchema.getField(innerLoad.getProjection().getColNum()).uid;
                 if(uid2LoadFuncMap.get(inpUid) != null){
-                    addUidLoadFuncToMap(innerLoad.getSchema().getField(0).uid, uid2LoadFuncMap.get(inpUid));
+                    addUidLoadFuncToMap(innerLoad.getSchema().getField(0), uid2LoadFuncMap.get(inpUid));
                 }
                 return;
             }
@@ -165,8 +165,7 @@
             return;
         }
         for(LogicalFieldSchema fs : schema.getFields()){
-            addUidLoadFuncToMap((Long) fs.uid, funcSpec);
-            setLoadFuncForUids(fs.schema, funcSpec);
+            addUidLoadFuncToMap(fs, funcSpec);
         }
         
     }
@@ -291,9 +290,7 @@
         //if the group plans are associated with same load function , associate
         //same load fucntion with group column schema
         if (getAssociatedLoadFunc(group)!=null) {
-            addUidLoadFuncToMap(sch.getField(0).uid, rel2InputFuncMap.get(group));
-            if (sch.getField(0).schema!=null)
-                setLoadFuncForUids(sch.getField(0).schema, rel2InputFuncMap.get(group));
+            addUidLoadFuncToMap(sch.getField(0), rel2InputFuncMap.get(group));
         }
         else
             mapMatchLoadFuncToUid(sch.getField(0), groupPlanSchemas);
@@ -311,10 +308,9 @@
         
         
         for(int i=1; i < sch.size(); i++){
-            long uid = sch.getField(i).uid;
             LogicalRelationalOperator input = (LogicalRelationalOperator) inputs.get(i-1);
             if(getAssociatedLoadFunc(input) != null){
-                addUidLoadFuncToMap(uid, rel2InputFuncMap.get(input));
+                addUidLoadFuncToMap(sch.getField(i), rel2InputFuncMap.get(input));
             }
         }
         
@@ -370,7 +366,7 @@
                 //get its funcspec and associate it with uid of all fields in the schema
                 FuncSpec funcSpec = uid2LoadFuncMap.get(exp.getFieldSchema().uid);
                 for(LogicalFieldSchema fs : sch.getFields()){
-                    addUidLoadFuncToMap(fs.uid, funcSpec);
+                    addUidLoadFuncToMap(fs, funcSpec);
                 }
             }
         }
@@ -461,7 +457,7 @@
                   prevLoadFuncSpec  = curLoadFuncSpec;
                 }
                 if( allSameLoader ) {
-                    addUidLoadFuncToMap(logicalFieldSchema.uid,curLoadFuncSpec);
+                    addUidLoadFuncToMap(logicalFieldSchema,curLoadFuncSpec);
                 }
             }
         }
@@ -502,7 +498,7 @@
             LogicalFieldSchema inField = inputSch.getField(i);
             LogicalFieldSchema outField = outSchema.getField(i);
             if(uid2LoadFuncMap.get(inField.uid) != null){
-                addUidLoadFuncToMap(outField.uid, uid2LoadFuncMap.get(inField.uid));
+                addUidLoadFuncToMap(outField, uid2LoadFuncMap.get(inField.uid));
             }
         }
         
@@ -515,15 +511,17 @@
      * @param loadFuncSpec
      * @throws VisitorException 
      */
-    private void addUidLoadFuncToMap(long uid, FuncSpec loadFuncSpec)
+    private void addUidLoadFuncToMap(LogicalFieldSchema fs, FuncSpec loadFuncSpec)
     throws VisitorException{
         if(loadFuncSpec == null){
             return;
         }
         //ensure that uid always matches to same load func
-        FuncSpec curFuncSpec = uid2LoadFuncMap.get(uid);
+        FuncSpec curFuncSpec = uid2LoadFuncMap.get(fs.uid);
         if(curFuncSpec == null){
-            uid2LoadFuncMap.put(uid, loadFuncSpec);
+            uid2LoadFuncMap.put(fs.uid, loadFuncSpec);
+            // if field is tuple/map/bag, also set the internal fields
+            setLoadFuncForUids(fs.schema, loadFuncSpec);
         }else if(! haveIdenticalCasters(curFuncSpec,loadFuncSpec)){
             String msg = "Bug: uid mapped to two different load functions : " +
             curFuncSpec + " and " + loadFuncSpec;
@@ -574,7 +572,7 @@
                 }
             }
             if(allMatch){
-                addUidLoadFuncToMap(outFS.uid, funcSpec1);
+                addUidLoadFuncToMap(outFS, funcSpec1);
             }
         }
         
@@ -624,7 +622,7 @@
             if(uid2LoadFuncMap.get(uid) == null && (inputRel.getSchema() == null || inputRel instanceof LOInnerLoad)){
                 FuncSpec funcSpec = rel2InputFuncMap.get(inputRel);
                 if(funcSpec != null){
-                    addUidLoadFuncToMap(uid, funcSpec);
+                    addUidLoadFuncToMap(proj.getFieldSchema(), funcSpec);
                 }
             }
         }
@@ -645,7 +643,7 @@
             //find input uid and corresponding load FuncSpec
             long inpUid = inp.getFieldSchema().uid;
             FuncSpec inpLoadFuncSpec = uid2LoadFuncMap.get(inpUid);
-            addUidLoadFuncToMap(exp.getFieldSchema().uid, inpLoadFuncSpec);
+            addUidLoadFuncToMap(exp.getFieldSchema(), inpLoadFuncSpec);
 
         }
 
@@ -689,7 +687,7 @@
             Integer inputColNum = (Integer)((ConstantExpression) scalarExp.getArguments().get(0)).getValue();
             String inputFile = (String)((ConstantExpression) scalarExp.getArguments().get(1)).getValue();
             
-            long outputUid = scalarExp.getFieldSchema().uid;
+            LogicalFieldSchema outputFS = scalarExp.getFieldSchema();
             boolean foundInput = false; // a variable to do sanity check on num of input relations
 
             //find the input relation, and use it to get lineage
@@ -711,12 +709,12 @@
                     if(sch == null){
                         //see if there is a load function associated with the store
                         FuncSpec funcSpec = rel2InputFuncMap.get(inputStore);
-                        addUidLoadFuncToMap(outputUid, funcSpec);
+                        addUidLoadFuncToMap(outputFS, funcSpec);
                     }else{
                         //find input uid and corresponding load func
                         LogicalFieldSchema fs = sch.getField(inputColNum);
                         FuncSpec funcSpec = uid2LoadFuncMap.get(fs.uid);
-                        addUidLoadFuncToMap(outputUid, funcSpec);
+                        addUidLoadFuncToMap(outputFS, funcSpec);
                     }
                 }
             }
@@ -762,9 +760,7 @@
             }
 
             if( funcSpec != null ) {
-                addUidLoadFuncToMap(op.getFieldSchema().uid, funcSpec);
-                // in case schema is nested, set funcSpec for all
-                setLoadFuncForUids(op.getFieldSchema().schema, funcSpec);
+                addUidLoadFuncToMap(op.getFieldSchema(), funcSpec);
             }
         }
 
diff --git a/test/org/apache/pig/test/TestTypeCheckingValidatorNewLP.java b/test/org/apache/pig/test/TestTypeCheckingValidatorNewLP.java
index 78371c3..61d2ba6 100644
--- a/test/org/apache/pig/test/TestTypeCheckingValidatorNewLP.java
+++ b/test/org/apache/pig/test/TestTypeCheckingValidatorNewLP.java
@@ -4174,4 +4174,39 @@
                 + "c = foreach b generate (gpa>3? b1 : b2);";
                 createAndProcessLPlan(query);
         }
+        @Test
+        public void testSplitLineageWithInnerFields() throws Throwable {
+            String query = "A = load 'a' as (field1, field2);"
+            + " B = group A by (field1,field2);"
+            + "split B into C if SIZE(A) > 2, Z otherwise;"
+            + "D = FOREACH C { "
+            + "      D1 = FOREACH A generate (chararray) field1;"
+            + "      GENERATE D1;"
+            + "}";
+            LOForEach outerForeach = getForeachFromPlan(query);
+            LogicalPlan innerPlan = outerForeach.getInnerPlan();
+            LOForEach innerForeach = (LOForEach)innerPlan.getPredecessors(innerPlan.getSinks().get(0)).get(0);
+            LogicalExpressionPlan innerForeachPlan = ((LOGenerate)innerForeach.getInnerPlan().getSinks().get(0)).getOutputPlans().get(0);
+            CastExpression cast = getCastFromExpPlan(innerForeachPlan);
+            checkCastLoadFunc(cast, "org.apache.pig.builtin.PigStorage");
+            /*
+              D: (Name: LOForEach Schema: D1#755:bag{#754:tuple(field1#750:chararray)})           =====> [outerForeach]
+              |   |
+              |   (Name: LOGenerate[false] Schema: D1#755:bag{#754:tuple(field1#750:chararray)})    ====> innerPlan.getSinks().get(0)
+              |   |   |
+              |   |   D1:(Name: Project Type: bag Uid: 755 Input: 0 Column: (*))
+              |   |
+              |   |---D1: (Name: LOForEach Schema: field1#750:chararray)                               =====> [innerForeach]
+              |       |   |
+              |       |   (Name: LOGenerate[false] Schema: field1#750:chararray)
+              |       |   |   |
+              |       |   |   (Name: Cast Type: chararray Uid: 750)                      <========CHECKING HERE
+              |       |   |   |
+              |       |   |   |---field1:(Name: Project Type: bytearray Uid: 750 Input: 0 Column: (*))
+              |       |   |
+              |       |   |---(Name: LOInnerLoad[field1] Schema: field1#750:bytearray)
+              |       |
+              |       |---A: (Name: LOInnerLoad[A] Schema: field1#750:bytearray)
+            */
+        }
 }