PIG-5157:Upgrade to Spark 2.0 (nkollar via liyunzhang)

git-svn-id: https://svn.apache.org/repos/asf/pig/trunk@1802347 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/build.xml b/build.xml
index d274af9..a99ca39 100644
--- a/build.xml
+++ b/build.xml
@@ -207,8 +207,7 @@
     <property name="ivy.repo.dir" value="${user.home}/ivyrepo" />
     <property name="ivy.dir" location="ivy" />
     <property name="loglevel" value="quiet" />
-    <loadproperties srcfile="${ivy.dir}/libraries.properties"/>
-
+    <loadproperties srcfile="${ivy.dir}/libraries.properties" />
 
     <!--
       Hadoop master version
@@ -241,6 +240,11 @@
         </then>
     </if>
     <property name="hbaseversion" value="1" />
+    <property name="sparkversion" value="1" />
+
+    <condition property="src.exclude.dir" value="**/Spark2*.java" else="**/Spark1*.java">
+        <equals arg1="${sparkversion}" arg2="1"/>
+    </condition>
 
     <property name="src.shims.dir" value="${basedir}/shims/src/hadoop${hadoopversion}" />
     <property name="src.shims.test.dir" value="${basedir}/shims/test/hadoop${hadoopversion}" />
@@ -556,7 +560,7 @@
         <echo>*** Building Main Sources ***</echo>
         <echo>*** To compile with all warnings enabled, supply -Dall.warnings=1 on command line ***</echo>
         <echo>*** Else, you will only be warned about deprecations ***</echo>
-        <echo>*** Hadoop version used: ${hadoopversion} ; HBase version used: ${hbaseversion} ***</echo>
+        <echo>*** Hadoop version used: ${hadoopversion} ; HBase version used: ${hbaseversion} ; Spark version used: ${sparkversion} ***</echo>
         <compileSources sources="${src.dir};${src.gen.dir};${src.lib.dir}/bzip2;${src.shims.dir}"
             excludes="${src.exclude.dir}" dist="${build.classes}" cp="classpath" warnings="${javac.args.warnings}" />
         <copy todir="${build.classes}/META-INF">
@@ -688,7 +692,7 @@
     <!-- ================================================================== -->
     <!-- Make pig.jar                                                       -->
     <!-- ================================================================== -->
-    <target name="jar" depends="compile,ivy-buildJar" description="Create pig core jar">
+    <target name="jar-simple" depends="compile,ivy-buildJar" description="Create pig core jar">
         <buildJar svnString="${svn.revision}" outputFile="${output.jarfile.core}" includedJars="core.dependencies.jar"/>
         <buildJar svnString="${svn.revision}" outputFile="${output.jarfile.withouthadoop}" includedJars="runtime.dependencies-withouthadoop.jar"/>
         <antcall target="copyCommonDependencies"/>
@@ -788,6 +792,35 @@
         </sequential>
     </macrodef>
 
+    <target name="jar-core" depends="compile,ivy-buildJar" description="Create only pig core jar">
+        <buildJar svnString="${svn.revision}" outputFile="${output.jarfile.core}" includedJars="core.dependencies.jar"/>
+    </target>
+
+    <target name="jar" description="Create pig jar with Spark 1 and 2">
+        <echo>Compiling against Spark 2</echo>
+        <antcall target="clean" inheritRefs="true" inheritall="true"/>
+        <propertyreset name="sparkversion" value="2"/>
+        <propertyreset name="src.exclude.dir" value="**/Spark1*.java" />
+        <antcall target="jar-core" inheritRefs="true" inheritall="true"/>
+        <move file="${output.jarfile.core}" tofile="${basedir}/_pig-shims.jar"/>
+
+        <echo>Compiling against Spark 1</echo>
+        <antcall target="clean" inheritRefs="true" inheritall="true"/>
+        <propertyreset name="sparkversion" value="1"/>
+        <propertyreset name="src.exclude.dir" value="**/Spark2*.java" />
+        <antcall target="jar-simple" inheritRefs="true" inheritall="true"/>
+        <jar update="yes" jarfile="${output.jarfile.core}">
+            <zipfileset src="${basedir}/_pig-shims.jar" includes="**/Spark2*.class"/>
+        </jar>
+        <jar update="yes" jarfile="${output.jarfile.backcompat-core-h2}">
+            <zipfileset src="${basedir}/_pig-shims.jar" includes="**/Spark2*.class"/>
+        </jar>
+        <jar update="yes" jarfile="${output.jarfile.withouthadoop-h2}">
+            <zipfileset src="${basedir}/_pig-shims.jar" includes="**/Spark2*.class"/>
+        </jar>
+        <delete file="${basedir}/_pig-shims.jar"/>
+    </target>
+
     <!-- ================================================================== -->
     <!-- macrodef: buildJar                                                 -->
     <!-- ================================================================== -->
@@ -1655,7 +1688,7 @@
 
      <target name="ivy-resolve" depends="ivy-init" unless="ivy.resolved" description="Resolve Ivy dependencies">
        <property name="ivy.resolved" value="true"/>
-       <echo>*** Ivy resolve with Hadoop ${hadoopversion} and HBase ${hbaseversion} ***</echo>
+       <echo>*** Ivy resolve with Hadoop ${hadoopversion}, Spark ${sparkversion} and HBase ${hbaseversion} ***</echo>
        <ivy:resolve log="${loglevel}" settingsRef="${ant.project.name}.ivy.settings" conf="compile"/>
        <ivy:report toDir="build/ivy/report"/>
      </target>
@@ -1664,7 +1697,7 @@
        <ivy:retrieve settingsRef="${ant.project.name}.ivy.settings" log="${loglevel}"
                  pattern="${build.ivy.lib.dir}/${ivy.artifact.retrieve.pattern}" conf="compile"/>
        <ivy:retrieve settingsRef="${ant.project.name}.ivy.settings" log="${loglevel}"
-                 pattern="${ivy.lib.dir.spark}/[artifact]-[revision](-[classifier]).[ext]" conf="spark"/>
+                 pattern="${ivy.lib.dir.spark}/[artifact]-[revision](-[classifier]).[ext]" conf="spark${sparkversion}"/>
        <ivy:cachepath pathid="compile.classpath" conf="compile"/>
      </target>
 
diff --git a/ivy.xml b/ivy.xml
index 3f2c943..db722a5 100644
--- a/ivy.xml
+++ b/ivy.xml
@@ -40,7 +40,8 @@
     <conf name="buildJar" extends="compile,test" visibility="private"/>
     <conf name="hadoop2" visibility="private"/>
     <conf name="hbase1" visibility="private"/>
-    <conf name="spark" visibility="private" />
+    <conf name="spark1" visibility="private" />
+    <conf name="spark2" visibility="private" />
   </configurations>
   <publications>
     <artifact name="pig" conf="master"/>
@@ -407,8 +408,8 @@
 
     <dependency org="com.twitter" name="parquet-pig-bundle" rev="${parquet-pig-bundle.version}" conf="compile->master"/>
 
-    <!-- for Spark integration -->
-    <dependency org="org.apache.spark" name="spark-core_2.11" rev="${spark.version}" conf="spark->default">
+    <!-- for Spark 1.x integration -->
+    <dependency org="org.apache.spark" name="spark-core_2.11" rev="${spark1.version}" conf="spark1->default">
         <exclude org="org.eclipse.jetty.orbit" module="javax.servlet"/>
         <exclude org="org.eclipse.jetty.orbit" module="javax.transaction"/>
         <exclude org="org.eclipse.jetty.orbit" module="javax.mail.glassfish"/>
@@ -418,12 +419,28 @@
         <exclude org="jline" module="jline"/>
         <exclude org="com.google.guava" />
     </dependency>
-    <dependency org="org.apache.spark" name="spark-yarn_2.11" rev="${spark.version}" conf="spark->default">
+    <dependency org="org.apache.spark" name="spark-yarn_2.11" rev="${spark1.version}" conf="spark1->default">
         <exclude org="org.apache.hadoop" />
     </dependency>
+
+    <!-- for Spark 2.x integration -->
+    <dependency org="org.apache.spark" name="spark-core_2.11" rev="${spark2.version}" conf="spark2->default">
+      <exclude org="org.eclipse.jetty.orbit" module="javax.servlet"/>
+      <exclude org="org.eclipse.jetty.orbit" module="javax.transaction"/>
+      <exclude org="org.eclipse.jetty.orbit" module="javax.mail.glassfish"/>
+      <exclude org="org.eclipse.jetty.orbit" module="javax.activation"/>
+      <exclude org="org.apache.hadoop" />
+      <exclude org="com.esotericsoftware.kryo" />
+      <exclude org="jline" module="jline"/>
+      <exclude org="com.google.guava" />
+    </dependency>
+    <dependency org="org.apache.spark" name="spark-yarn_2.11" rev="${spark2.version}" conf="spark2->default">
+      <exclude org="org.apache.hadoop" />
+    </dependency>
+
     <dependency org="asm" name="asm" rev="${asm.version}" conf="compile->master"/>
-    <dependency org="javax.servlet" name="javax.servlet-api" rev="3.0.1" conf="spark->default"/>
-    <dependency org="org.scala-lang.modules" name="scala-xml_2.11" rev="${scala-xml.version}" conf="spark->default"/>
+    <dependency org="javax.servlet" name="javax.servlet-api" rev="3.0.1" conf="spark1->default;spark2->default"/>
+    <dependency org="org.scala-lang.modules" name="scala-xml_2.11" rev="${scala-xml.version}" conf="spark1->default;spark2->default"/>
 
     <!-- for Tez integration -->
     <dependency org="org.apache.tez" name="tez" rev="${tez.version}"
diff --git a/ivy/libraries.properties b/ivy/libraries.properties
index c2aed45..de1f324 100644
--- a/ivy/libraries.properties
+++ b/ivy/libraries.properties
@@ -73,7 +73,8 @@
 rats-lib.version=0.5.1
 slf4j-api.version=1.6.1
 slf4j-log4j12.version=1.6.1
-spark.version=1.6.1
+spark1.version=1.6.1
+spark2.version=2.1.1
 xerces.version=2.10.0
 xalan.version=2.7.1
 wagon-http.version=1.0-beta-2
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/FlatMapFunctionAdapter.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/FlatMapFunctionAdapter.java
new file mode 100644
index 0000000..c1d297f
--- /dev/null
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/FlatMapFunctionAdapter.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.pig.backend.hadoop.executionengine.spark;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+public interface FlatMapFunctionAdapter<R, T> extends Serializable {
+    Iterator<T> call(final R r) throws Exception;
+}
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/JobGraphBuilder.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/JobGraphBuilder.java
index 5eac045..fac6679 100644
--- a/src/org/apache/pig/backend/hadoop/executionengine/spark/JobGraphBuilder.java
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/JobGraphBuilder.java
@@ -81,7 +81,7 @@
     private Map<Class<? extends PhysicalOperator>, RDDConverter> convertMap = null;
     private SparkPigStats sparkStats = null;
     private JavaSparkContext sparkContext = null;
-    private JobMetricsListener jobMetricsListener = null;
+    private JobStatisticCollector jobStatisticCollector = null;
     private String jobGroupID = null;
     private Set<Integer> seenJobIDs = new HashSet<Integer>();
     private SparkOperPlan sparkPlan = null;
