PIG-3938: Add LoadCaster to EvalFunc (knoguchi)


git-svn-id: https://svn.apache.org/repos/asf/pig/trunk@1779397 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/CHANGES.txt b/CHANGES.txt
index 55d9fc5..dcef62e 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -34,6 +34,8 @@
  
 IMPROVEMENTS
 
+PIG-3938: Add LoadCaster to EvalFunc (knoguchi)
+
 PIG-5105: Tez unit tests failing with "Argument list too long" (rohini)
 
 PIG-4901: To use Multistorage for each Group (szita via daijy)
diff --git a/src/docs/src/documentation/content/xdocs/basic.xml b/src/docs/src/documentation/content/xdocs/basic.xml
index f4066a6..b3a12c6 100644
--- a/src/docs/src/documentation/content/xdocs/basic.xml
+++ b/src/docs/src/documentation/content/xdocs/basic.xml
@@ -8771,8 +8771,8 @@
    </ul>
    
    
-   <section>
-   <title>About Input and Output</title>
+   <section id="pig-streaming-input-output">
+   <title>About Input and Output for Streaming</title>
    <p>Serialization is needed to convert data from tuples to a format that can be processed by the streaming application. Deserialization is needed to convert the output from the streaming application back into tuples. PigStreaming is the default serialization/deserialization function.</p>
    
 <p>Streaming uses the same default format as PigStorage to serialize/deserialize the data. If you want to explicitly specify a format, you can do it as show below (see more examples in the Examples: Input/Output section).  </p> 
@@ -8803,7 +8803,7 @@
     public Tuple deserialize(byte[]) throws IOException;
 
     /**
-     * This will be called on the front end during planning and not on the back
+     * This will be called on both the front end and the back
      * end during execution.
      *
      * @return the {@link LoadCaster} associated with this object.
diff --git a/src/docs/src/documentation/content/xdocs/udf.xml b/src/docs/src/documentation/content/xdocs/udf.xml
index fb7f836..790ac00 100644
--- a/src/docs/src/documentation/content/xdocs/udf.xml
+++ b/src/docs/src/documentation/content/xdocs/udf.xml
@@ -976,6 +976,14 @@
 }
 </source>
 </section>
+<section id="udf-loadcaster">
+<title>Typecasting from bytearrays</title>
+<p>Just like <a href="#load-functions">Load Function</a> and <a href="basic.html#pig-streaming-input-output">Streaming</a>,
+Java UDF has a getLoadCaster() method that returns
+<a href="http://svn.apache.org/viewvc/pig/trunk/src/org/apache/pig/LoadCaster.java?view=markup">LoadCaster</a>
+to convert byte arrays to specific types. A UDf implementation should implement this if casts (implicit or explicit) from DataByteArray fields to other types need to be supported. Default implementation returns null and Pig will determine if all parameters passed to the UDF have identical loadcaster and use it when true. </p>
+</section>
+
 <section id="tez-jvm-reuse">
         <title>Clean up static variable in Tez</title>
         <p>In Tez, jvm could reuse for other tasks. It is important to cleanup static variable to make sure there is no side effect. Here is one example:</p>
diff --git a/src/org/apache/pig/EvalFunc.java b/src/org/apache/pig/EvalFunc.java
index 58640b0..fd139a8 100644
--- a/src/org/apache/pig/EvalFunc.java
+++ b/src/org/apache/pig/EvalFunc.java
@@ -369,4 +369,17 @@
 
     public void setEndOfAllInput(boolean endOfAllInput) {
     }
+
+    /**
+     * This will be called on both the front end and the back
+     * end during execution.
+     * @return the {@link LoadCaster} associated with this eval. Returning null
+     * indicates that casts from bytearray will pick the one associated with the
+     * parameters when they all come from the same loadcaster type.
+     * @throws IOException if there is an exception during LoadCaster
+     */
+    public LoadCaster getLoadCaster() throws IOException {
+        return null;
+    }
+
 }
