[#1751][0.9] improvement: support gluten (#1753)

* support gluten

* optimize

* fix bug

* nit

* fix spotless

* nit

* nit

* fix bug

* optimize

* optimize

* nit

* nit

* nit

* nit

* nit

* Update RssShuffleWriter.java
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 78bcc2c..45d338e 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -475,15 +475,6 @@
 
       int shuffleId = rssHandle.getShuffleId();
       String taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber();
-      ShuffleHandleInfo shuffleHandleInfo;
-      if (shuffleManagerRpcServiceEnabled) {
-        // Get the ShuffleServer list from the Driver based on the shuffleId
-        shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
-      } else {
-        shuffleHandleInfo =
-            new ShuffleHandleInfo(
-                shuffleId, rssHandle.getPartitionToServers(), rssHandle.getRemoteStorage());
-      }
       ShuffleWriteMetrics writeMetrics = context.taskMetrics().shuffleWriteMetrics();
       return new RssShuffleWriter<>(
           rssHandle.getAppId(),
@@ -496,8 +487,7 @@
           shuffleWriteClient,
           rssHandle,
           this::markFailedTask,
-          context,
-          shuffleHandleInfo);
+          context);
     } else {
       throw new RssException("Unexpected ShuffleHandle:" + handle.getClass().getName());
     }
@@ -806,6 +796,18 @@
         .createShuffleManagerClient(ClientType.GRPC, host, port);
   }
 
+  public ShuffleHandleInfo getShuffleHandleInfo(RssShuffleHandle<?, ?, ?> rssHandle) {
+    if (shuffleManagerRpcServiceEnabled) {
+      // Get the ShuffleServer list from the Driver based on the shuffleId
+      return getRemoteShuffleHandleInfo(rssHandle.getShuffleId());
+    } else {
+      return new ShuffleHandleInfo(
+          rssHandle.getShuffleId(),
+          rssHandle.getPartitionToServers(),
+          rssHandle.getRemoteStorage());
+    }
+  }
+
   /**
    * Get the ShuffleServer list from the Driver based on the shuffleId
    *
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 9e64b2f..37576c1 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -188,8 +188,7 @@
       ShuffleWriteClient shuffleWriteClient,
       RssShuffleHandle<K, V, C> rssHandle,
       Function<String, Boolean> taskFailureCallback,
-      TaskContext context,
-      ShuffleHandleInfo shuffleHandleInfo) {
+      TaskContext context) {
     this(
         appId,
         shuffleId,
@@ -201,9 +200,10 @@
         shuffleWriteClient,
         rssHandle,
         taskFailureCallback,
-        shuffleHandleInfo,
+        shuffleManager.getShuffleHandleInfo(rssHandle),
         context);
     BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
+    ShuffleHandleInfo shuffleHandleInfo = shuffleManager.getShuffleHandleInfo(rssHandle);
     final WriteBufferManager bufferManager =
         new WriteBufferManager(
             shuffleId,
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 6d9487c..700b769 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -141,7 +141,6 @@
   private boolean rssResubmitStage;
 
   private boolean taskBlockSendFailureRetryEnabled;
-
   private boolean shuffleManagerRpcServiceEnabled;
   /** A list of shuffleServer for Write failures */
   private Set<String> failuresShuffleServerIds;
@@ -514,15 +513,6 @@
     } else {
       writeMetrics = context.taskMetrics().shuffleWriteMetrics();
     }
-    ShuffleHandleInfo shuffleHandleInfo;
-    if (shuffleManagerRpcServiceEnabled) {
-      // Get the ShuffleServer list from the Driver based on the shuffleId
-      shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
-    } else {
-      shuffleHandleInfo =
-          new ShuffleHandleInfo(
-              shuffleId, rssHandle.getPartitionToServers(), rssHandle.getRemoteStorage());
-    }
     String taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber();
     LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(), rssHandle.getShuffleId());
     return new RssShuffleWriter<>(
@@ -536,8 +526,7 @@
         shuffleWriteClient,
         rssHandle,
         this::markFailedTask,
-        context,
-        shuffleHandleInfo);
+        context);
   }
 
   @Override
@@ -656,17 +645,7 @@
     RssShuffleHandle<K, ?, C> rssShuffleHandle = (RssShuffleHandle<K, ?, C>) handle;
     final int partitionNum = rssShuffleHandle.getDependency().partitioner().numPartitions();
     int shuffleId = rssShuffleHandle.getShuffleId();