@@ -91,14 +91,14 @@
     private PigContext pc;
 
     public JobGraphBuilder(SparkOperPlan plan, Map<Class<? extends PhysicalOperator>, RDDConverter> convertMap,
-                           SparkPigStats sparkStats, JavaSparkContext sparkContext, JobMetricsListener
-            jobMetricsListener, String jobGroupID, JobConf jobConf, PigContext pc) {
+                           SparkPigStats sparkStats, JavaSparkContext sparkContext, JobStatisticCollector
+            jobStatisticCollector, String jobGroupID, JobConf jobConf, PigContext pc) {
         super(plan, new DependencyOrderWalker<SparkOperator, SparkOperPlan>(plan, true));
         this.sparkPlan = plan;
         this.convertMap = convertMap;
         this.sparkStats = sparkStats;
         this.sparkContext = sparkContext;
-        this.jobMetricsListener = jobMetricsListener;
+        this.jobStatisticCollector = jobStatisticCollector;
         this.jobGroupID = jobGroupID;
         this.jobConf = jobConf;
         this.pc = pc;
@@ -223,7 +223,7 @@
                             }
                         }
                         SparkStatsUtil.waitForJobAddStats(jobIDs.get(i++), poStore, sparkOperator,
-                                jobMetricsListener, sparkContext, sparkStats);
+                                jobStatisticCollector, sparkContext, sparkStats);
                     }
                 } else {
                     for (POStore poStore : poStores) {
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/JobMetricsListener.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/JobMetricsListener.java
deleted file mode 100644
index f813412..0000000
--- a/src/org/apache/pig/backend/hadoop/executionengine/spark/JobMetricsListener.java
+++ /dev/null
@@ -1,227 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.pig.backend.hadoop.executionengine.spark;
-
-import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
-import com.google.common.collect.Sets;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.spark.executor.TaskMetrics;
-import org.apache.spark.scheduler.SparkListener;
-import org.apache.spark.scheduler.SparkListenerApplicationEnd;
-import org.apache.spark.scheduler.SparkListenerApplicationStart;
-import org.apache.spark.scheduler.SparkListenerBlockManagerAdded;
-import org.apache.spark.scheduler.SparkListenerBlockManagerRemoved;
-import org.apache.spark.scheduler.SparkListenerBlockUpdated;
-import org.apache.spark.scheduler.SparkListenerEnvironmentUpdate;
-import org.apache.spark.scheduler.SparkListenerExecutorAdded;
-import org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate;
-import org.apache.spark.scheduler.SparkListenerExecutorRemoved;
-import org.apache.spark.scheduler.SparkListenerJobEnd;
-import org.apache.spark.scheduler.SparkListenerJobStart;
-import org.apache.spark.scheduler.SparkListenerStageCompleted;
-import org.apache.spark.scheduler.SparkListenerStageSubmitted;
-import org.apache.spark.scheduler.SparkListenerTaskEnd;
-import org.apache.spark.scheduler.SparkListenerTaskGettingResult;
-import org.apache.spark.scheduler.SparkListenerTaskStart;
-import org.apache.spark.scheduler.SparkListenerUnpersistRDD;
-
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-
-public class JobMetricsListener implements SparkListener {
-
-    private static final Log LOG = LogFactory.getLog(JobMetricsListener.class);
-
-    private final Map<Integer, int[]> jobIdToStageId = Maps.newHashMap();
-    private final Map<Integer, Integer> stageIdToJobId = Maps.newHashMap();
-    private final Map<Integer, Map<String, List<TaskMetrics>>> allJobMetrics = Maps.newHashMap();
-    private final Set<Integer> finishedJobIds = Sets.newHashSet();
-
-    @Override
-    public void onStageCompleted(SparkListenerStageCompleted stageCompleted) {
-//        uncomment and remove the code onTaskEnd until we fix PIG-5157. It is better to update taskMetrics of stage when stage completes
-//        if we update taskMetrics in onTaskEnd(), it consumes lot of memory.
-//        int stageId = stageCompleted.stageInfo().stageId();
-//        int stageAttemptId = stageCompleted.stageInfo().attemptId();
-//        String stageIdentifier = stageId + "_" + stageAttemptId;
-//        Integer jobId = stageIdToJobId.get(stageId);
-//        if (jobId == null) {
-//            LOG.warn("Cannot find job id for stage[" + stageId + "].");
-//        } else {
-//            Map<String, List<TaskMetrics>> jobMetrics = allJobMetrics.get(jobId);
-//            if (jobMetrics == null) {
-//                jobMetrics = Maps.newHashMap();
-//                allJobMetrics.put(jobId, jobMetrics);
-//            }
-//            List<TaskMetrics> stageMetrics = jobMetrics.get(stageIdentifier);
-//            if (stageMetrics == null) {
-//                stageMetrics = Lists.newLinkedList();
-//                jobMetrics.put(stageIdentifier, stageMetrics);
-//            }
-//            // uncomment until we fix PIG-5157. after we upgrade to spark2.0 StageInfo().taskMetrics() api is available
-//            // stageMetrics.add(stageCompleted.stageInfo().taskMetrics());
-//        }
-    }
-
-    @Override
-    public void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) {
-
-    }
-
-    @Override
-    public void onTaskStart(SparkListenerTaskStart taskStart) {
-
-    }
-
-    @Override
-    public void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) {
-
-    }
-
-    @Override
-    public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) {
-
-    }
-
-    @Override
-    public void onExecutorAdded(SparkListenerExecutorAdded executorAdded) {
-
-    }
-
-    @Override
-    public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated){
-
-    }
-
-    @Override
-    public synchronized void onTaskEnd(SparkListenerTaskEnd taskEnd) {
-        int stageId = taskEnd.stageId();
-        int stageAttemptId = taskEnd.stageAttemptId();
-        String stageIdentifier = stageId + "_" + stageAttemptId;
-        Integer jobId = stageIdToJobId.get(stageId);
-        if (jobId == null) {
-            LOG.warn("Cannot find job id for stage[" + stageId + "].");
-        } else {
-            Map<String, List<TaskMetrics>> jobMetrics = allJobMetrics.get(jobId);
-            if (jobMetrics == null) {
-                jobMetrics = Maps.newHashMap();
-                allJobMetrics.put(jobId, jobMetrics);
-            }
-            List<TaskMetrics> stageMetrics = jobMetrics.get(stageIdentifier);
-            if (stageMetrics == null) {
-                stageMetrics = Lists.newLinkedList();
-                jobMetrics.put(stageIdentifier, stageMetrics);
-            }
-            stageMetrics.add(taskEnd.taskMetrics());
-        }
-    }
-
-    @Override
-    public synchronized void onJobStart(SparkListenerJobStart jobStart) {
-        int jobId = jobStart.jobId();
-        int size = jobStart.stageIds().size();
-        int[] intStageIds = new int[size];
-        for (int i = 0; i < size; i++) {
-            Integer stageId = (Integer) jobStart.stageIds().apply(i);
-            intStageIds[i] = stageId;
-            stageIdToJobId.put(stageId, jobId);
-        }
-        jobIdToStageId.put(jobId, intStageIds);
-    }
-
-    @Override
-    public synchronized void onJobEnd(SparkListenerJobEnd jobEnd) {
-        finishedJobIds.add(jobEnd.jobId());
-        notify();
-    }
-
-    @Override
-    public void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) {
-
-    }
-
-    @Override
-    public void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) {
-
-    }
-
-    @Override
-    public void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) {
-
-    }
-
-    @Override
-    public void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) {
-
-    }
-
-    @Override
-    public void onApplicationStart(SparkListenerApplicationStart applicationStart) {
-
-    }
-
-    @Override
-    public void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) {
-
-    }
-
-    @Override
-    public void onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate executorMetricsUpdate) {
-
-    }
-
-
-    public synchronized Map<String, List<TaskMetrics>> getJobMetric(int jobId) {
-        return allJobMetrics.get(jobId);
-    }
-
-    public synchronized boolean waitForJobToEnd(int jobId) throws InterruptedException {
-        if (finishedJobIds.contains(jobId)) {
-            finishedJobIds.remove(jobId);
-            return true;
-        }
-
-        wait();
-        return false;
-    }
-
-    public synchronized void cleanup(int jobId) {
-        allJobMetrics.remove(jobId);
-        jobIdToStageId.remove(jobId);
-        Iterator<Map.Entry<Integer, Integer>> iterator = stageIdToJobId.entrySet().iterator();
-        while (iterator.hasNext()) {
-            Map.Entry<Integer, Integer> entry = iterator.next();
-            if (entry.getValue() == jobId) {
-                iterator.remove();
-            }
-        }
-    }
-
-    public synchronized void reset() {
-        stageIdToJobId.clear();
-        jobIdToStageId.clear();
-        allJobMetrics.clear();
-        finishedJobIds.clear();
-    }
-}
\ No newline at end of file
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/JobStatisticCollector.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/JobStatisticCollector.java
new file mode 100644
index 0000000..8e16eac
--- /dev/null
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/JobStatisticCollector.java
@@ -0,0 +1,89 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.pig.backend.hadoop.executionengine.spark;
+
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.scheduler.SparkListener;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+public class JobStatisticCollector {
+
+    private final Map<Integer, int[]> jobIdToStageId = Maps.newHashMap();
+    private final Map<Integer, Integer> stageIdToJobId = Maps.newHashMap();
+    private final Map<Integer, Map<String, List<TaskMetrics>>> allJobStatistics = Maps.newHashMap();
+    private final Set<Integer> finishedJobIds = Sets.newHashSet();
+
+    private SparkListener sparkListener;
+
+    public SparkListener getSparkListener() {
+        if (sparkListener == null) {
+            sparkListener = SparkShims.getInstance()
+                    .getJobMetricsListener(jobIdToStageId, stageIdToJobId, allJobStatistics, finishedJobIds);
+        }
+        return sparkListener;
+    }
+
+    public Map<String, List<TaskMetrics>> getJobMetric(int jobId) {
+        synchronized (sparkListener) {
+            return allJobStatistics.get(jobId);
+        }
+    }
+
+    public boolean waitForJobToEnd(int jobId) throws InterruptedException {
+        synchronized (sparkListener) {
+            if (finishedJobIds.contains(jobId)) {
+                finishedJobIds.remove(jobId);
+                return true;
+            }
+
+            sparkListener.wait();
+            return false;
+        }
+    }
+
+    public void cleanup(int jobId) {
+        synchronized (sparkListener) {
+            allJobStatistics.remove(jobId);
+            jobIdToStageId.remove(jobId);
+            Iterator<Map.Entry<Integer, Integer>> iterator = stageIdToJobId.entrySet().iterator();
+            while (iterator.hasNext()) {
+                Map.Entry<Integer, Integer> entry = iterator.next();
+                if (entry.getValue() == jobId) {
+                    iterator.remove();
+                }
+            }
+        }
+    }
+
+    public void reset() {
+        synchronized (sparkListener) {
+            stageIdToJobId.clear();
+            jobIdToStageId.clear();
+            allJobStatistics.clear();
+            finishedJobIds.clear();
+        }
+    }
+}
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/PairFlatMapFunctionAdapter.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/PairFlatMapFunctionAdapter.java
new file mode 100644
index 0000000..296413b
--- /dev/null
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/PairFlatMapFunctionAdapter.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.pig.backend.hadoop.executionengine.spark;
+
+import scala.Tuple2;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+public interface PairFlatMapFunctionAdapter<T, K, V> extends Serializable {
+    Iterator<Tuple2<K, V>> call(T t) throws Exception;
+}
\ No newline at end of file
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/Spark1Shims.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/Spark1Shims.java
new file mode 100644
index 0000000..c7974fe
--- /dev/null
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/Spark1Shims.java
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.pig.backend.hadoop.executionengine.spark;
+
+import com.google.common.base.Optional;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.pig.data.Tuple;
+import org.apache.pig.tools.pigstats.PigStats;
+import org.apache.pig.tools.pigstats.spark.SparkJobStats;
+import org.apache.pig.tools.pigstats.spark.Spark1JobStats;
+import org.apache.spark.SparkContext;
+import org.apache.spark.api.java.function.FlatMapFunction;
+import org.apache.spark.api.java.function.PairFlatMapFunction;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.scheduler.SparkListener;
+import org.apache.spark.scheduler.SparkListenerApplicationEnd;
+import org.apache.spark.scheduler.SparkListenerApplicationStart;
+import org.apache.spark.scheduler.SparkListenerBlockManagerAdded;
+import org.apache.spark.scheduler.SparkListenerBlockManagerRemoved;
+import org.apache.spark.scheduler.SparkListenerBlockUpdated;
+import org.apache.spark.scheduler.SparkListenerEnvironmentUpdate;
+import org.apache.spark.scheduler.SparkListenerExecutorAdded;
+import org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate;
+import org.apache.spark.scheduler.SparkListenerExecutorRemoved;
+import org.apache.spark.scheduler.SparkListenerJobEnd;
+import org.apache.spark.scheduler.SparkListenerJobStart;
+import org.apache.spark.scheduler.SparkListenerStageCompleted;
+import org.apache.spark.scheduler.SparkListenerStageSubmitted;
+import org.apache.spark.scheduler.SparkListenerTaskEnd;
+import org.apache.spark.scheduler.SparkListenerTaskGettingResult;
+import org.apache.spark.scheduler.SparkListenerTaskStart;
+import org.apache.spark.scheduler.SparkListenerUnpersistRDD;
+import scala.Tuple2;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+public class Spark1Shims extends SparkShims {
+    @Override
+    public <T, R> FlatMapFunction<T, R> flatMapFunction(final FlatMapFunctionAdapter<T, R> function) {
+        return new FlatMapFunction<T, R>() {
+            @Override
+            public Iterable<R> call(final T t) throws Exception {
+                return new Iterable<R>() {
+                    @Override
+                    public Iterator<R> iterator() {
+                        try {
+                            return function.call(t);
+                        } catch (Exception e) {
+                            throw new RuntimeException(e);
+                        }
+                    }
+                };
+
+            }
+        };
+    }
+
+    @Override
+    public <T, K, V> PairFlatMapFunction<T, K, V> pairFlatMapFunction(final PairFlatMapFunctionAdapter<T, K, V> function) {
+        return new PairFlatMapFunction<T, K, V>() {
+            @Override
+            public Iterable<Tuple2<K, V>> call(final T t) throws Exception {
+                return new Iterable<Tuple2<K, V>>() {
+                    @Override
+                    public Iterator<Tuple2<K, V>> iterator() {
+                        try {
+                            return function.call(t);
+                        } catch (Exception e) {
+                            throw new RuntimeException(e);
+                        }
+                    }
+                };
+
+            }
+        };
+    }
+
+    @Override
+    public RDD<Tuple> coalesce(RDD<Tuple> rdd, int numPartitions, boolean shuffle) {
+        return rdd.coalesce(numPartitions, shuffle, null);
+    }
+
+    @Override
+    public SparkJobStats sparkJobStats(int jobId, PigStats.JobGraph plan, Configuration conf) {
+        return new Spark1JobStats(jobId, plan, conf);
+    }
+
+    @Override
+    public SparkJobStats sparkJobStats(String jobId, PigStats.JobGraph plan, Configuration conf) {
+        return new Spark1JobStats(jobId, plan, conf);
+    }
+
+    @Override
+    public <T> OptionalWrapper<T> wrapOptional(T tuple) {
+        final Optional<T> t = (Optional<T>) tuple;
+
+        return new OptionalWrapper<T>() {
+            @Override
+            public boolean isPresent() {
+                return t.isPresent();
+            }
+
+            @Override
+            public T get() {
+                return t.get();
+            }
+        };
+    }
+
+    private static class JobMetricsListener implements SparkListener {
+        private final Log LOG = LogFactory.getLog(JobMetricsListener.class);
+
+        private Map<Integer, int[]> jobIdToStageId;
+        private Map<Integer, Integer> stageIdToJobId;
+        private Map<Integer, Map<String, List<TaskMetrics>>> allJobMetrics;
+        private Set<Integer> finishedJobIds;
+
+        JobMetricsListener(final Map<Integer, int[]> jobIdToStageId,
+                           final Map<Integer, Integer> stageIdToJobId,
+                           final Map<Integer, Map<String, List<TaskMetrics>>> allJobMetrics,
+                           final Set<Integer> finishedJobIds) {
+            this.jobIdToStageId = jobIdToStageId;
+            this.stageIdToJobId = stageIdToJobId;
+            this.allJobMetrics = allJobMetrics;
+            this.finishedJobIds = finishedJobIds;
+        }
+
+        @Override
+        public void onStageCompleted(SparkListenerStageCompleted stageCompleted) {
+        }
+
+        @Override
+        public void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) {
+        }
+
+        @Override
+        public void onTaskStart(SparkListenerTaskStart taskStart) {
+        }
+
+        @Override
+        public void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) {
+        }
+
+        @Override
+        public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) {
+        }
+
+        @Override
+        public void onExecutorAdded(SparkListenerExecutorAdded executorAdded) {
+        }
+
+        @Override
+        public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) {
+        }
+
+        @Override
+        public synchronized void onTaskEnd(SparkListenerTaskEnd taskEnd) {
+            int stageId = taskEnd.stageId();
+            int stageAttemptId = taskEnd.stageAttemptId();
+            String stageIdentifier = stageId + "_" + stageAttemptId;
+            Integer jobId = stageIdToJobId.get(stageId);
+            if (jobId == null) {
+                LOG.warn("Cannot find job id for stage[" + stageId + "].");
+            } else {
+                Map<String, List<TaskMetrics>> jobMetrics = allJobMetrics.get(jobId);
+                if (jobMetrics == null) {
+                    jobMetrics = Maps.newHashMap();
+                    allJobMetrics.put(jobId, jobMetrics);
+                }
+                List<TaskMetrics> stageMetrics = jobMetrics.get(stageIdentifier);
+                if (stageMetrics == null) {
+                    stageMetrics = Lists.newLinkedList();
+                    jobMetrics.put(stageIdentifier, stageMetrics);
+                }
+                stageMetrics.add(taskEnd.taskMetrics());
+            }
+        }
+
+        @Override
+        public synchronized void onJobStart(SparkListenerJobStart jobStart) {
+            int jobId = jobStart.jobId();
+            int size = jobStart.stageIds().size();
+            int[] intStageIds = new int[size];
+            for (int i = 0; i < size; i++) {
+                Integer stageId = (Integer) jobStart.stageIds().apply(i);
+                intStageIds[i] = stageId;
+                stageIdToJobId.put(stageId, jobId);
+            }
+            jobIdToStageId.put(jobId, intStageIds);
+        }
+
+        @Override
+        public synchronized void onJobEnd(SparkListenerJobEnd jobEnd) {
+            finishedJobIds.add(jobEnd.jobId());
+            notify();
+        }
+
+        @Override
+        public void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) {
+        }
+
+        @Override
+        public void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) {
+        }
+
+        @Override
+        public void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) {
+        }
+
+        @Override
+        public void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) {
+        }
+
+        @Override
+        public void onApplicationStart(SparkListenerApplicationStart applicationStart) {
+        }
+
+        @Override
+        public void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) {
+        }
+
+        @Override
+        public void onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate executorMetricsUpdate) {
+        }
+    }
+
+    @Override
+    public SparkListener getJobMetricsListener(Map<Integer, int[]> jobIdToStageId,
+                                                          Map<Integer, Integer> stageIdToJobId,
+                                                          Map<Integer, Map<String, List<TaskMetrics>>> allJobMetrics,
+                                                          Set<Integer> finishedJobIds) {
+        return new JobMetricsListener(jobIdToStageId, stageIdToJobId, allJobMetrics, finishedJobIds);
+    }
+
+    @Override
+    public void addSparkListener(SparkContext sc, SparkListener sparkListener) {
+        sc.addSparkListener(sparkListener);
+    }
+}
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/Spark2Shims.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/Spark2Shims.java
new file mode 100644
index 0000000..5a85dff
--- /dev/null
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/Spark2Shims.java
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.pig.backend.hadoop.executionengine.spark;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.pig.data.Tuple;
+import org.apache.pig.tools.pigstats.PigStats;
+import org.apache.pig.tools.pigstats.spark.Spark2JobStats;
+import org.apache.pig.tools.pigstats.spark.SparkJobStats;
+import org.apache.spark.SparkContext;
+import org.apache.spark.api.java.Optional;
+import org.apache.spark.api.java.function.FlatMapFunction;
+import org.apache.spark.api.java.function.PairFlatMapFunction;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.rdd.PartitionCoalescer;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.scheduler.SparkListener;
+import org.apache.spark.scheduler.SparkListenerJobEnd;
+import org.apache.spark.scheduler.SparkListenerJobStart;
+import org.apache.spark.scheduler.SparkListenerStageCompleted;
+import scala.Option;
+import scala.Tuple2;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+public class Spark2Shims extends SparkShims {
+    @Override
+    public <T, R> FlatMapFunction flatMapFunction(final FlatMapFunctionAdapter<T, R> function) {
+        return new FlatMapFunction<T, R>() {
+            @Override
+            public Iterator<R> call(T t) throws Exception {
+                return function.call(t);
+            }
+        };
+    }
+
+    @Override
+    public <T, K, V> PairFlatMapFunction<T, K, V> pairFlatMapFunction(final PairFlatMapFunctionAdapter<T, K, V> function) {
+        return new PairFlatMapFunction<T, K, V>() {
+            @Override
+            public Iterator<Tuple2<K, V>> call(T t) throws Exception {
+                return function.call(t);
+            }
+        };
+    }
+
+    @Override
+    public RDD<Tuple> coalesce(RDD<Tuple> rdd, int numPartitions, boolean shuffle) {
+        return rdd.coalesce(numPartitions, shuffle, Option.<PartitionCoalescer>empty(), null);
+    }
+
+    @Override
+    public SparkJobStats sparkJobStats(int jobId, PigStats.JobGraph plan, Configuration conf) {
+        return new Spark2JobStats(jobId, plan, conf);
+    }
+
+    @Override
+    public SparkJobStats sparkJobStats(String jobId, PigStats.JobGraph plan, Configuration conf) {
+        return new Spark2JobStats(jobId, plan, conf);
+    }
+
+    @Override
+    public <T> OptionalWrapper<T> wrapOptional(T tuple) {
+        final Optional<T> t = (Optional<T>) tuple;
+
+        return new OptionalWrapper<T>() {
+            @Override
+            public boolean isPresent() {
+                return t.isPresent();
+            }
+
+            @Override
+            public T get() {
+                return t.get();
+            }
+        };
+    }
+
+    private static class JobMetricsListener extends SparkListener {
+        private final Log LOG = LogFactory.getLog(JobMetricsListener.class);
+
+        private Map<Integer, int[]> jobIdToStageId;
+        private Map<Integer, Integer> stageIdToJobId;
+        private Map<Integer, Map<String, List<TaskMetrics>>> allJobMetrics;
+        private Set<Integer> finishedJobIds;
+
+        JobMetricsListener(final Map<Integer, int[]> jobIdToStageId,
+                           final Map<Integer, Integer> stageIdToJobId,
+                           final Map<Integer, Map<String, List<TaskMetrics>>> allJobMetrics,
+                           final Set<Integer> finishedJobIds) {
+            this.jobIdToStageId = jobIdToStageId;
+            this.stageIdToJobId = stageIdToJobId;
+            this.allJobMetrics = allJobMetrics;
+            this.finishedJobIds = finishedJobIds;
+        }
+
+        @Override
+        public synchronized void onStageCompleted(SparkListenerStageCompleted stageCompleted) {
+            int stageId = stageCompleted.stageInfo().stageId();
+            int stageAttemptId = stageCompleted.stageInfo().attemptId();
+            String stageIdentifier = stageId + "_" + stageAttemptId;
+            Integer jobId = stageIdToJobId.get(stageId);
+            if (jobId == null) {
+                LOG.warn("Cannot find job id for stage[" + stageId + "].");
+            } else {
+                Map<String, List<TaskMetrics>> jobMetrics = allJobMetrics.get(jobId);
+                if (jobMetrics == null) {
+                    jobMetrics = Maps.newHashMap();
+                    allJobMetrics.put(jobId, jobMetrics);
+                }
+                List<TaskMetrics> stageMetrics = jobMetrics.get(stageIdentifier);
+                if (stageMetrics == null) {
+                    stageMetrics = Lists.newLinkedList();
+                    jobMetrics.put(stageIdentifier, stageMetrics);
+                }
+                stageMetrics.add(stageCompleted.stageInfo().taskMetrics());
+            }
+        }
+
+        @Override
+        public synchronized void onJobStart(SparkListenerJobStart jobStart) {
+            int jobId = jobStart.jobId();
+            int size = jobStart.stageIds().size();
+            int[] intStageIds = new int[size];
+            for (int i = 0; i < size; i++) {
+                Integer stageId = (Integer) jobStart.stageIds().apply(i);
+                intStageIds[i] = stageId;
+                stageIdToJobId.put(stageId, jobId);
+            }
+            jobIdToStageId.put(jobId, intStageIds);
+        }
+
+        @Override
+        public synchronized void onJobEnd(SparkListenerJobEnd jobEnd) {
+            finishedJobIds.add(jobEnd.jobId());
+            notify();
+        }
+    }
+
+    @Override
+    public SparkListener getJobMetricsListener(Map<Integer, int[]> jobIdToStageId,
+                                               Map<Integer, Integer> stageIdToJobId,
+                                               Map<Integer, Map<String, List<TaskMetrics>>> allJobMetrics,
+                                               Set<Integer> finishedJobIds) {
+        return new JobMetricsListener(jobIdToStageId, stageIdToJobId, allJobMetrics, finishedJobIds);
+    }
+
+    @Override
+    public void addSparkListener(SparkContext sc, SparkListener sparkListener) {
+        sc.addSparkListener(sparkListener);
+    }
+
+}
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/SparkLauncher.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/SparkLauncher.java
index 237fd94..03dea96 100644
--- a/src/org/apache/pig/backend/hadoop/executionengine/spark/SparkLauncher.java
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/SparkLauncher.java
@@ -136,11 +136,12 @@
 import org.apache.pig.tools.pigstats.spark.SparkPigStatusReporter;
 import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.scheduler.JobLogger;
 import org.apache.spark.scheduler.StatsReportListener;
 
 import com.google.common.base.Joiner;
 