diff --git a/src/org/apache/pig/LoadFunc.java b/src/org/apache/pig/LoadFunc.java
index db429ba..c262bad 100644
--- a/src/org/apache/pig/LoadFunc.java
+++ b/src/org/apache/pig/LoadFunc.java
@@ -108,7 +108,7 @@
     public abstract InputFormat getInputFormat() throws IOException;
 
     /**
-     * This will be called on the front end during planning and not on the back 
+     * This will be called on both the front end and the back
      * end during execution.
      * @return the {@link LoadCaster} associated with this loader. Returning null 
      * indicates that casts from byte array are not supported for this loader. 
diff --git a/src/org/apache/pig/StreamToPig.java b/src/org/apache/pig/StreamToPig.java
index f328d27..4366c0c 100644
--- a/src/org/apache/pig/StreamToPig.java
+++ b/src/org/apache/pig/StreamToPig.java
@@ -57,7 +57,7 @@
     public Tuple deserialize(byte[] bytes) throws IOException;
 
     /**
-     * This will be called on the front end during planning and not on the back
+     * This will be called on both the front end and the back
      * end during execution.
      *
      * @return the {@link LoadCaster} associated with this object, or null if
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/physicalLayer/expressionOperators/POCast.java b/src/org/apache/pig/backend/hadoop/executionengine/physicalLayer/expressionOperators/POCast.java
index 02ecf79..a4abd6e 100644
--- a/src/org/apache/pig/backend/hadoop/executionengine/physicalLayer/expressionOperators/POCast.java
+++ b/src/org/apache/pig/backend/hadoop/executionengine/physicalLayer/expressionOperators/POCast.java
@@ -28,6 +28,7 @@
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.pig.EvalFunc;
 import org.apache.pig.FuncSpec;
 import org.apache.pig.LoadCaster;
 import org.apache.pig.LoadFunc;
@@ -89,6 +90,8 @@
                 caster = ((LoadFunc)obj).getLoadCaster();
             } else if (obj instanceof StreamToPig) {
                 caster = ((StreamToPig)obj).getLoadCaster();
+            } else if (obj instanceof EvalFunc) {
+                caster = ((EvalFunc)obj).getLoadCaster();
             } else {
                 throw new IOException("Invalid class type "
                         + funcSpec.getClassName());
diff --git a/src/org/apache/pig/newplan/logical/visitor/LineageFindRelVisitor.java b/src/org/apache/pig/newplan/logical/visitor/LineageFindRelVisitor.java
index ec86129..116e681 100644
--- a/src/org/apache/pig/newplan/logical/visitor/LineageFindRelVisitor.java
+++ b/src/org/apache/pig/newplan/logical/visitor/LineageFindRelVisitor.java
@@ -24,6 +24,7 @@
 import java.util.Map;
 import java.util.Set;
 
+import org.apache.pig.EvalFunc;
 import org.apache.pig.FuncSpec;
 import org.apache.pig.LoadCaster;
 import org.apache.pig.LoadFunc;
@@ -731,10 +732,39 @@
 
         @Override
         public void visit(UserFuncExpression op) throws FrontendException {
-            if(op.getFuncSpec().getClassName().equals(IdentityColumn.class.getName())) {
-                // IdentityColumn only expects one arg
-                FuncSpec funcSpec = uid2LoadFuncMap.get(op.getArguments().get(0).getFieldSchema().uid);
+
+            if( op.getFieldSchema() == null ) {
+                return;
+            }
+
+            FuncSpec funcSpec = null;
+            Class loader = instantiateCaster(op.getFuncSpec());
+            List<LogicalExpression> arguments = op.getArguments();
+            if ( loader != null ) {
+                // if evalFunc.getLoadCaster() returns, simply use that.
+                funcSpec = op.getFuncSpec();
+            } else if (arguments.size() != 0 ) {
+                FuncSpec baseFuncSpec = null;
+                LogicalFieldSchema fs = arguments.get(0).getFieldSchema();
+                if ( fs != null ) {
+                    baseFuncSpec = uid2LoadFuncMap.get(fs.uid);
+                    if( baseFuncSpec != null ) {
+                        funcSpec = baseFuncSpec;
+                        for(int i = 1; i < arguments.size(); i++) {
+                            fs = arguments.get(i).getFieldSchema();
+                            if( fs == null || !haveIdenticalCasters(baseFuncSpec, uid2LoadFuncMap.get(fs.uid)) ) {
+                                funcSpec = null;
+                                break;
+                            }
+                        }
+                    }
+                }
+            }
+
+            if( funcSpec != null ) {
                 addUidLoadFuncToMap(op.getFieldSchema().uid, funcSpec);
+                // in case schema is nested, set funcSpec for all
+                setLoadFuncForUids(op.getFieldSchema().schema, funcSpec);
             }
         }
 
@@ -779,6 +809,8 @@
                 caster = ((LoadFunc)obj).getLoadCaster();
             } else if (obj instanceof StreamToPig) {
                 caster = ((StreamToPig)obj).getLoadCaster();
+            } else if (obj instanceof EvalFunc) {
+                caster = ((EvalFunc)obj).getLoadCaster();
             } else {
                 throw new VisitorException("Invalid class type " + funcSpec.getClassName(),
                                            2270, PigException.BUG );
diff --git a/test/e2e/pig/tests/negative.conf b/test/e2e/pig/tests/negative.conf
index d64a00e..9a92528 100644
--- a/test/e2e/pig/tests/negative.conf
+++ b/test/e2e/pig/tests/negative.conf
@@ -568,24 +568,7 @@
                         'expected_err_regex' => "Could not resolve StringStoreBad using imports",
 			},
 		]
-		},
-		{
-		'name' => 'LineageErrors',
-		'tests' => [
-			{
-			# UDF returns a bytearray that is cast to an integer
-                'num' => 1,
-                'pig' => q\register :FUNCPATH:/testudf.jar;
-a = load ':INPATH:/singlefile/studenttab10k' as (name, age, gpa);
-b = filter a by name lt 'b';
-c = foreach b generate org.apache.pig.test.udf.evalfunc.CreateMap((chararray)name, age);
-d = foreach c generate $0#'alice young' as field_bytearray;
-split d into e if field_bytearray < 42, f if field_bytearray >= 42;
-store e into ':OUTPATH:';\,
-                'expected_err_regex' => "Received a bytearray from the UDF or Union from two different Loaders. Cannot determine how to convert the bytearray to int for \\[field_bytearray\\[6,",
-            },
-        ]
-        }
+		}
     ]
 }
 ;
diff --git a/test/org/apache/pig/test/TestEvalPipeline.java b/test/org/apache/pig/test/TestEvalPipeline.java
index 9efde13..48ece69 100644
--- a/test/org/apache/pig/test/TestEvalPipeline.java
+++ b/test/org/apache/pig/test/TestEvalPipeline.java
@@ -290,7 +290,7 @@
             myMap.put("long", new Long(1));
             myMap.put("float", new Float(1.0));
             myMap.put("double", new Double(1.0));
-            myMap.put("dba", new DataByteArray(new String("bytes").getBytes()));
+            myMap.put("dba", new DataByteArray(new String("1234").getBytes()));
             myMap.put("map", mapInMap);
             myMap.put("tuple", tuple);
             myMap.put("bag", bag);
@@ -771,32 +771,31 @@
     }
 
     @Test
-    public void testMapUDFfail() throws Exception{
+    public void testMapUDFWithImplicitTypeCast() throws Exception{
         int LOOP_COUNT = 2;
         File tmpFile = Util.createTempFileDelOnExit("test", "txt");
         PrintStream ps = new PrintStream(new FileOutputStream(tmpFile));
         for(int i = 0; i < LOOP_COUNT; i++) {
-            for(int j=0;j<LOOP_COUNT;j+=2){
-                ps.println(i+"\t"+j);
-                ps.println(i+"\t"+j);
-            }
+            ps.println(i);
         }
         ps.close();
 
         pigServer.registerQuery("A = LOAD '"
                 + Util.generateURI(tmpFile.toString(), pigContext) + "';");
         pigServer.registerQuery("B = foreach A generate " + MapUDF.class.getName() + "($0) as mymap;"); //the argument does not matter
-        String query = "C = foreach B {"
-        + "generate mymap#'dba' * 10;"
-        + "};";
+        String query = "C = foreach B generate mymap#'dba' * 10; ";
 
         pigServer.registerQuery(query);
-        try {
-            pigServer.openIterator("C");
-            Assert.fail("Error expected.");
-        } catch (Exception e) {
-            e.getMessage().contains("Cannot determine");
+
+        Iterator<Tuple> iter = pigServer.openIterator("C");
+        if(!iter.hasNext()) Assert.fail("No output found");
+        int numIdentity = 0;
+        while(iter.hasNext()){
+            Tuple t = iter.next();
+            Assert.assertEquals(new Integer(12340), (Integer)t.get(0));
+            ++numIdentity;
         }
+        Assert.assertEquals(LOOP_COUNT, numIdentity);
     }
 
     @Test
diff --git a/test/org/apache/pig/test/TestLineageFindRelVisitor.java b/test/org/apache/pig/test/TestLineageFindRelVisitor.java
index 2b1f388..e8e6aeb 100644
--- a/test/org/apache/pig/test/TestLineageFindRelVisitor.java
+++ b/test/org/apache/pig/test/TestLineageFindRelVisitor.java
@@ -20,10 +20,14 @@
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 import junit.framework.Assert;
 
+import java.io.File;
 import java.io.IOException;
 import java.lang.reflect.Method;
+import java.util.Iterator;
+import java.util.ArrayList;
 import java.util.List;
 
 import org.apache.pig.FuncSpec;
@@ -35,6 +39,15 @@
 import org.apache.pig.data.DataByteArray;
 import org.apache.pig.data.Tuple;
 import org.apache.pig.impl.io.FileSpec;
+import org.apache.pig.impl.logicalLayer.FrontendException;
+import org.apache.pig.newplan.DependencyOrderWalker;
+import org.apache.pig.newplan.OperatorPlan;
+import org.apache.pig.newplan.PlanWalker;
+import org.apache.pig.newplan.ReverseDependencyOrderWalker;
+import org.apache.pig.newplan.logical.expression.CastExpression;
+import org.apache.pig.newplan.logical.expression.LogicalExpressionPlan;
+import org.apache.pig.newplan.logical.expression.LogicalExpressionVisitor;
+import org.apache.pig.newplan.logical.optimizer.AllExpressionVisitor;
 import org.apache.pig.newplan.logical.relational.LOLoad;
 import org.apache.pig.newplan.logical.relational.LogicalPlan;
 import org.apache.pig.newplan.logical.visitor.LineageFindRelVisitor;
@@ -49,7 +62,8 @@
 
     private PigServer pig ;
 
-    public TestLineageFindRelVisitor() throws Throwable {
+    @Before
+    public void setUp() throws Exception{
         pig = new PigServer(Util.getLocalTestMode()) ;
     }
 
@@ -80,6 +94,13 @@
         }
     }
 
+    public static class ToTupleWithCustomLoadCaster extends org.apache.pig.builtin.TOTUPLE {
+        @Override
+        public LoadCaster getLoadCaster() throws IOException {
+            return new SillyLoadCasterWithExtraConstructor("ignored");
+        }
+    }
+
     @Test
     public void testhaveIdenticalCasters() throws Exception {
         LogicalPlan lp = new LogicalPlan();
@@ -134,7 +155,7 @@
                            (Boolean) testMethod.invoke(lineageFindRelVisitor,
                                      casterWithExtraConstuctorSpec, casterWithExtraConstuctorSpec) );
 
-        Assert.assertEquals("Loader should be instantiated at most once.", SillyLoaderWithLoadCasterWithExtraConstructor.counter, 1);
+        Assert.assertEquals("Loader should be instantiated at most once.", 1, SillyLoaderWithLoadCasterWithExtraConstructor.counter);
     }
 
     @Test
@@ -164,4 +185,139 @@
                 new String[] {"('aaa', 'bbb')"});
         Util.checkQueryOutputs(actualResults.iterator(), expectedResults);
     }
+
+    @Test
+    public void testUDFForwardingLoadCaster() throws Exception {
+        Storage.Data data = Storage.resetData(pig);
+        data.set("input",
+                Storage.tuple(new DataByteArray("aaa")),
+                Storage.tuple(new DataByteArray("bbb")));
+        pig.setBatchOn();
+        String query = "A = load 'input' using mock.Storage() as (a1:bytearray);"
+            + "B = foreach A GENERATE TOTUPLE(a1) as tupleA;"
+            + "C = foreach B GENERATE (chararray) tupleA.a1;"  //using loadcaster
+            + "store C into 'output' using mock.Storage();";
+
+        LogicalPlan lp = Util.parse(query, pig.getPigContext());
+        Util.optimizeNewLP(lp);
+
+        CastFinder cf = new CastFinder(lp);
+        cf.visit();
+        Assert.assertEquals("There should be only one typecast expression.", 1, cf.casts.size());
+        Assert.assertEquals("Loadcaster should be coming from the Load", "mock.Storage", cf.casts.get(0).getFuncSpec().getClassName());
+
+        pig.registerQuery(query);
+        pig.executeBatch();
+
+        List<Tuple> actualResults = data.get("output");
+        List<Tuple> expectedResults = Util.getTuplesFromConstantTupleStrings(
+                new String[] {"('aaa')", "('bbb')"});
+        Util.checkQueryOutputs(actualResults.iterator(), expectedResults);
+    }
+
+    @Test
+    public void testUDFgetLoadCaster() throws Exception {
+        Storage.Data data = Storage.resetData(pig);
+        data.set("input",
+                Storage.tuple(new DataByteArray("aaa")),
+                Storage.tuple(new DataByteArray("bbb")));
+        pig.setBatchOn();
+        String query = "A = load 'input' using mock.Storage() as (a1:bytearray);"
+            + "B = foreach A GENERATE org.apache.pig.test.TestLineageFindRelVisitor$ToTupleWithCustomLoadCaster(a1) as tupleA;"
+            + "C = foreach B GENERATE (chararray) tupleA.a1;" //using loadcaster
+            + "store C into 'output' using mock.Storage();";
+
+        pig.registerQuery(query);
+        pig.executeBatch();
+
+        LogicalPlan lp = Util.parse(query, pig.getPigContext());
+        Util.optimizeNewLP(lp);
+
+        CastFinder cf = new CastFinder(lp);
+        cf.visit();
+        Assert.assertEquals("There should be only one typecast expression.", 1, cf.casts.size());
+        Assert.assertEquals("Loadcaster should be coming from the UDF", "org.apache.pig.test.TestLineageFindRelVisitor$ToTupleWithCustomLoadCaster", cf.casts.get(0).getFuncSpec().getClassName());
+
+        List<Tuple> actualResults = data.get("output");
+        List<Tuple> expectedResults = Util.getTuplesFromConstantTupleStrings(
+                new String[] {"('aaa')", "('bbb')"});
+        Util.checkQueryOutputs(actualResults.iterator(), expectedResults);
+    }
+
+    @Test
+    public void testUDFForwardingLoadCasterWithMultipleParams() throws Exception{
+        File inputfile = Util.createFile(new String[]{"123","456","789"});
+
+        pig.registerQuery("A = load '"
+                + inputfile.toString()
+                + "' using PigStorage() as (a1:bytearray);\n");
+        pig.registerQuery("B = load '"
+                + inputfile.toString()
+                + "' using PigStorage() as (b1:bytearray);\n");
+        pig.registerQuery("C = join A by a1, B by b1;\n");
+        pig.registerQuery("D = FOREACH C GENERATE TOTUPLE(a1,b1) as tupleD;\n");
+        pig.registerQuery("E = FOREACH D GENERATE (chararray) tupleD.a1;\n");
+        Iterator<Tuple> iter  = pig.openIterator("E");
+
+        Assert.assertEquals("123", iter.next().get(0));
+        Assert.assertEquals("456", iter.next().get(0));
+        Assert.assertEquals("789", iter.next().get(0));
+    }
+
+    @Test
+    public void testNegativeUDFForwardingLoadCasterWithMultipleParams() throws Exception {
+        File inputfile = Util.createFile(new String[]{"123","456","789"});
+
+        pig.registerQuery("A = load '"
+                + inputfile.toString()
+                + "' using PigStorage() as (a1:bytearray);\n");
+        pig.registerQuery("B = load '"
+                + inputfile.toString()
+                + "' using org.apache.pig.test.TestLineageFindRelVisitor$SillyLoaderWithLoadCasterWithExtraConstructor2() as (b1:bytearray);\n");
+        pig.registerQuery("C = join A by a1, B by b1;\n");
+        pig.registerQuery("D = FOREACH C GENERATE TOTUPLE(a1,b1) as tupleD;\n");
+        pig.registerQuery("E = FOREACH D GENERATE (chararray) tupleD.a1;\n");
+        try {
+            Iterator<Tuple> iter  = pig.openIterator("E");
+
+            // this should fail since above typecast cannot determine which
+            // loadcaster to use (one from PigStroage and another from
+            // SillyLoaderWithLoadCasterWithExtraConstructor2)
+            fail("Above typecast should fail since it cannot determine which loadcaster to use.");
+        } catch (IOException e) {
+            Assert.assertTrue(e.getMessage().contains("Unable to open iterator for alias E"));
+        }
+
+
+    }
+
+    /**
+     * Find all casts in the plan (Copied from TestTypeCheckingValidatorNewLP.java)
+     */
+    class CastFinder extends AllExpressionVisitor {
+        List<CastExpression> casts = new ArrayList<CastExpression>();
+
+        public CastFinder(OperatorPlan plan)
+                throws FrontendException {
+            super(plan, new DependencyOrderWalker(plan));
+        }
+
+        @Override
+        protected LogicalExpressionVisitor getVisitor(
+                LogicalExpressionPlan exprPlan) throws FrontendException {
+            return new CastExpFinder(exprPlan, new ReverseDependencyOrderWalker(exprPlan));
+        }
+
+        class CastExpFinder extends LogicalExpressionVisitor{
+            protected CastExpFinder(OperatorPlan p, PlanWalker walker)
+            throws FrontendException {
+                super(p, walker);
+            }
+
+            @Override
+            public void visit(CastExpression cExp){
+                casts.add(cExp);
+            }
+        }
+    }
 }
diff --git a/test/org/apache/pig/test/TestTypeCheckingValidatorNewLP.java b/test/org/apache/pig/test/TestTypeCheckingValidatorNewLP.java
index ba87cba..6f76aeb 100644
--- a/test/org/apache/pig/test/TestTypeCheckingValidatorNewLP.java
+++ b/test/org/apache/pig/test/TestTypeCheckingValidatorNewLP.java
@@ -4067,12 +4067,12 @@
 
         @Test
         public void testUDFNoInnerSchema() throws FrontendException {
-            String query = "a= load '1.txt';"
+            String query = "a= load '1.txt' using PigStorage(':') ;"
                 + "b = foreach a generate "+TestUDFTupleNullInnerSchema.class.getName()+"($0);"
                 + "c = foreach b generate flatten($0);"
                 + "d = foreach c generate $0 + 1;";
 
-            checkLastForeachCastLoadFunc(query, null, 0);
+            checkLastForeachCastLoadFunc(query, "PigStorage(':')");
         }
 
         //see PIG-1990