PIG-5194: HiveUDF fails with Spark exec type (szita)

git-svn-id: https://svn.apache.org/repos/asf/pig/trunk@1796647 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/CHANGES.txt b/CHANGES.txt
index 2b6c257..2b510c7 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -109,6 +109,8 @@
  
 BUG FIXES
 
+PIG-5194: HiveUDF fails with Spark exec type (szita)
+
 PIG-5231: PigStorage with -schema may produce inconsistent outputs with more fields (knoguchi)
 
 PIG-5224: Extra foreach from ColumnPrune preventing Accumulator usage (knoguchi)
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/SparkLauncher.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/SparkLauncher.java
index 391e83f..237fd94 100644
--- a/src/org/apache/pig/backend/hadoop/executionengine/spark/SparkLauncher.java
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/SparkLauncher.java
@@ -20,6 +20,7 @@
 import java.io.File;
 import java.io.IOException;
 import java.io.PrintStream;
+import java.net.URL;
 import java.nio.file.Files;
 import java.nio.file.Paths;
 import java.util.ArrayList;
@@ -387,8 +388,8 @@
             for (String file : shipFiles.split(",")) {
                 File shipFile = new File(file.trim());
                 if (shipFile.exists()) {
-                    addResourceToSparkJobWorkingDirectory(shipFile,
-                            shipFile.getName(), ResourceType.FILE);
+                    addResourceToSparkJobWorkingDirectory(shipFile, shipFile.getName(),
+                            shipFile.getName().endsWith(".jar") ? ResourceType.JAR : ResourceType.FILE );
                 }
             }
         }
@@ -429,7 +430,7 @@
         Set<String> allJars = new HashSet<String>();
         LOG.info("Add default jars to Spark Job");
         allJars.addAll(JarManager.getDefaultJars());
-        LOG.info("Add extra jars to Spark Job");
+        LOG.info("Add script jars to Spark Job");
         for (String scriptJar : pigContext.scriptJars) {
             allJars.add(scriptJar);
         }
@@ -448,6 +449,11 @@
             allJars.add(scriptUDFJarFile.getAbsolutePath().toString());
         }
 
+        LOG.info("Add extra jars to Spark job");
+        for (URL extraJarUrl : pigContext.extraJars) {
+            allJars.add(extraJarUrl.getFile());
+        }
+
         //Upload all jars to spark working directory
         for (String jar : allJars) {
             File jarFile = new File(jar);
diff --git a/src/org/apache/pig/builtin/HiveUDAF.java b/src/org/apache/pig/builtin/HiveUDAF.java
index cf53d7c..b86159d 100644
--- a/src/org/apache/pig/builtin/HiveUDAF.java
+++ b/src/org/apache/pig/builtin/HiveUDAF.java
@@ -135,11 +135,11 @@
                     return;
                 }
 
-                if (m == Mode.PARTIAL1 || m == Mode.FINAL) {
+                if (m == Mode.PARTIAL1 || m == Mode.PARTIAL2 || m == Mode.FINAL) {
                     intermediateOutputObjectInspector = evaluator.init(Mode.PARTIAL1, inputObjectInspectorAsArray);
                     intermediateOutputTypeInfo = TypeInfoUtils.getTypeInfoFromObjectInspector(intermediateOutputObjectInspector);
 
-                    if (m == Mode.FINAL) {
+                    if (m == Mode.PARTIAL2 || m == Mode.FINAL) {
                         intermediateInputObjectInspector = HiveUtils.createObjectInspector(intermediateOutputTypeInfo);
                         intermediateInputObjectInspectorAsArray = new ObjectInspector[] {intermediateInputObjectInspector};
                         outputObjectInspector = evaluator.init(Mode.FINAL, intermediateInputObjectInspectorAsArray);
@@ -208,20 +208,41 @@
     }
 
     static public class Initial extends EvalFunc<Tuple> {
+
+        private boolean inited = false;
+        private String funcName;
+        ConstantObjectInspectInfo constantsInfo;
+        private SchemaAndEvaluatorInfo schemaAndEvaluatorInfo = new SchemaAndEvaluatorInfo();
+        private static TupleFactory tf = TupleFactory.getInstance();
+
         public Initial(String funcName) {
+            this.funcName = funcName;
         }
-        public Initial(String funcName, String params) {
+        public Initial(String funcName, String params) throws IOException {
+            this.funcName = funcName;
+            constantsInfo = ConstantObjectInspectInfo.parse(params);
         }
         @Override
         public Tuple exec(Tuple input) throws IOException {
-
-            DataBag bg = (DataBag) input.get(0);
-            Tuple tp = null;
-            if(bg.iterator().hasNext()) {
-                tp = bg.iterator().next();
+            try {
+                if (!inited) {
+                    schemaAndEvaluatorInfo.init(getInputSchema(), instantiateUDAF(funcName), Mode.PARTIAL1, constantsInfo);
+                    inited = true;
+                }
+                DataBag b = (DataBag)input.get(0);
+                AggregationBuffer agg = schemaAndEvaluatorInfo.evaluator.getNewAggregationBuffer();
+                for (Iterator<Tuple> it = b.iterator(); it.hasNext();) {
+                    Tuple t = it.next();
+                    List inputs = schemaAndEvaluatorInfo.inputObjectInspector.getStructFieldsDataAsList(t);
+                    schemaAndEvaluatorInfo.evaluator.iterate(agg, inputs.toArray());
+                }
+                Object returnValue = schemaAndEvaluatorInfo.evaluator.terminatePartial(agg);
+                Tuple result = tf.newTuple();
+                result.append(HiveUtils.convertHiveToPig(returnValue, schemaAndEvaluatorInfo.intermediateOutputObjectInspector, null));
+                return result;
+            } catch (Exception e) {
+                throw new IOException(e);
             }
-
-            return tp;
         }
     }
 
@@ -244,15 +265,14 @@
         public Tuple exec(Tuple input) throws IOException {
             try {
                 if (!inited) {
-                    schemaAndEvaluatorInfo.init(getInputSchema(), instantiateUDAF(funcName), Mode.PARTIAL1, constantsInfo);
+                    schemaAndEvaluatorInfo.init(getInputSchema(), instantiateUDAF(funcName), Mode.PARTIAL2, constantsInfo);
                     inited = true;
                 }
                 DataBag b = (DataBag)input.get(0);
                 AggregationBuffer agg = schemaAndEvaluatorInfo.evaluator.getNewAggregationBuffer();
                 for (Iterator<Tuple> it = b.iterator(); it.hasNext();) {
                     Tuple t = it.next();
-                    List inputs = schemaAndEvaluatorInfo.inputObjectInspector.getStructFieldsDataAsList(t);
-                    schemaAndEvaluatorInfo.evaluator.iterate(agg, inputs.toArray());
+                    schemaAndEvaluatorInfo.evaluator.merge(agg, t.get(0));
                 }
                 Object returnValue = schemaAndEvaluatorInfo.evaluator.terminatePartial(agg);
                 Tuple result = tf.newTuple();