+import static org.apache.pig.backend.hadoop.executionengine.spark.SparkShims.SPARK_VERSION;
+
 /**
  * Main class that launches pig for Spark
  */
@@ -152,7 +153,7 @@
     // across jobs, because a
     // new SparkLauncher gets created for each job.
     private static JavaSparkContext sparkContext = null;
-    private static JobMetricsListener jobMetricsListener = new JobMetricsListener();
+    private static JobStatisticCollector jobStatisticCollector = new JobStatisticCollector();
     private String jobGroupID;
     private PigContext pigContext = null;
     private JobConf jobConf = null;
@@ -174,9 +175,10 @@
         SparkPigStats sparkStats = (SparkPigStats) pigContext
                 .getExecutionEngine().instantiatePigStats();
         sparkStats.initialize(pigContext, sparkplan, jobConf);
+        UDFContext.getUDFContext().addJobConf(jobConf);
         PigStats.start(sparkStats);
 
-        startSparkIfNeeded(pigContext);
+        startSparkIfNeeded(jobConf, pigContext);
 
         jobGroupID = String.format("%s-%s",sparkContext.getConf().getAppId(),
                 UUID.randomUUID().toString());
@@ -184,7 +186,7 @@
 
         sparkContext.setJobGroup(jobGroupID, "Pig query to Spark cluster",
                 false);
-        jobMetricsListener.reset();
+        jobStatisticCollector.reset();
 
         this.currentDirectoryPath = Paths.get(".").toAbsolutePath()
                 .normalize().toString()
@@ -231,7 +233,7 @@
         }
         uploadResources(sparkplan);
 
-        new JobGraphBuilder(sparkplan, convertMap, sparkStats, sparkContext, jobMetricsListener, jobGroupID, jobConf, pigContext).visit();
+        new JobGraphBuilder(sparkplan, convertMap, sparkStats, sparkContext, jobStatisticCollector, jobGroupID, jobConf, pigContext).visit();
         cleanUpSparkJob(sparkStats);
         sparkStats.finish();
         resetUDFContext();
@@ -539,7 +541,7 @@
      * Only one SparkContext may be active per JVM (SPARK-2243). When multiple threads start SparkLaucher,
      * the static member sparkContext should be initialized only once
      */