-    ShuffleHandleInfo shuffleHandleInfo;
-    if (shuffleManagerRpcServiceEnabled) {
-      // Get the ShuffleServer list from the Driver based on the shuffleId
-      shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
-    } else {
-      shuffleHandleInfo =
-          new ShuffleHandleInfo(
-              shuffleId,
-              rssShuffleHandle.getPartitionToServers(),
-              rssShuffleHandle.getRemoteStorage());
-    }
+    ShuffleHandleInfo shuffleHandleInfo = getShuffleHandleInfo(rssShuffleHandle);
     Map<Integer, List<ShuffleServerInfo>> allPartitionToServers =
         shuffleHandleInfo.getPartitionToServers();
     Map<Integer, List<ShuffleServerInfo>> requirePartitionToServers =
@@ -1101,6 +1080,18 @@
         .createShuffleManagerClient(ClientType.GRPC, host, port);
   }
 
+  public ShuffleHandleInfo getShuffleHandleInfo(RssShuffleHandle<?, ?, ?> rssHandle) {
+    if (shuffleManagerRpcServiceEnabled) {
+      // Get the ShuffleServer list from the Driver based on the shuffleId
+      return getRemoteShuffleHandleInfo(rssHandle.getShuffleId());
+    } else {
+      return new ShuffleHandleInfo(
+          rssHandle.getShuffleId(),
+          rssHandle.getPartitionToServers(),
+          rssHandle.getRemoteStorage());
+    }
+  }
+
   /**
    * Get the ShuffleServer list from the Driver based on the shuffleId
    *
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 8a22b73..70ae3d8 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -95,6 +95,7 @@
 
   private final String appId;
   private final int shuffleId;
+  private final ShuffleHandleInfo shuffleHandleInfo;
   private WriteBufferManager bufferManager;
   private String taskId;
   private final int numMaps;
@@ -110,7 +111,8 @@
   private final ShuffleWriteClient shuffleWriteClient;
   private final Set<ShuffleServerInfo> shuffleServersForData;
   private final long[] partitionLengths;
-  private final boolean isMemoryShuffleEnabled;
+  // Gluten needs this variable
+  protected final boolean isMemoryShuffleEnabled;
   private final Function<String, Boolean> taskFailureCallback;
   private final Set<Long> blockIds = Sets.newConcurrentHashSet();
   private TaskContext taskContext;
@@ -195,6 +197,7 @@
     this.isMemoryShuffleEnabled =
         isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
     this.taskFailureCallback = taskFailureCallback;
+    this.shuffleHandleInfo = shuffleHandleInfo;
     this.taskContext = context;
     this.sparkConf = sparkConf;
     this.blockFailSentRetryEnabled =
@@ -204,6 +207,7 @@
             RssClientConf.RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED.defaultValue());
   }
 
+  // Gluten needs this constructor
   public RssShuffleWriter(
       String appId,
       int shuffleId,
@@ -215,8 +219,7 @@
       ShuffleWriteClient shuffleWriteClient,
       RssShuffleHandle<K, V, C> rssHandle,
       Function<String, Boolean> taskFailureCallback,
-      TaskContext context,
-      ShuffleHandleInfo shuffleHandleInfo) {
+      TaskContext context) {
     this(
         appId,
         shuffleId,
@@ -228,7 +231,7 @@
         shuffleWriteClient,
         rssHandle,
         taskFailureCallback,
-        shuffleHandleInfo,
+        shuffleManager.getShuffleHandleInfo(rssHandle),
         context);
     BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
     final WriteBufferManager bufferManager =
@@ -264,7 +267,8 @@
     }
   }
 
-  private void writeImpl(Iterator<Product2<K, V>> records) {
+  // Gluten needs this method.
+  protected void writeImpl(Iterator<Product2<K, V>> records) {
     List<ShuffleBlockInfo> shuffleBlockInfos;
     boolean isCombine = shuffleDependency.mapSideCombine();
     Function1<V, C> createCombiner = null;
@@ -322,6 +326,11 @@
             + bufferManager.getManagerCostInfo());
   }
 
+  // Gluten needs this method
+  protected void internalCheckBlockSendResult() {
+    this.checkBlockSendResult(this.blockIds);
+  }
+
   private void checkSentRecordCount(long recordCount) {
     if (recordCount != bufferManager.getRecordCount()) {
       String errorMsg =