-    private static synchronized void startSparkIfNeeded(PigContext pc) throws PigException {
+    private static synchronized void startSparkIfNeeded(JobConf jobConf, PigContext pc) throws PigException {
         if (sparkContext == null) {
             String master = null;
             if (pc.getExecType().isLocal()) {
@@ -594,9 +596,9 @@
             checkAndConfigureDynamicAllocation(master, sparkConf);
 
             sparkContext = new JavaSparkContext(sparkConf);
-            sparkContext.sc().addSparkListener(new StatsReportListener());
-            sparkContext.sc().addSparkListener(new JobLogger());
-            sparkContext.sc().addSparkListener(jobMetricsListener);
+            jobConf.set(SPARK_VERSION, sparkContext.version());
+            SparkShims.getInstance().addSparkListener(sparkContext.sc(), jobStatisticCollector.getSparkListener());
+            SparkShims.getInstance().addSparkListener(sparkContext.sc(), new StatsReportListener());
         }
     }
 
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/SparkShims.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/SparkShims.java
new file mode 100644
index 0000000..b81bb1a
--- /dev/null
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/SparkShims.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.pig.backend.hadoop.executionengine.spark;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.pig.data.Tuple;
+import org.apache.pig.impl.util.UDFContext;
+import org.apache.pig.tools.pigstats.PigStats;
+import org.apache.pig.tools.pigstats.spark.SparkJobStats;
+import org.apache.spark.SparkContext;
+import org.apache.spark.api.java.function.FlatMapFunction;
+import org.apache.spark.api.java.function.PairFlatMapFunction;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.scheduler.SparkListener;
+
+import java.io.Serializable;
+import java.lang.reflect.Constructor;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+public abstract class SparkShims implements Serializable {
+    private static final Log LOG = LogFactory.getLog(SparkShims.class);
+    public static final String SPARK_VERSION = "pig.spark.version";
+
+    private static SparkShims sparkShims;
+
+    private static SparkShims loadShims(String sparkVersion) throws ReflectiveOperationException {
+        Class<?> sparkShimsClass;
+
+        if ("2".equals(sparkVersion)) {
+            LOG.info("Initializing shims for Spark 2.x");
+            sparkShimsClass = Class.forName("org.apache.pig.backend.hadoop.executionengine.spark.Spark2Shims");
+        } else {
+            LOG.info("Initializing shims for Spark 1.x");
+            sparkShimsClass = Class.forName("org.apache.pig.backend.hadoop.executionengine.spark.Spark1Shims");
+        }
+
+        Constructor c = sparkShimsClass.getConstructor();
+        return (SparkShims) c.newInstance();
+    }
+
+    public static SparkShims getInstance() {
+        if (sparkShims == null) {
+            String sparkVersion = UDFContext.getUDFContext().getJobConf().get(SPARK_VERSION, "");
+            LOG.info("Initializing SparkShims for Spark version: " + sparkVersion);
+            String sparkMajorVersion = getSparkMajorVersion(sparkVersion);
+            try {
+                sparkShims = loadShims(sparkMajorVersion);
+            } catch (ReflectiveOperationException e) {
+                throw new RuntimeException(e);
+            }
+        }
+        return sparkShims;
+    }
+
+    private static String getSparkMajorVersion(String sparkVersion) {
+        return sparkVersion.startsWith("2") ? "2" : "1";
+    }
+
+    public abstract <T, R> FlatMapFunction<T, R> flatMapFunction(FlatMapFunctionAdapter<T, R> function);
+
+    public abstract <T, K, V> PairFlatMapFunction<T, K, V> pairFlatMapFunction(PairFlatMapFunctionAdapter<T, K, V> function);
+
+    public abstract RDD<Tuple> coalesce(RDD<Tuple> rdd, int numPartitions, boolean shuffle);
+
+    public abstract SparkJobStats sparkJobStats(int jobId, PigStats.JobGraph plan, Configuration conf);
+
+    public abstract SparkJobStats sparkJobStats(String jobId, PigStats.JobGraph plan, Configuration conf);
+
+    public abstract <T> OptionalWrapper<T> wrapOptional(T tuple);
+
+    public abstract SparkListener getJobMetricsListener(Map<Integer, int[]> jobIdToStageId,
+                                                        Map<Integer, Integer> stageIdToJobId,
+                                                        Map<Integer, Map<String, List<TaskMetrics>>> allJobMetrics,
+                                                        Set<Integer> finishedJobIds);
+
+    public abstract void addSparkListener(SparkContext sc, SparkListener sparkListener);
+
+    public interface OptionalWrapper<T> {
+        boolean isPresent();
+
+        T get();
+    }
+}
\ No newline at end of file
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/CollectedGroupConverter.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/CollectedGroupConverter.java
index 83311df..7933324 100644
--- a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/CollectedGroupConverter.java
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/CollectedGroupConverter.java
@@ -24,9 +24,10 @@
 import org.apache.pig.backend.executionengine.ExecException;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.Result;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POCollectedGroup;
+import org.apache.pig.backend.hadoop.executionengine.spark.FlatMapFunctionAdapter;
+import org.apache.pig.backend.hadoop.executionengine.spark.SparkShims;
 import org.apache.pig.backend.hadoop.executionengine.spark.SparkUtil;
 import org.apache.pig.data.Tuple;
-import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.rdd.RDD;
 
 @SuppressWarnings({"serial"})
@@ -39,48 +40,40 @@
         RDD<Tuple> rdd = predecessors.get(0);
         CollectedGroupFunction collectedGroupFunction
                 = new CollectedGroupFunction(physicalOperator);
-        return rdd.toJavaRDD().mapPartitions(collectedGroupFunction, true)
+        return rdd.toJavaRDD().mapPartitions(SparkShims.getInstance().flatMapFunction(collectedGroupFunction), true)
                 .rdd();
     }
 
     private static class CollectedGroupFunction
-            implements FlatMapFunction<Iterator<Tuple>, Tuple> {
+            implements FlatMapFunctionAdapter<Iterator<Tuple>, Tuple> {
 
         private POCollectedGroup poCollectedGroup;
 
         public long current_val;
-        public boolean proceed;
 
         private CollectedGroupFunction(POCollectedGroup poCollectedGroup) {
             this.poCollectedGroup = poCollectedGroup;
             this.current_val = 0;
         }
 
-        public Iterable<Tuple> call(final Iterator<Tuple> input) {
-
-            return new Iterable<Tuple>() {
+        @Override
+        public Iterator<Tuple> call(final Iterator<Tuple> input) {
+            return new OutputConsumerIterator(input) {
 
                 @Override
-                public Iterator<Tuple> iterator() {
+                protected void attach(Tuple tuple) {
+                    poCollectedGroup.setInputs(null);
+                    poCollectedGroup.attachInput(tuple);
+                }
 
-                    return new OutputConsumerIterator(input) {
+                @Override
+                protected Result getNextResult() throws ExecException {
+                    return poCollectedGroup.getNextTuple();
+                }
 
-                        @Override
-                        protected void attach(Tuple tuple) {
-                            poCollectedGroup.setInputs(null);
-                            poCollectedGroup.attachInput(tuple);
-                        }
-
-                        @Override
-                        protected Result getNextResult() throws ExecException {
-                            return poCollectedGroup.getNextTuple();
-                        }
-
-                        @Override
-                        protected void endOfInput() {
-                            poCollectedGroup.setEndOfInput(true);
-                        }
-                    };
+                @Override
+                protected void endOfInput() {
+                    poCollectedGroup.setEndOfInput(true);
                 }
             };
         }
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/FRJoinConverter.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/FRJoinConverter.java
index 382258e..6cd01cf 100644
--- a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/FRJoinConverter.java
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/FRJoinConverter.java
@@ -32,10 +32,11 @@
 import org.apache.pig.backend.executionengine.ExecException;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.Result;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POFRJoin;
+import org.apache.pig.backend.hadoop.executionengine.spark.FlatMapFunctionAdapter;
 import org.apache.pig.backend.hadoop.executionengine.spark.SparkPigContext;
+import org.apache.pig.backend.hadoop.executionengine.spark.SparkShims;
 import org.apache.pig.backend.hadoop.executionengine.spark.SparkUtil;
 import org.apache.pig.data.Tuple;
-import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.rdd.RDD;
 
 @SuppressWarnings("serial")
@@ -53,7 +54,7 @@
         attachReplicatedInputs((POFRJoinSpark) poFRJoin);
 
         FRJoinFunction frJoinFunction = new FRJoinFunction(poFRJoin);
-        return rdd.toJavaRDD().mapPartitions(frJoinFunction, true).rdd();
+        return rdd.toJavaRDD().mapPartitions(SparkShims.getInstance().flatMapFunction(frJoinFunction), true).rdd();
     }
 
     private void attachReplicatedInputs(POFRJoinSpark poFRJoin) {
@@ -67,7 +68,7 @@
     }
 
     private static class FRJoinFunction implements
-            FlatMapFunction<Iterator<Tuple>, Tuple>, Serializable {
+            FlatMapFunctionAdapter<Iterator<Tuple>, Tuple>, Serializable {
 
         private POFRJoin poFRJoin;
         private FRJoinFunction(POFRJoin poFRJoin) {
@@ -75,29 +76,22 @@
         }
 
         @Override
-        public Iterable<Tuple> call(final Iterator<Tuple> input) throws Exception {
-
-            return new Iterable<Tuple>() {
+        public Iterator<Tuple> call(final Iterator<Tuple> input) {
+            return new OutputConsumerIterator(input) {
 
                 @Override
-                public Iterator<Tuple> iterator() {
-                    return new OutputConsumerIterator(input) {
+                protected void attach(Tuple tuple) {
+                    poFRJoin.setInputs(null);
+                    poFRJoin.attachInput(tuple);
+                }
 
-                        @Override
-                        protected void attach(Tuple tuple) {
-                            poFRJoin.setInputs(null);
-                            poFRJoin.attachInput(tuple);
-                        }
+                @Override
+                protected Result getNextResult() throws ExecException {
+                    return poFRJoin.getNextTuple();
+                }
 
-                        @Override
-                        protected Result getNextResult() throws ExecException {
-                            return poFRJoin.getNextTuple();
-                        }
-
-                        @Override
-                        protected void endOfInput() {
-                        }
-                    };
+                @Override
+                protected void endOfInput() {
                 }
             };
         }
@@ -107,4 +101,4 @@
     public void setReplicatedInputs(Set<String> replicatedInputs) {
         this.replicatedInputs = replicatedInputs;
     }
-}
\ No newline at end of file
+}
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/ForEachConverter.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/ForEachConverter.java
index b58415e..01581d0 100644
--- a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/ForEachConverter.java
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/ForEachConverter.java
@@ -29,13 +29,14 @@
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.Result;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.expressionOperators.POUserFunc;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POForEach;
+import org.apache.pig.backend.hadoop.executionengine.spark.FlatMapFunctionAdapter;
 import org.apache.pig.backend.hadoop.executionengine.spark.KryoSerializer;
+import org.apache.pig.backend.hadoop.executionengine.spark.SparkShims;
 import org.apache.pig.backend.hadoop.executionengine.spark.SparkUtil;
 import org.apache.pig.data.SchemaTupleBackend;
 import org.apache.pig.data.Tuple;
 import org.apache.pig.impl.PigContext;
 import org.apache.pig.impl.util.ObjectSerializer;
-import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.rdd.RDD;
 
 /**
@@ -60,11 +61,11 @@
         RDD<Tuple> rdd = predecessors.get(0);
         ForEachFunction forEachFunction = new ForEachFunction(physicalOperator, confBytes);
 
-        return rdd.toJavaRDD().mapPartitions(forEachFunction, true).rdd();
+        return rdd.toJavaRDD().mapPartitions(SparkShims.getInstance().flatMapFunction(forEachFunction), true).rdd();
     }
 
     private static class ForEachFunction implements
-            FlatMapFunction<Iterator<Tuple>, Tuple>, Serializable {
+            FlatMapFunctionAdapter<Iterator<Tuple>, Tuple>, Serializable {
 
         private POForEach poForEach;
         private byte[] confBytes;
@@ -75,7 +76,8 @@
             this.confBytes = confBytes;
         }
 
-        public Iterable<Tuple> call(final Iterator<Tuple> input) {
+        @Override
+        public Iterator<Tuple> call(final Iterator<Tuple> input) {
 
             initialize();
 
@@ -90,29 +92,21 @@
                     }
                 }
             }
-
-
-            return new Iterable<Tuple>() {
+            return new OutputConsumerIterator(input) {
 
                 @Override
-                public Iterator<Tuple> iterator() {
-                    return new OutputConsumerIterator(input) {
+                protected void attach(Tuple tuple) {
+                    poForEach.setInputs(null);
+                    poForEach.attachInput(tuple);
+                }
 
-                        @Override
-                        protected void attach(Tuple tuple) {
-                            poForEach.setInputs(null);
-                            poForEach.attachInput(tuple);
-                        }
+                @Override
+                protected Result getNextResult() throws ExecException {
+                    return poForEach.getNextTuple();
+                }
 
-                        @Override
-                        protected Result getNextResult() throws ExecException {
-                            return poForEach.getNextTuple();
-                        }
-
-                        @Override
-                        protected void endOfInput() {
-                        }
-                    };
+                @Override
+                protected void endOfInput() {
                 }
             };
         }
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/GlobalRearrangeConverter.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/GlobalRearrangeConverter.java
index 130c8b9..805c017 100644
--- a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/GlobalRearrangeConverter.java
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/GlobalRearrangeConverter.java
@@ -26,6 +26,7 @@
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.pig.backend.executionengine.ExecException;
+import org.apache.pig.backend.hadoop.executionengine.spark.FlatMapFunctionAdapter;
 import org.apache.pig.backend.hadoop.executionengine.spark.SparkPigContext;
 import org.apache.pig.backend.hadoop.executionengine.spark.SparkUtil;
 import org.apache.pig.backend.hadoop.executionengine.spark.operator.POGlobalRearrangeSpark;
@@ -34,7 +35,6 @@
 import org.apache.spark.HashPartitioner;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.rdd.CoGroupedRDD;
 import org.apache.spark.rdd.RDD;
@@ -127,7 +127,7 @@
     }
 
     private static class RemoveValueFunction implements
-            FlatMapFunction<Iterator<Tuple2<Tuple, Object>>, Tuple>, Serializable {
+            FlatMapFunctionAdapter<Iterator<Tuple2<Tuple, Object>>, Tuple>, Serializable {
 
         private class Tuple2TransformIterable implements Iterable<Tuple> {
 
@@ -148,8 +148,8 @@
         }
 
         @Override
-        public Iterable<Tuple> call(Iterator<Tuple2<Tuple, Object>> input) {
-            return new Tuple2TransformIterable(input);
+        public Iterator<Tuple> call(Iterator<Tuple2<Tuple, Object>> input) {
+            return new Tuple2TransformIterable(input).iterator();
         }
     }
 
@@ -330,7 +330,7 @@
                             }
                         }
                     });
-                    ++ i;
+                    ++i;
                 }
 
                 Tuple out = tf.newTuple(2);
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/LimitConverter.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/LimitConverter.java
index fe1b54c..1a277fc 100644
--- a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/LimitConverter.java
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/LimitConverter.java
@@ -24,9 +24,10 @@
 import org.apache.pig.backend.executionengine.ExecException;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.Result;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POLimit;
+import org.apache.pig.backend.hadoop.executionengine.spark.FlatMapFunctionAdapter;
+import org.apache.pig.backend.hadoop.executionengine.spark.SparkShims;
 import org.apache.pig.backend.hadoop.executionengine.spark.SparkUtil;
 import org.apache.pig.data.Tuple;
-import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.rdd.RDD;
 
 @SuppressWarnings({ "serial" })
@@ -38,11 +39,11 @@
         SparkUtil.assertPredecessorSize(predecessors, poLimit, 1);
         RDD<Tuple> rdd = predecessors.get(0);
         LimitFunction limitFunction = new LimitFunction(poLimit);
-        RDD<Tuple> rdd2 = rdd.coalesce(1, false, null);
-        return rdd2.toJavaRDD().mapPartitions(limitFunction, false).rdd();
+        RDD<Tuple> rdd2 = SparkShims.getInstance().coalesce(rdd, 1, false);
+        return rdd2.toJavaRDD().mapPartitions(SparkShims.getInstance().flatMapFunction(limitFunction), false).rdd();
     }
 
-    private static class LimitFunction implements FlatMapFunction<Iterator<Tuple>, Tuple> {
+    private static class LimitFunction implements FlatMapFunctionAdapter<Iterator<Tuple>, Tuple> {
 
         private final POLimit poLimit;
 
@@ -51,28 +52,22 @@
         }
 
         @Override
-        public Iterable<Tuple> call(final Iterator<Tuple> tuples) {
+        public Iterator<Tuple> call(final Iterator<Tuple> tuples) {
+            return new OutputConsumerIterator(tuples) {
 
-            return new Iterable<Tuple>() {
+                @Override
+                protected void attach(Tuple tuple) {
+                    poLimit.setInputs(null);
+                    poLimit.attachInput(tuple);
+                }
 
-                public Iterator<Tuple> iterator() {
-                    return new OutputConsumerIterator(tuples) {
+                @Override
+                protected Result getNextResult() throws ExecException {
+                    return poLimit.getNextTuple();
+                }
 
-                        @Override
-                        protected void attach(Tuple tuple) {
-                            poLimit.setInputs(null);
-                            poLimit.attachInput(tuple);
-                        }
-
-                        @Override
-                        protected Result getNextResult() throws ExecException {
-                            return poLimit.getNextTuple();
-                        }
-
-                        @Override
-                        protected void endOfInput() {
-                        }
-                    };
+                @Override
+                protected void endOfInput() {
                 }
             };
         }
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/MergeCogroupConverter.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/MergeCogroupConverter.java
index adf78ec..b3cefa8 100644
--- a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/MergeCogroupConverter.java
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/MergeCogroupConverter.java
@@ -24,9 +24,10 @@
 import org.apache.pig.backend.executionengine.ExecException;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.Result;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POMergeCogroup;
+import org.apache.pig.backend.hadoop.executionengine.spark.FlatMapFunctionAdapter;
+import org.apache.pig.backend.hadoop.executionengine.spark.SparkShims;
 import org.apache.pig.backend.hadoop.executionengine.spark.SparkUtil;
 import org.apache.pig.data.Tuple;
-import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.rdd.RDD;
 
 
@@ -37,38 +38,32 @@
         SparkUtil.assertPredecessorSize(predecessors, physicalOperator, 1);
         RDD<Tuple> rdd = predecessors.get(0);
         MergeCogroupFunction mergeCogroupFunction = new MergeCogroupFunction(physicalOperator);
-        return rdd.toJavaRDD().mapPartitions(mergeCogroupFunction, true).rdd();
+        return rdd.toJavaRDD().mapPartitions(SparkShims.getInstance().flatMapFunction(mergeCogroupFunction), true).rdd();
     }
 
     private static class MergeCogroupFunction implements
-            FlatMapFunction<Iterator<Tuple>, Tuple>, Serializable {
+            FlatMapFunctionAdapter<Iterator<Tuple>, Tuple>, Serializable {
 
         private POMergeCogroup poMergeCogroup;
 
         @Override
-        public Iterable<Tuple> call(final Iterator<Tuple> input) throws Exception {
-            return new Iterable<Tuple>() {
+        public Iterator<Tuple> call(final Iterator<Tuple> input) {
+            return new OutputConsumerIterator(input) {
 
                 @Override
-                public Iterator<Tuple> iterator() {
-                    return new OutputConsumerIterator(input) {
+                protected void attach(Tuple tuple) {
+                    poMergeCogroup.setInputs(null);
+                    poMergeCogroup.attachInput(tuple);
+                }
 
-                        @Override
-                        protected void attach(Tuple tuple) {
-                            poMergeCogroup.setInputs(null);
-                            poMergeCogroup.attachInput(tuple);
-                        }
+                @Override
+                protected Result getNextResult() throws ExecException {
+                    return poMergeCogroup.getNextTuple();
+                }
 
-                        @Override
-                        protected Result getNextResult() throws ExecException {
-                            return poMergeCogroup.getNextTuple();
-                        }
-
-                        @Override
-                        protected void endOfInput() {
-                            poMergeCogroup.setEndOfInput(true);
-                        }
-                    };
+                @Override
+                protected void endOfInput() {
+                    poMergeCogroup.setEndOfInput(true);
                 }
             };
         }
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/MergeJoinConverter.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/MergeJoinConverter.java
index d1c43b1..9e37e8c 100644
--- a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/MergeJoinConverter.java
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/MergeJoinConverter.java
@@ -25,9 +25,10 @@
 import org.apache.pig.backend.executionengine.ExecException;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.Result;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POMergeJoin;
+import org.apache.pig.backend.hadoop.executionengine.spark.FlatMapFunctionAdapter;
+import org.apache.pig.backend.hadoop.executionengine.spark.SparkShims;
 import org.apache.pig.backend.hadoop.executionengine.spark.SparkUtil;
 import org.apache.pig.data.Tuple;
-import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.rdd.RDD;
 
 
@@ -43,11 +44,11 @@
         RDD<Tuple> rdd = predecessors.get(0);
         MergeJoinFunction mergeJoinFunction = new MergeJoinFunction(poMergeJoin);
 
-        return rdd.toJavaRDD().mapPartitions(mergeJoinFunction, true).rdd();
+        return rdd.toJavaRDD().mapPartitions(SparkShims.getInstance().flatMapFunction(mergeJoinFunction), true).rdd();
     }
 
     private static class MergeJoinFunction implements
-            FlatMapFunction<Iterator<Tuple>, Tuple>, Serializable {
+            FlatMapFunctionAdapter<Iterator<Tuple>, Tuple>, Serializable {
 
         private POMergeJoin poMergeJoin;
 
@@ -55,29 +56,24 @@
             this.poMergeJoin = poMergeJoin;
         }
 
-        public Iterable<Tuple> call(final Iterator<Tuple> input) {
+        @Override
+        public Iterator<Tuple> call(final Iterator<Tuple> input) {
+            return new OutputConsumerIterator(input) {
 
-            return new Iterable<Tuple>() {
                 @Override
-                public Iterator<Tuple> iterator() {
-                    return new OutputConsumerIterator(input) {
+                protected void attach(Tuple tuple) {
+                    poMergeJoin.setInputs(null);
+                    poMergeJoin.attachInput(tuple);
+                }
 
-                        @Override
-                        protected void attach(Tuple tuple) {
-                            poMergeJoin.setInputs(null);
-                            poMergeJoin.attachInput(tuple);
-                        }
+                @Override
+                protected Result getNextResult() throws ExecException {
+                    return poMergeJoin.getNextTuple();
+                }
 
-                        @Override
-                        protected Result getNextResult() throws ExecException {
-                            return poMergeJoin.getNextTuple();
-                        }
-
-                        @Override
-                        protected void endOfInput() {
-                            poMergeJoin.setEndOfInput(true);
-                        }
-                    };
+                @Override
+                protected void endOfInput() {
+                    poMergeJoin.setEndOfInput(true);
                 }
             };
         }
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/PoissonSampleConverter.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/PoissonSampleConverter.java
index e003bbd..772338b 100644
--- a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/PoissonSampleConverter.java
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/PoissonSampleConverter.java
@@ -19,10 +19,11 @@
 
 import org.apache.pig.backend.executionengine.ExecException;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.Result;
-import org.apache.pig.backend.hadoop.executionengine.spark.operator.POPoissonSampleSpark;
+import org.apache.pig.backend.hadoop.executionengine.spark.FlatMapFunctionAdapter;
+import org.apache.pig.backend.hadoop.executionengine.spark.SparkShims;
 import org.apache.pig.backend.hadoop.executionengine.spark.SparkUtil;
+import org.apache.pig.backend.hadoop.executionengine.spark.operator.POPoissonSampleSpark;
 import org.apache.pig.data.Tuple;
-import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.rdd.RDD;
 
 import java.io.IOException;
@@ -37,10 +38,10 @@
         SparkUtil.assertPredecessorSize(predecessors, po, 1);
         RDD<Tuple> rdd = predecessors.get(0);
         PoissionSampleFunction poissionSampleFunction = new PoissionSampleFunction(po);
-        return rdd.toJavaRDD().mapPartitions(poissionSampleFunction, false).rdd();
+        return rdd.toJavaRDD().mapPartitions(SparkShims.getInstance().flatMapFunction(poissionSampleFunction), false).rdd();
     }
 
-    private static class PoissionSampleFunction implements FlatMapFunction<Iterator<Tuple>, Tuple> {
+    private static class PoissionSampleFunction implements FlatMapFunctionAdapter<Iterator<Tuple>, Tuple> {
 
         private final POPoissonSampleSpark po;
 
@@ -49,29 +50,23 @@
         }
 
         @Override
-        public Iterable<Tuple> call(final Iterator<Tuple> tuples) {
+        public Iterator<Tuple> call(final Iterator<Tuple> tuples) {
+            return new OutputConsumerIterator(tuples) {
 
-            return new Iterable<Tuple>() {
+                @Override
+                protected void attach(Tuple tuple) {
+                    po.setInputs(null);
+                    po.attachInput(tuple);
+                }
 
-                public Iterator<Tuple> iterator() {
-                    return new OutputConsumerIterator(tuples) {
+                @Override
+                protected Result getNextResult() throws ExecException {
+                    return po.getNextTuple();
+                }
 
-                        @Override
-                        protected void attach(Tuple tuple) {
-                            po.setInputs(null);
-                            po.attachInput(tuple);
-                        }
-
-                        @Override
-                        protected Result getNextResult() throws ExecException {
-                            return po.getNextTuple();
-                        }
-
-                        @Override
-                        protected void endOfInput() {
-                            po.setEndOfInput(true);
-                        }
-                    };
+                @Override
+                protected void endOfInput() {
+                    po.setEndOfInput(true);
                 }
             };
         }
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SecondaryKeySortUtil.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SecondaryKeySortUtil.java
index 00d29b4..cd5aef6 100644
--- a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SecondaryKeySortUtil.java
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SecondaryKeySortUtil.java
@@ -22,6 +22,9 @@
 import java.util.Iterator;
 import java.util.Objects;
 
+import org.apache.pig.backend.hadoop.executionengine.spark.FlatMapFunctionAdapter;
+import org.apache.pig.backend.hadoop.executionengine.spark.SparkShims;
+import org.apache.pig.backend.hadoop.executionengine.spark.SparkUtil;
 import scala.Tuple2;
 
 import org.apache.commons.logging.Log;
@@ -30,13 +33,11 @@
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.POStatus;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.Result;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POPackage;
-import org.apache.pig.backend.hadoop.executionengine.spark.SparkUtil;
 import org.apache.pig.data.Tuple;
 import org.apache.pig.impl.io.NullableTuple;
 import org.apache.pig.impl.io.PigNullableWritable;
 import org.apache.spark.Partitioner;
 import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.rdd.RDD;
 
 /**
@@ -56,13 +57,13 @@
         JavaPairRDD<IndexedKey, Tuple> sorted = pairRDD.repartitionAndSortWithinPartitions(
                 new IndexedKeyPartitioner(partitionNums));
         //Package tuples with same indexedkey as the result: (key,(val1,val2,val3,...))
-        return sorted.mapPartitions(new AccumulateByKey(pkgOp), true).rdd();
+        return sorted.mapPartitions(SparkShims.getInstance().flatMapFunction(new AccumulateByKey(pkgOp)), true).rdd();
     }
 
     //Package tuples with same indexedkey as the result: (key,(val1,val2,val3,...))
     //Send (key,Iterator) to POPackage, use POPackage#getNextTuple to get the result
-    private static class AccumulateByKey implements FlatMapFunction<Iterator<Tuple2<IndexedKey, Tuple>>, Tuple>,
-            Serializable {
+    private static class AccumulateByKey
+            implements FlatMapFunctionAdapter<Iterator<Tuple2<IndexedKey, Tuple>>, Tuple>, Serializable {
         private POPackage pkgOp;
 
         public AccumulateByKey(POPackage pkgOp) {
@@ -70,7 +71,7 @@
         }
 
         @Override
-        public Iterable<Tuple> call(final Iterator<Tuple2<IndexedKey, Tuple>> it) throws Exception {
+        public Iterator<Tuple> call(final Iterator<Tuple2<IndexedKey, Tuple>> it) {
             return new Iterable<Tuple>() {
                 Object curKey = null;
                 ArrayList curValues = new ArrayList();
@@ -132,7 +133,7 @@
                         }
                     };
                 }
-            };
+            }.iterator();
         }
 
         private Tuple restructTuple(final Object curKey, final ArrayList<Tuple> curValues) {
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SkewedJoinConverter.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SkewedJoinConverter.java
index c55ba31..64f2fc2 100644
--- a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SkewedJoinConverter.java
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SkewedJoinConverter.java
@@ -25,11 +25,12 @@
 import java.util.Map;
 import java.util.HashMap;
 
-import com.google.common.base.Optional;
 import com.google.common.collect.Maps;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.pig.backend.hadoop.executionengine.spark.FlatMapFunctionAdapter;
 import org.apache.pig.backend.hadoop.executionengine.spark.SparkPigContext;
+import org.apache.pig.backend.hadoop.executionengine.spark.SparkShims;
 import org.apache.pig.data.DataBag;
 import org.apache.pig.impl.builtin.PartitionSkewedKeys;
 import org.apache.pig.impl.util.Pair;
@@ -54,7 +55,6 @@
 import org.apache.pig.impl.util.MultiMap;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.rdd.RDD;
 
 public class SkewedJoinConverter implements
@@ -103,7 +103,7 @@
 
         // with partition id
         StreamPartitionIndexKeyFunction streamFun = new StreamPartitionIndexKeyFunction(this, keyDist, defaultParallelism);
-        JavaRDD<Tuple2<PartitionIndexedKey, Tuple>> streamIdxKeyJavaRDD = rdd2.toJavaRDD().flatMap(streamFun);
+        JavaRDD<Tuple2<PartitionIndexedKey, Tuple>> streamIdxKeyJavaRDD = rdd2.toJavaRDD().flatMap(SparkShims.getInstance().flatMapFunction(streamFun));
 
         // Tuple2 RDD to Pair RDD
         JavaPairRDD<PartitionIndexedKey, Tuple> streamIndexedJavaPairRDD = new JavaPairRDD<PartitionIndexedKey, Tuple>(
@@ -146,7 +146,7 @@
      * @param <R> be generic because it can be Optional<Tuple> or Tuple
      */
     private static class ToValueFunction<L, R> implements
-            FlatMapFunction<Iterator<Tuple2<PartitionIndexedKey, Tuple2<L, R>>>, Tuple>, Serializable {
+            FlatMapFunctionAdapter<Iterator<Tuple2<PartitionIndexedKey, Tuple2<L, R>>>, Tuple>, Serializable {
 
         private boolean[] innerFlags;
         private int[] schemaSize;
@@ -188,7 +188,7 @@
                             Tuple leftTuple = tf.newTuple();
                             if (!innerFlags[0]) {
                                 // left should be Optional<Tuple>
-                                Optional<Tuple> leftOption = (Optional<Tuple>) left;
+                                SparkShims.OptionalWrapper<L> leftOption = SparkShims.getInstance().wrapOptional(left);
                                 if (!leftOption.isPresent()) {
                                     // Add an empty left record for RIGHT OUTER JOIN.
                                     // Notice: if it is a skewed, only join the first reduce key
@@ -200,7 +200,7 @@
                                         return this.next();
                                     }
                                 } else {
-                                    leftTuple = leftOption.get();
+                                    leftTuple = (Tuple) leftOption.get();
                                 }
                             } else {
                                 leftTuple = (Tuple) left;
@@ -212,13 +212,13 @@
                             Tuple rightTuple = tf.newTuple();
                             if (!innerFlags[1]) {
                                 // right should be Optional<Tuple>
-                                Optional<Tuple> rightOption = (Optional<Tuple>) right;
+                                SparkShims.OptionalWrapper<R> rightOption = SparkShims.getInstance().wrapOptional(right);
                                 if (!rightOption.isPresent()) {
                                     for (int i = 0; i < schemaSize[1]; i++) {
                                         rightTuple.append(null);
                                     }
                                 } else {
-                                    rightTuple = rightOption.get();
+                                    rightTuple = (Tuple) rightOption.get();
                                 }
                             } else {
                                 rightTuple = (Tuple) right;
@@ -234,17 +234,17 @@
                             return result;
                         } catch (Exception e) {
                             log.warn(e);
+                            return null;
                         }
-                        return null;
                     }
                 };
             }
         }
 
         @Override
-        public Iterable<Tuple> call(
-                Iterator<Tuple2<PartitionIndexedKey, Tuple2<L, R>>> input) {
-            return new Tuple2TransformIterable(input);
+        public Iterator<Tuple> call(
+                Iterator<Tuple2<PartitionIndexedKey, Tuple2<L, R>>> input) throws Exception {
+            return new Tuple2TransformIterable(input).iterator();
         }
 
         private boolean isFirstReduceKey(PartitionIndexedKey pKey) {
@@ -413,7 +413,7 @@
 
                 return tuple_KeyValue;
             } catch (Exception e) {
-                System.out.print(e);
+                log.warn(e);
                 return null;
             }
         }
@@ -469,7 +469,7 @@
      * <p>
      * see: https://wiki.apache.org/pig/PigSkewedJoinSpec
      */
-    private static class StreamPartitionIndexKeyFunction implements FlatMapFunction<Tuple, Tuple2<PartitionIndexedKey, Tuple>> {
+    private static class StreamPartitionIndexKeyFunction implements FlatMapFunctionAdapter<Tuple, Tuple2<PartitionIndexedKey, Tuple>> {
 
         private SkewedJoinConverter poSkewedJoin;
         private final Broadcast<List<Tuple>> keyDist;
@@ -487,7 +487,8 @@
             this.defaultParallelism = defaultParallelism;
         }
 
-        public Iterable<Tuple2<PartitionIndexedKey, Tuple>> call(Tuple tuple) throws Exception {
+        @Override
+        public Iterator<Tuple2<PartitionIndexedKey, Tuple>> call(Tuple tuple) throws Exception {
             if (!initialized) {
                 Integer[] reducers = new Integer[1];
                 reducerMap = loadKeyDistribution(keyDist, reducers);
@@ -526,12 +527,12 @@
                 l.add(new Tuple2(pIndexKey, tuple));
             }
 
-            return l;
+            return l.iterator();
         }
     }
 
     /**
-     * user defined spark partitioner for skewed join
+     * User defined spark partitioner for skewed join
      */
     private static class SkewedJoinPartitioner extends Partitioner {
         private int numPartitions;
@@ -568,12 +569,8 @@
     }
 
     /**
-     * use parallelism from keyDist or the default parallelism to
+     * Use parallelism from keyDist or the default parallelism to
      * create user defined partitioner
-     *
-     * @param keyDist
-     * @param defaultParallelism
-     * @return
      */
     private SkewedJoinPartitioner buildPartitioner(Broadcast<List<Tuple>> keyDist, Integer defaultParallelism) {
         Integer parallelism = -1;
@@ -588,12 +585,7 @@
     }
 
     /**
-     * do all kinds of Join (inner, left outer, right outer, full outer)
-     *
-     * @param skewIndexedJavaPairRDD
-     * @param streamIndexedJavaPairRDD
-     * @param partitioner
-     * @return
+     * Do all kinds of Join (inner, left outer, right outer, full outer)
      */
     private JavaRDD<Tuple> doJoin(
             JavaPairRDD<PartitionIndexedKey, Tuple> skewIndexedJavaPairRDD,
@@ -616,25 +608,22 @@
             JavaPairRDD<PartitionIndexedKey, Tuple2<Tuple, Tuple>> resultKeyValue = skewIndexedJavaPairRDD.
                     join(streamIndexedJavaPairRDD, partitioner);
 
-            return resultKeyValue.mapPartitions(toValueFun);
+            return resultKeyValue.mapPartitions(SparkShims.getInstance().flatMapFunction(toValueFun));
         } else if (innerFlags[0] && !innerFlags[1]) {
             // left outer join
-            JavaPairRDD<PartitionIndexedKey, Tuple2<Tuple, Optional<Tuple>>> resultKeyValue = skewIndexedJavaPairRDD.
-                    leftOuterJoin(streamIndexedJavaPairRDD, partitioner);
-
-            return resultKeyValue.mapPartitions(toValueFun);
+            return skewIndexedJavaPairRDD
+                    .leftOuterJoin(streamIndexedJavaPairRDD, partitioner)
+                    .mapPartitions(SparkShims.getInstance().flatMapFunction(toValueFun));
         } else if (!innerFlags[0] && innerFlags[1]) {
             // right outer join
-            JavaPairRDD<PartitionIndexedKey, Tuple2<Optional<Tuple>, Tuple>> resultKeyValue = skewIndexedJavaPairRDD.
-                    rightOuterJoin(streamIndexedJavaPairRDD, partitioner);
-
-            return resultKeyValue.mapPartitions(toValueFun);
+            return skewIndexedJavaPairRDD
+                    .rightOuterJoin(streamIndexedJavaPairRDD, partitioner)
+                    .mapPartitions(SparkShims.getInstance().flatMapFunction(toValueFun));
         } else {
             // full outer join
-            JavaPairRDD<PartitionIndexedKey, Tuple2<Optional<Tuple>, Optional<Tuple>>> resultKeyValue = skewIndexedJavaPairRDD.
-                    fullOuterJoin(streamIndexedJavaPairRDD, partitioner);
-
-            return resultKeyValue.mapPartitions(toValueFun);
+            return skewIndexedJavaPairRDD
+                    .fullOuterJoin(streamIndexedJavaPairRDD, partitioner)
+                    .mapPartitions(SparkShims.getInstance().flatMapFunction(toValueFun));
         }
     }
 
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SortConverter.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SortConverter.java
index baabfa0..a8b51a5 100644
--- a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SortConverter.java
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SortConverter.java
@@ -22,25 +22,26 @@
 import java.util.Iterator;
 import java.util.List;
 
+import org.apache.pig.backend.hadoop.executionengine.spark.FlatMapFunctionAdapter;
+import org.apache.pig.backend.hadoop.executionengine.spark.SparkPigContext;
+import org.apache.pig.backend.hadoop.executionengine.spark.SparkShims;
+import org.apache.pig.backend.hadoop.executionengine.spark.SparkUtil;
 import scala.Tuple2;
 import scala.runtime.AbstractFunction1;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POSort;
-import org.apache.pig.backend.hadoop.executionengine.spark.SparkPigContext;
-import org.apache.pig.backend.hadoop.executionengine.spark.SparkUtil;
 import org.apache.pig.data.Tuple;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.rdd.RDD;
 
 @SuppressWarnings("serial")
 public class SortConverter implements RDDConverter<Tuple, Tuple, POSort> {
     private static final Log LOG = LogFactory.getLog(SortConverter.class);
 
-    private static final FlatMapFunction<Iterator<Tuple2<Tuple, Object>>, Tuple> TO_VALUE_FUNCTION = new ToValueFunction();
+    private static final FlatMapFunctionAdapter<Iterator<Tuple2<Tuple, Object>>, Tuple> TO_VALUE_FUNCTION = new ToValueFunction();
 
     @Override
     public RDD<Tuple> convert(List<RDD<Tuple>> predecessors, POSort sortOperator)
@@ -57,13 +58,13 @@
 
         JavaPairRDD<Tuple, Object> sorted = r.sortByKey(
                 sortOperator.getMComparator(), true, parallelism);
-        JavaRDD<Tuple> mapped = sorted.mapPartitions(TO_VALUE_FUNCTION);
+        JavaRDD<Tuple> mapped = sorted.mapPartitions(SparkShims.getInstance().flatMapFunction(TO_VALUE_FUNCTION));
 
         return mapped.rdd();
     }
 
     private static class ToValueFunction implements
-            FlatMapFunction<Iterator<Tuple2<Tuple, Object>>, Tuple>, Serializable {
+            FlatMapFunctionAdapter<Iterator<Tuple2<Tuple, Object>>, Tuple>, Serializable {
 
         private class Tuple2TransformIterable implements Iterable<Tuple> {
 
@@ -84,8 +85,8 @@
         }
 
         @Override
-        public Iterable<Tuple> call(Iterator<Tuple2<Tuple, Object>> input) {
-            return new Tuple2TransformIterable(input);
+        public Iterator<Tuple> call(Iterator<Tuple2<Tuple, Object>> input) {
+            return new Tuple2TransformIterable(input).iterator();
         }
     }
 
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SparkSampleSortConverter.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SparkSampleSortConverter.java
index 3166fdc..7e00e1f 100644
--- a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SparkSampleSortConverter.java
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/SparkSampleSortConverter.java
@@ -19,18 +19,18 @@
 
 import java.io.IOException;
 import java.io.Serializable;
-import java.util.ArrayList;
 import java.util.Iterator;
 import java.util.List;
 
+import org.apache.pig.backend.hadoop.executionengine.spark.PairFlatMapFunctionAdapter;
+import org.apache.pig.backend.hadoop.executionengine.spark.SparkShims;
+import org.apache.pig.backend.hadoop.executionengine.spark.SparkUtil;
+import org.apache.spark.api.java.function.Function2;
 import scala.Tuple2;
 import scala.runtime.AbstractFunction1;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
-import org.apache.pig.backend.executionengine.ExecException;
-import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POSort;
-import org.apache.pig.backend.hadoop.executionengine.spark.SparkUtil;
 import org.apache.pig.backend.hadoop.executionengine.spark.operator.POSampleSortSpark;
 import org.apache.pig.data.BagFactory;
 import org.apache.pig.data.DataBag;
@@ -40,12 +40,11 @@
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.Function;
-import org.apache.spark.api.java.function.PairFlatMapFunction;
 import org.apache.spark.rdd.RDD;
-  /*
-   sort the sample data and convert the sample data to the format (all,{(sampleEle1),(sampleEle2),...})
 
-   */
+/*
+ sort the sample data and convert the sample data to the format (all,{(sampleEle1),(sampleEle2),...})
+ */
 @SuppressWarnings("serial")
 public class SparkSampleSortConverter implements RDDConverter<Tuple, Tuple, POSampleSortSpark> {
     private static final Log LOG = LogFactory.getLog(SparkSampleSortConverter.class);
@@ -66,14 +65,14 @@
          //sort sample data
         JavaPairRDD<Tuple, Object> sorted = r.sortByKey(true);
          //convert every element in sample data from element to (all, element) format
-        JavaPairRDD<String, Tuple> mapped = sorted.mapPartitionsToPair(new AggregateFunction());
+        JavaPairRDD<String, Tuple> mapped = sorted.mapPartitionsToPair(SparkShims.getInstance().pairFlatMapFunction(new AggregateFunction()));
         //use groupByKey to aggregate all values( the format will be ((all),{(sampleEle1),(sampleEle2),...} )
         JavaRDD<Tuple> groupByKey= mapped.groupByKey().map(new ToValueFunction());
         return  groupByKey.rdd();
     }
 
 
-    private static class MergeFunction implements org.apache.spark.api.java.function.Function2<Tuple, Tuple, Tuple>
+    private static class MergeFunction implements Function2<Tuple, Tuple, Tuple>
             , Serializable {
 
         @Override
@@ -89,7 +88,7 @@
     // input: Tuple2<Tuple,Object>
     // output: Tuple2("all", Tuple)
     private static class AggregateFunction implements
-            PairFlatMapFunction<Iterator<Tuple2<Tuple, Object>>, String,Tuple>, Serializable {
+            PairFlatMapFunctionAdapter<Iterator<Tuple2<Tuple, Object>>, String,Tuple>, Serializable {
 
         private class Tuple2TransformIterable implements Iterable<Tuple2<String,Tuple>> {
 
@@ -111,8 +110,8 @@
         }
 
         @Override
-        public Iterable<Tuple2<String, Tuple>> call(Iterator<Tuple2<Tuple, Object>> input) throws Exception {
-            return new Tuple2TransformIterable(input);
+        public Iterator<Tuple2<String, Tuple>> call(Iterator<Tuple2<Tuple, Object>> input) throws Exception {
+            return new Tuple2TransformIterable(input).iterator();
         }
 
     }
diff --git a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/StreamConverter.java b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/StreamConverter.java
index 3a50d48..364deed 100644
--- a/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/StreamConverter.java
+++ b/src/org/apache/pig/backend/hadoop/executionengine/spark/converter/StreamConverter.java
@@ -25,9 +25,8 @@
 import org.apache.pig.backend.executionengine.ExecException;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.Result;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POStream;
-import org.apache.pig.backend.hadoop.executionengine.spark.SparkUtil;
+import org.apache.pig.backend.hadoop.executionengine.spark.*;
 import org.apache.pig.data.Tuple;
-import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.rdd.RDD;
 
 public class StreamConverter implements
@@ -35,44 +34,40 @@
 
     @Override
     public RDD<Tuple> convert(List<RDD<Tuple>> predecessors,
-            POStream poStream) throws IOException {
+                              POStream poStream) throws IOException {
         SparkUtil.assertPredecessorSize(predecessors, poStream, 1);
         RDD<Tuple> rdd = predecessors.get(0);
         StreamFunction streamFunction = new StreamFunction(poStream);
-        return rdd.toJavaRDD().mapPartitions(streamFunction, true).rdd();
+        return rdd.toJavaRDD().mapPartitions(SparkShims.getInstance().flatMapFunction(streamFunction), true).rdd();
     }
 
     private static class StreamFunction implements
-            FlatMapFunction<Iterator<Tuple>, Tuple>, Serializable {
+            FlatMapFunctionAdapter<Iterator<Tuple>, Tuple>, Serializable {
         private POStream poStream;
 
         private StreamFunction(POStream poStream) {
             this.poStream = poStream;
         }
 
-        public Iterable<Tuple> call(final Iterator<Tuple> input) {
-            return new Iterable<Tuple>() {
+        @Override
+        public Iterator<Tuple> call(final Iterator<Tuple> input) {
+            return new OutputConsumerIterator(input) {
+
                 @Override
-                public Iterator<Tuple> iterator() {
-                    return new OutputConsumerIterator(input) {
+                protected void attach(Tuple tuple) {
+                    poStream.setInputs(null);
+                    poStream.attachInput(tuple);
+                }
 
-                        @Override
-                        protected void attach(Tuple tuple) {
-                            poStream.setInputs(null);
-                            poStream.attachInput(tuple);
-                        }
+                @Override
+                protected Result getNextResult() throws ExecException {
+                    Result result = poStream.getNextTuple();
+                    return result;
+                }
 
-                        @Override
-                        protected Result getNextResult() throws ExecException {
-                            Result result = poStream.getNextTuple();
-                            return result;
-                        }
-
-                        @Override
-                        protected void endOfInput() {
-                            poStream.setFetchable(true);
-                        }
-                    };
+                @Override
+                protected void endOfInput() {
+                    poStream.setFetchable(true);
                 }
             };
         }
diff --git a/src/org/apache/pig/tools/pigstats/spark/Spark1JobStats.java b/src/org/apache/pig/tools/pigstats/spark/Spark1JobStats.java
new file mode 100644
index 0000000..521ddec
--- /dev/null
+++ b/src/org/apache/pig/tools/pigstats/spark/Spark1JobStats.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.pig.tools.pigstats.spark;
+
+import com.google.common.collect.Maps;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.pig.tools.pigstats.PigStats;
+import org.apache.pig.tools.pigstats.PigStatsUtil;
+import org.apache.spark.executor.ShuffleReadMetrics;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
+import scala.Option;
+
+import java.util.List;
+import java.util.Map;
+
+public class Spark1JobStats extends SparkJobStats {
+    public Spark1JobStats(int jobId, PigStats.JobGraph plan, Configuration conf) {
+        super(jobId, plan, conf);
+    }
+
+    public Spark1JobStats(String jobId, PigStats.JobGraph plan, Configuration conf) {
+        super(jobId, plan, conf);
+    }
+
+    @Override
+    protected Map<String, Long> combineTaskMetrics(Map<String, List<TaskMetrics>> jobMetric) {
+        Map<String, Long> results = Maps.newLinkedHashMap();
+
+        long executorDeserializeTime = 0;
+        long executorRunTime = 0;
+        long resultSize = 0;
+        long jvmGCTime = 0;
+        long resultSerializationTime = 0;
+        long memoryBytesSpilled = 0;
+        long diskBytesSpilled = 0;
+        long bytesRead = 0;
+        long bytesWritten = 0;
+        long remoteBlocksFetched = 0;
+        long localBlocksFetched = 0;
+        long fetchWaitTime = 0;
+        long remoteBytesRead = 0;
+        long shuffleBytesWritten = 0;
+        long shuffleWriteTime = 0;
+        boolean inputMetricExist = false;
+        boolean outputMetricExist = false;
+        boolean shuffleReadMetricExist = false;
+        boolean shuffleWriteMetricExist = false;
+
+        for (List<TaskMetrics> stageMetric : jobMetric.values()) {
+            if (stageMetric != null) {
+                for (TaskMetrics taskMetrics : stageMetric) {
+                    if (taskMetrics != null) {
+                        executorDeserializeTime += taskMetrics.executorDeserializeTime();
+                        executorRunTime += taskMetrics.executorRunTime();
+                        resultSize += taskMetrics.resultSize();
+                        jvmGCTime += taskMetrics.jvmGCTime();
+                        resultSerializationTime += taskMetrics.resultSerializationTime();
+                        memoryBytesSpilled += taskMetrics.memoryBytesSpilled();
+                        diskBytesSpilled += taskMetrics.diskBytesSpilled();
+                        if (!taskMetrics.inputMetrics().isEmpty()) {
+                            inputMetricExist = true;
+                            bytesRead += taskMetrics.inputMetrics().get().bytesRead();
+                        }
+
+                        if (!taskMetrics.outputMetrics().isEmpty()) {
+                            outputMetricExist = true;
+                            bytesWritten += taskMetrics.outputMetrics().get().bytesWritten();
+                        }
+
+                        Option<ShuffleReadMetrics> shuffleReadMetricsOption = taskMetrics.shuffleReadMetrics();
+                        if (!shuffleReadMetricsOption.isEmpty()) {
+                            shuffleReadMetricExist = true;
+                            remoteBlocksFetched += shuffleReadMetricsOption.get().remoteBlocksFetched();
+                            localBlocksFetched += shuffleReadMetricsOption.get().localBlocksFetched();
+                            fetchWaitTime += shuffleReadMetricsOption.get().fetchWaitTime();
+                            remoteBytesRead += shuffleReadMetricsOption.get().remoteBytesRead();
+                        }
+
+                        Option<ShuffleWriteMetrics> shuffleWriteMetricsOption = taskMetrics.shuffleWriteMetrics();
+                        if (!shuffleWriteMetricsOption.isEmpty()) {
+                            shuffleWriteMetricExist = true;
+                            shuffleBytesWritten += shuffleWriteMetricsOption.get().shuffleBytesWritten();
+                            shuffleWriteTime += shuffleWriteMetricsOption.get().shuffleWriteTime();
+                        }
+
+                    }
+                }
+            }
+        }
+
+        results.put("ExcutorDeserializeTime", executorDeserializeTime);
+        results.put("ExecutorRunTime", executorRunTime);
+        results.put("ResultSize", resultSize);
+        results.put("JvmGCTime", jvmGCTime);
+        results.put("ResultSerializationTime", resultSerializationTime);
+        results.put("MemoryBytesSpilled", memoryBytesSpilled);
+        results.put("DiskBytesSpilled", diskBytesSpilled);
+        if (inputMetricExist) {
+            results.put("BytesRead", bytesRead);
+            hdfsBytesRead = bytesRead;
+            counters.incrCounter(FS_COUNTER_GROUP, PigStatsUtil.HDFS_BYTES_READ, hdfsBytesRead);
+        }
+
+        if (outputMetricExist) {
+            results.put("BytesWritten", bytesWritten);
+            hdfsBytesWritten = bytesWritten;
+            counters.incrCounter(FS_COUNTER_GROUP, PigStatsUtil.HDFS_BYTES_WRITTEN, hdfsBytesWritten);
+        }
+
+        if (shuffleReadMetricExist) {
+            results.put("RemoteBlocksFetched", remoteBlocksFetched);
+            results.put("LocalBlocksFetched", localBlocksFetched);
+            results.put("TotalBlocksFetched", localBlocksFetched + remoteBlocksFetched);
+            results.put("FetchWaitTime", fetchWaitTime);
+            results.put("RemoteBytesRead", remoteBytesRead);
+        }
+
+        if (shuffleWriteMetricExist) {
+            results.put("ShuffleBytesWritten", shuffleBytesWritten);
+            results.put("ShuffleWriteTime", shuffleWriteTime);
+        }
+
+        return results;
+    }
+}
diff --git a/src/org/apache/pig/tools/pigstats/spark/Spark2JobStats.java b/src/org/apache/pig/tools/pigstats/spark/Spark2JobStats.java
new file mode 100644
index 0000000..7c20382
--- /dev/null
+++ b/src/org/apache/pig/tools/pigstats/spark/Spark2JobStats.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.pig.tools.pigstats.spark;
+
+import com.google.common.collect.Maps;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.pig.tools.pigstats.PigStats;
+import org.apache.pig.tools.pigstats.PigStatsUtil;
+import org.apache.spark.executor.ShuffleReadMetrics;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
+
+import java.util.List;
+import java.util.Map;
+
+public class Spark2JobStats extends SparkJobStats {
+    public Spark2JobStats(int jobId, PigStats.JobGraph plan, Configuration conf) {
+        super(jobId, plan, conf);
+    }
+
+    public Spark2JobStats(String jobId, PigStats.JobGraph plan, Configuration conf) {
+        super(jobId, plan, conf);
+    }
+
+    @Override
+    protected Map<String, Long> combineTaskMetrics(Map<String, List<TaskMetrics>> jobMetric) {
+        Map<String, Long> results = Maps.newLinkedHashMap();
+
+        long executorDeserializeTime = 0;
+        long executorRunTime = 0;
+        long resultSize = 0;
+        long jvmGCTime = 0;
+        long resultSerializationTime = 0;
+        long memoryBytesSpilled = 0;
+        long diskBytesSpilled = 0;
+        long bytesRead = 0;
+        long bytesWritten = 0;
+        long remoteBlocksFetched = 0;
+        long localBlocksFetched = 0;
+        long fetchWaitTime = 0;
+        long remoteBytesRead = 0;
+        long shuffleBytesWritten = 0;
+        long shuffleWriteTime = 0;
+
+        for (List<TaskMetrics> stageMetric : jobMetric.values()) {
+            if (stageMetric != null) {
+                for (TaskMetrics taskMetrics : stageMetric) {
+                    if (taskMetrics != null) {
+                        executorDeserializeTime += taskMetrics.executorDeserializeTime();
+                        executorRunTime += taskMetrics.executorRunTime();
+                        resultSize += taskMetrics.resultSize();
+                        jvmGCTime += taskMetrics.jvmGCTime();
+                        resultSerializationTime += taskMetrics.resultSerializationTime();
+                        memoryBytesSpilled += taskMetrics.memoryBytesSpilled();
+                        diskBytesSpilled += taskMetrics.diskBytesSpilled();
+                        bytesRead += taskMetrics.inputMetrics().bytesRead();
+
+                        bytesWritten += taskMetrics.outputMetrics().bytesWritten();
+
+                        ShuffleReadMetrics shuffleReadMetricsOption = taskMetrics.shuffleReadMetrics();
+                        remoteBlocksFetched += shuffleReadMetricsOption.remoteBlocksFetched();
+                        localBlocksFetched += shuffleReadMetricsOption.localBlocksFetched();
+                        fetchWaitTime += shuffleReadMetricsOption.fetchWaitTime();
+                        remoteBytesRead += shuffleReadMetricsOption.remoteBytesRead();
+
+                        ShuffleWriteMetrics shuffleWriteMetricsOption = taskMetrics.shuffleWriteMetrics();
+                        shuffleBytesWritten += shuffleWriteMetricsOption.shuffleBytesWritten();
+                        shuffleWriteTime += shuffleWriteMetricsOption.shuffleWriteTime();
+                    }
+                }
+            }
+        }
+
+        results.put("ExcutorDeserializeTime", executorDeserializeTime);
+        results.put("ExecutorRunTime", executorRunTime);
+        results.put("ResultSize", resultSize);
+        results.put("JvmGCTime", jvmGCTime);
+        results.put("ResultSerializationTime", resultSerializationTime);
+        results.put("MemoryBytesSpilled", memoryBytesSpilled);
+        results.put("DiskBytesSpilled", diskBytesSpilled);
+
+        results.put("BytesRead", bytesRead);
+        hdfsBytesRead = bytesRead;
+        counters.incrCounter(FS_COUNTER_GROUP, PigStatsUtil.HDFS_BYTES_READ, hdfsBytesRead);
+
+        results.put("BytesWritten", bytesWritten);
+        hdfsBytesWritten = bytesWritten;
+        counters.incrCounter(FS_COUNTER_GROUP, PigStatsUtil.HDFS_BYTES_WRITTEN, hdfsBytesWritten);
+
+        results.put("RemoteBlocksFetched", remoteBlocksFetched);
+        results.put("LocalBlocksFetched", localBlocksFetched);
+        results.put("TotalBlocksFetched", localBlocksFetched + remoteBlocksFetched);
+        results.put("FetchWaitTime", fetchWaitTime);
+        results.put("RemoteBytesRead", remoteBytesRead);
+
+        results.put("ShuffleBytesWritten", shuffleBytesWritten);
+        results.put("ShuffleWriteTime", shuffleWriteTime);
+
+        return results;
+    }
+}
diff --git a/src/org/apache/pig/tools/pigstats/spark/SparkJobStats.java b/src/org/apache/pig/tools/pigstats/spark/SparkJobStats.java
index c8cc031..6545cbc 100644
--- a/src/org/apache/pig/tools/pigstats/spark/SparkJobStats.java
+++ b/src/org/apache/pig/tools/pigstats/spark/SparkJobStats.java
@@ -21,30 +21,30 @@
 import java.util.List;
 import java.util.Map;
 
-import org.apache.pig.tools.pigstats.*;
-import scala.Option;
-
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.mapred.Counters;
 import org.apache.pig.PigWarning;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POLoad;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POStore;
-import org.apache.pig.backend.hadoop.executionengine.spark.JobMetricsListener;
+import org.apache.pig.backend.hadoop.executionengine.spark.JobStatisticCollector;
 import org.apache.pig.backend.hadoop.executionengine.spark.plan.SparkOperator;
 import org.apache.pig.impl.logicalLayer.FrontendException;
 import org.apache.pig.newplan.PlanVisitor;
-import org.apache.spark.executor.ShuffleReadMetrics;
-import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.pig.tools.pigstats.InputStats;
+import org.apache.pig.tools.pigstats.JobStats;
+import org.apache.pig.tools.pigstats.OutputStats;
+import org.apache.pig.tools.pigstats.PigStats;
+import org.apache.pig.tools.pigstats.PigStatsUtil;
 import org.apache.spark.executor.TaskMetrics;
 
 import com.google.common.collect.Maps;
 
-public class SparkJobStats extends JobStats {
+public abstract class SparkJobStats extends JobStats {
 
     private int jobId;
     private Map<String, Long> stats = Maps.newLinkedHashMap();
     private boolean disableCounter;
-    private Counters counters = null;
+    protected Counters counters = null;
     public static String FS_COUNTER_GROUP = "FS_GROUP";
     private Map<String, SparkCounter<Map<String, Long>>> warningCounters = null;
 
@@ -58,6 +58,7 @@
         setConf(conf);
     }
 
+    @Override
     public void setConf(Configuration conf) {
         super.setConf(conf);
         disableCounter = conf.getBoolean("pig.disable.counter", false);
@@ -65,7 +66,7 @@
     }
 
     public void addOutputInfo(POStore poStore, boolean success,
-                              JobMetricsListener jobMetricsListener) {
+                              JobStatisticCollector jobStatisticCollector) {
         if (!poStore.isTmpStore()) {
             long bytes = getOutputSize(poStore, conf);
             long recordsCount = -1;
@@ -99,9 +100,9 @@
         inputs.add(inputStats);
     }
 
-    public void collectStats(JobMetricsListener jobMetricsListener) {
-        if (jobMetricsListener != null) {
-            Map<String, List<TaskMetrics>> taskMetrics = jobMetricsListener.getJobMetric(jobId);
+    public void collectStats(JobStatisticCollector jobStatisticCollector) {
+        if (jobStatisticCollector != null) {
+            Map<String, List<TaskMetrics>> taskMetrics = jobStatisticCollector.getJobMetric(jobId);
             if (taskMetrics == null) {
                 throw new RuntimeException("No task metrics available for jobId " + jobId);
             }
@@ -109,110 +110,12 @@
         }
     }
 
+    protected abstract Map<String, Long> combineTaskMetrics(Map<String, List<TaskMetrics>> jobMetric);
+
     public Map<String, Long> getStats() {
         return stats;
     }
 
-    private Map<String, Long> combineTaskMetrics(Map<String, List<TaskMetrics>> jobMetric) {
-        Map<String, Long> results = Maps.newLinkedHashMap();
-
-        long executorDeserializeTime = 0;
-        long executorRunTime = 0;
-        long resultSize = 0;
-        long jvmGCTime = 0;
-        long resultSerializationTime = 0;
-        long memoryBytesSpilled = 0;
-        long diskBytesSpilled = 0;
-        long bytesRead = 0;
-        long bytesWritten = 0;
-        long remoteBlocksFetched = 0;
-        long localBlocksFetched = 0;
-        long fetchWaitTime = 0;
-        long remoteBytesRead = 0;
-        long shuffleBytesWritten = 0;
-        long shuffleWriteTime = 0;
-        boolean inputMetricExist = false;
-        boolean outputMetricExist = false;
-        boolean shuffleReadMetricExist = false;
-        boolean shuffleWriteMetricExist = false;
-
-        for (List<TaskMetrics> stageMetric : jobMetric.values()) {
-            if (stageMetric != null) {
-                for (TaskMetrics taskMetrics : stageMetric) {
-                    if (taskMetrics != null) {
-                        executorDeserializeTime += taskMetrics.executorDeserializeTime();
-                        executorRunTime += taskMetrics.executorRunTime();
-                        resultSize += taskMetrics.resultSize();
-                        jvmGCTime += taskMetrics.jvmGCTime();
-                        resultSerializationTime += taskMetrics.resultSerializationTime();
-                        memoryBytesSpilled += taskMetrics.memoryBytesSpilled();
-                        diskBytesSpilled += taskMetrics.diskBytesSpilled();
-                        if (!taskMetrics.inputMetrics().isEmpty()) {
-                            inputMetricExist = true;
-                            bytesRead += taskMetrics.inputMetrics().get().bytesRead();
-                        }
-
-                        if (!taskMetrics.outputMetrics().isEmpty()) {
-                            outputMetricExist = true;
-                            bytesWritten += taskMetrics.outputMetrics().get().bytesWritten();
-                        }
-
-                        Option<ShuffleReadMetrics> shuffleReadMetricsOption = taskMetrics.shuffleReadMetrics();
-                        if (!shuffleReadMetricsOption.isEmpty()) {
-                            shuffleReadMetricExist = true;
-                            remoteBlocksFetched += shuffleReadMetricsOption.get().remoteBlocksFetched();
-                            localBlocksFetched += shuffleReadMetricsOption.get().localBlocksFetched();
-                            fetchWaitTime += shuffleReadMetricsOption.get().fetchWaitTime();
-                            remoteBytesRead += shuffleReadMetricsOption.get().remoteBytesRead();
-                        }
-
-                        Option<ShuffleWriteMetrics> shuffleWriteMetricsOption = taskMetrics.shuffleWriteMetrics();
-                        if (!shuffleWriteMetricsOption.isEmpty()) {
-                            shuffleWriteMetricExist = true;
-                            shuffleBytesWritten += shuffleWriteMetricsOption.get().shuffleBytesWritten();
-                            shuffleWriteTime += shuffleWriteMetricsOption.get().shuffleWriteTime();
-                        }
-
-                    }
-                }
-            }
-        }
-
-        results.put("EexcutorDeserializeTime", executorDeserializeTime);
-        results.put("ExecutorRunTime", executorRunTime);
-        results.put("ResultSize", resultSize);
-        results.put("JvmGCTime", jvmGCTime);
-        results.put("ResultSerializationTime", resultSerializationTime);
-        results.put("MemoryBytesSpilled", memoryBytesSpilled);
-        results.put("DiskBytesSpilled", diskBytesSpilled);
-        if (inputMetricExist) {
-            results.put("BytesRead", bytesRead);
-            hdfsBytesRead = bytesRead;
-            counters.incrCounter(FS_COUNTER_GROUP, PigStatsUtil.HDFS_BYTES_READ, hdfsBytesRead);
-        }
-
-        if (outputMetricExist) {
-            results.put("BytesWritten", bytesWritten);
-            hdfsBytesWritten = bytesWritten;
-            counters.incrCounter(FS_COUNTER_GROUP, PigStatsUtil.HDFS_BYTES_WRITTEN, hdfsBytesWritten);
-        }
-
-        if (shuffleReadMetricExist) {
-            results.put("RemoteBlocksFetched", remoteBlocksFetched);
-            results.put("LocalBlocksFetched", localBlocksFetched);
-            results.put("TotalBlocksFetched", localBlocksFetched + remoteBlocksFetched);
-            results.put("FetchWaitTime", fetchWaitTime);
-            results.put("RemoteBytesRead", remoteBytesRead);
-        }
-
-        if (shuffleWriteMetricExist) {
-            results.put("ShuffleBytesWritten", shuffleBytesWritten);
-            results.put("ShuffleWriteTime", shuffleWriteTime);
-        }
-
-        return results;
-    }
-
     @Override
     public String getJobId() {
         return String.valueOf(jobId);
diff --git a/src/org/apache/pig/tools/pigstats/spark/SparkPigStats.java b/src/org/apache/pig/tools/pigstats/spark/SparkPigStats.java
index 61ccbcc..bd864ed 100644
--- a/src/org/apache/pig/tools/pigstats/spark/SparkPigStats.java
+++ b/src/org/apache/pig/tools/pigstats/spark/SparkPigStats.java
@@ -32,7 +32,8 @@
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POLoad;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POStore;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.util.PlanHelper;
-import org.apache.pig.backend.hadoop.executionengine.spark.JobMetricsListener;
+import org.apache.pig.backend.hadoop.executionengine.spark.JobStatisticCollector;
+import org.apache.pig.backend.hadoop.executionengine.spark.SparkShims;
 import org.apache.pig.backend.hadoop.executionengine.spark.operator.NativeSparkOperator;
 import org.apache.pig.backend.hadoop.executionengine.spark.plan.SparkOperPlan;
 import org.apache.pig.backend.hadoop.executionengine.spark.plan.SparkOperator;
@@ -69,14 +70,14 @@
     }
 
     public void addJobStats(POStore poStore, SparkOperator sparkOperator, int jobId,
-                            JobMetricsListener jobMetricsListener,
+                            JobStatisticCollector jobStatisticCollector,
                             JavaSparkContext sparkContext) {
         boolean isSuccess = SparkStatsUtil.isJobSuccess(jobId, sparkContext);
-        SparkJobStats jobStats = new SparkJobStats(jobId, jobPlan, conf);
+        SparkJobStats jobStats = SparkShims.getInstance().sparkJobStats(jobId, jobPlan, conf);
         jobStats.setSuccessful(isSuccess);
-        jobStats.collectStats(jobMetricsListener);
-        jobStats.addOutputInfo(poStore, isSuccess, jobMetricsListener);
-        addInputInfoForSparkOper(sparkOperator, jobStats, isSuccess, jobMetricsListener, conf);
+        jobStats.collectStats(jobStatisticCollector);
+        jobStats.addOutputInfo(poStore, isSuccess, jobStatisticCollector);
+        addInputInfoForSparkOper(sparkOperator, jobStats, isSuccess, jobStatisticCollector, conf);
         jobStats.initWarningCounters();
         jobSparkOperatorMap.put(jobStats, sparkOperator);
 
@@ -85,22 +86,22 @@
 
 
     public void addFailJobStats(POStore poStore, SparkOperator sparkOperator, String jobId,
-                                JobMetricsListener jobMetricsListener,
+                                JobStatisticCollector jobStatisticCollector,
                                 JavaSparkContext sparkContext,
                                 Exception e) {
         boolean isSuccess = false;
-        SparkJobStats jobStats = new SparkJobStats(jobId, jobPlan, conf);
+        SparkJobStats jobStats = SparkShims.getInstance().sparkJobStats(jobId, jobPlan, conf);
         jobStats.setSuccessful(isSuccess);
-        jobStats.collectStats(jobMetricsListener);
-        jobStats.addOutputInfo(poStore, isSuccess, jobMetricsListener);
-        addInputInfoForSparkOper(sparkOperator, jobStats, isSuccess, jobMetricsListener, conf);
+        jobStats.collectStats(jobStatisticCollector);
+        jobStats.addOutputInfo(poStore, isSuccess, jobStatisticCollector);
+        addInputInfoForSparkOper(sparkOperator, jobStats, isSuccess, jobStatisticCollector, conf);
         jobSparkOperatorMap.put(jobStats, sparkOperator);
         jobPlan.add(jobStats);
         jobStats.setBackendException(e);
     }
 
     public void addNativeJobStats(NativeSparkOperator sparkOperator, String jobId, boolean isSuccess, Exception e) {
-        SparkJobStats jobStats = new SparkJobStats(jobId, jobPlan, conf);
+        SparkJobStats jobStats = SparkShims.getInstance().sparkJobStats(jobId, jobPlan, conf);
         jobStats.setSuccessful(isSuccess);
         jobSparkOperatorMap.put(jobStats, sparkOperator);
         jobPlan.add(jobStats);
@@ -229,7 +230,7 @@
     private void addInputInfoForSparkOper(SparkOperator sparkOperator,
                                           SparkJobStats jobStats,
                                           boolean isSuccess,
-                                          JobMetricsListener jobMetricsListener,
+                                          JobStatisticCollector jobStatisticCollector,
                                           Configuration conf) {
         //to avoid repetition
         if (sparkOperatorsSet.contains(sparkOperator)) {
diff --git a/src/org/apache/pig/tools/pigstats/spark/SparkStatsUtil.java b/src/org/apache/pig/tools/pigstats/spark/SparkStatsUtil.java
index 1541264..12aae3e 100644
--- a/src/org/apache/pig/tools/pigstats/spark/SparkStatsUtil.java
+++ b/src/org/apache/pig/tools/pigstats/spark/SparkStatsUtil.java
@@ -26,7 +26,7 @@
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POSplit;
 import org.apache.pig.backend.hadoop.executionengine.physicalLayer.relationalOperators.POStore;
 import org.apache.pig.backend.hadoop.executionengine.spark.JobGraphBuilder;
-import org.apache.pig.backend.hadoop.executionengine.spark.JobMetricsListener;
+import org.apache.pig.backend.hadoop.executionengine.spark.JobStatisticCollector;
 import org.apache.pig.backend.hadoop.executionengine.spark.operator.NativeSparkOperator;
 import org.apache.pig.backend.hadoop.executionengine.spark.plan.SparkOperator;
 import org.apache.pig.tools.pigstats.PigStatsUtil;
@@ -44,7 +44,7 @@
 
     public static void waitForJobAddStats(int jobID,
                                           POStore poStore, SparkOperator sparkOperator,
-                                          JobMetricsListener jobMetricsListener,
+                                          JobStatisticCollector jobStatisticCollector,
                                           JavaSparkContext sparkContext,
                                           SparkPigStats sparkPigStats)
             throws InterruptedException {
@@ -55,20 +55,17 @@
         // "event bus" thread updating it's internal listener and
         // this driver thread calling SparkStatusTracker.
         // To workaround this, we will wait for this job to "finish".
-        jobMetricsListener.waitForJobToEnd(jobID);
-        sparkPigStats.addJobStats(poStore, sparkOperator, jobID, jobMetricsListener,
+        jobStatisticCollector.waitForJobToEnd(jobID);
+        sparkPigStats.addJobStats(poStore, sparkOperator, jobID, jobStatisticCollector,
                 sparkContext);
-        jobMetricsListener.cleanup(jobID);
+        jobStatisticCollector.cleanup(jobID);
     }
 
     public static void addFailJobStats(String jobID,
                                        POStore poStore, SparkOperator sparkOperator,
                                        SparkPigStats sparkPigStats,
                                        Exception e) {
-        JobMetricsListener jobMetricsListener = null;
-        JavaSparkContext sparkContext = null;
-        sparkPigStats.addFailJobStats(poStore, sparkOperator, jobID, jobMetricsListener,
-                sparkContext, e);
+        sparkPigStats.addFailJobStats(poStore, sparkOperator, jobID, null, null, e);
     }
 
     public static String getCounterName(POStore store) {
diff --git a/test/org/apache/pig/test/TestPigRunner.java b/test/org/apache/pig/test/TestPigRunner.java
index ec08417..ac106e0 100644
--- a/test/org/apache/pig/test/TestPigRunner.java
+++ b/test/org/apache/pig/test/TestPigRunner.java
@@ -60,7 +60,6 @@
 import org.apache.pig.tools.pigstats.PigStatsUtil;
 import org.apache.pig.tools.pigstats.mapreduce.MRJobStats;
 import org.apache.pig.tools.pigstats.mapreduce.MRPigStatsUtil;
-import org.apache.pig.tools.pigstats.spark.SparkJobStats;
 import org.junit.AfterClass;
 import org.junit.Assume;
 import org.junit.Before;