[#1751][0.9] improvement: support gluten (#1753)
* support gluten
* optimize
* fix bug
* nit
* fix spotless
* nit
* nit
* fix bug
* optimize
* optimize
* nit
* nit
* nit
* nit
* nit
* Update RssShuffleWriter.java
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 78bcc2c..45d338e 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -475,15 +475,6 @@
int shuffleId = rssHandle.getShuffleId();
String taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber();
- ShuffleHandleInfo shuffleHandleInfo;
- if (shuffleManagerRpcServiceEnabled) {
- // Get the ShuffleServer list from the Driver based on the shuffleId
- shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
- } else {
- shuffleHandleInfo =
- new ShuffleHandleInfo(
- shuffleId, rssHandle.getPartitionToServers(), rssHandle.getRemoteStorage());
- }
ShuffleWriteMetrics writeMetrics = context.taskMetrics().shuffleWriteMetrics();
return new RssShuffleWriter<>(
rssHandle.getAppId(),
@@ -496,8 +487,7 @@
shuffleWriteClient,
rssHandle,
this::markFailedTask,
- context,
- shuffleHandleInfo);
+ context);
} else {
throw new RssException("Unexpected ShuffleHandle:" + handle.getClass().getName());
}
@@ -806,6 +796,18 @@
.createShuffleManagerClient(ClientType.GRPC, host, port);
}
+ public ShuffleHandleInfo getShuffleHandleInfo(RssShuffleHandle<?, ?, ?> rssHandle) {
+ if (shuffleManagerRpcServiceEnabled) {
+ // Get the ShuffleServer list from the Driver based on the shuffleId
+ return getRemoteShuffleHandleInfo(rssHandle.getShuffleId());
+ } else {
+ return new ShuffleHandleInfo(
+ rssHandle.getShuffleId(),
+ rssHandle.getPartitionToServers(),
+ rssHandle.getRemoteStorage());
+ }
+ }
+
/**
* Get the ShuffleServer list from the Driver based on the shuffleId
*
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 9e64b2f..37576c1 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -188,8 +188,7 @@
ShuffleWriteClient shuffleWriteClient,
RssShuffleHandle<K, V, C> rssHandle,
Function<String, Boolean> taskFailureCallback,
- TaskContext context,
- ShuffleHandleInfo shuffleHandleInfo) {
+ TaskContext context) {
this(
appId,
shuffleId,
@@ -201,9 +200,10 @@
shuffleWriteClient,
rssHandle,
taskFailureCallback,
- shuffleHandleInfo,
+ shuffleManager.getShuffleHandleInfo(rssHandle),
context);
BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
+ ShuffleHandleInfo shuffleHandleInfo = shuffleManager.getShuffleHandleInfo(rssHandle);
final WriteBufferManager bufferManager =
new WriteBufferManager(
shuffleId,
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 6d9487c..700b769 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -141,7 +141,6 @@
private boolean rssResubmitStage;
private boolean taskBlockSendFailureRetryEnabled;
-
private boolean shuffleManagerRpcServiceEnabled;
/** A list of shuffleServer for Write failures */
private Set<String> failuresShuffleServerIds;
@@ -514,15 +513,6 @@
} else {
writeMetrics = context.taskMetrics().shuffleWriteMetrics();
}
- ShuffleHandleInfo shuffleHandleInfo;
- if (shuffleManagerRpcServiceEnabled) {
- // Get the ShuffleServer list from the Driver based on the shuffleId
- shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
- } else {
- shuffleHandleInfo =
- new ShuffleHandleInfo(
- shuffleId, rssHandle.getPartitionToServers(), rssHandle.getRemoteStorage());
- }
String taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber();
LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(), rssHandle.getShuffleId());
return new RssShuffleWriter<>(
@@ -536,8 +526,7 @@
shuffleWriteClient,
rssHandle,
this::markFailedTask,
- context,
- shuffleHandleInfo);
+ context);
}
@Override
@@ -656,17 +645,7 @@
RssShuffleHandle<K, ?, C> rssShuffleHandle = (RssShuffleHandle<K, ?, C>) handle;
final int partitionNum = rssShuffleHandle.getDependency().partitioner().numPartitions();
int shuffleId = rssShuffleHandle.getShuffleId();
- ShuffleHandleInfo shuffleHandleInfo;
- if (shuffleManagerRpcServiceEnabled) {
- // Get the ShuffleServer list from the Driver based on the shuffleId
- shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
- } else {
- shuffleHandleInfo =
- new ShuffleHandleInfo(
- shuffleId,
- rssShuffleHandle.getPartitionToServers(),
- rssShuffleHandle.getRemoteStorage());
- }
+ ShuffleHandleInfo shuffleHandleInfo = getShuffleHandleInfo(rssShuffleHandle);
Map<Integer, List<ShuffleServerInfo>> allPartitionToServers =
shuffleHandleInfo.getPartitionToServers();
Map<Integer, List<ShuffleServerInfo>> requirePartitionToServers =
@@ -1101,6 +1080,18 @@
.createShuffleManagerClient(ClientType.GRPC, host, port);
}
+ public ShuffleHandleInfo getShuffleHandleInfo(RssShuffleHandle<?, ?, ?> rssHandle) {
+ if (shuffleManagerRpcServiceEnabled) {
+ // Get the ShuffleServer list from the Driver based on the shuffleId
+ return getRemoteShuffleHandleInfo(rssHandle.getShuffleId());
+ } else {
+ return new ShuffleHandleInfo(
+ rssHandle.getShuffleId(),
+ rssHandle.getPartitionToServers(),
+ rssHandle.getRemoteStorage());
+ }
+ }
+
/**
* Get the ShuffleServer list from the Driver based on the shuffleId
*
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 8a22b73..70ae3d8 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -95,6 +95,7 @@
private final String appId;
private final int shuffleId;
+ private final ShuffleHandleInfo shuffleHandleInfo;
private WriteBufferManager bufferManager;
private String taskId;
private final int numMaps;
@@ -110,7 +111,8 @@
private final ShuffleWriteClient shuffleWriteClient;
private final Set<ShuffleServerInfo> shuffleServersForData;
private final long[] partitionLengths;
- private final boolean isMemoryShuffleEnabled;
+ // Gluten needs this variable
+ protected final boolean isMemoryShuffleEnabled;
private final Function<String, Boolean> taskFailureCallback;
private final Set<Long> blockIds = Sets.newConcurrentHashSet();
private TaskContext taskContext;
@@ -195,6 +197,7 @@
this.isMemoryShuffleEnabled =
isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
this.taskFailureCallback = taskFailureCallback;
+ this.shuffleHandleInfo = shuffleHandleInfo;
this.taskContext = context;
this.sparkConf = sparkConf;
this.blockFailSentRetryEnabled =
@@ -204,6 +207,7 @@
RssClientConf.RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED.defaultValue());
}
+ // Gluten needs this constructor
public RssShuffleWriter(
String appId,
int shuffleId,
@@ -215,8 +219,7 @@
ShuffleWriteClient shuffleWriteClient,
RssShuffleHandle<K, V, C> rssHandle,
Function<String, Boolean> taskFailureCallback,
- TaskContext context,
- ShuffleHandleInfo shuffleHandleInfo) {
+ TaskContext context) {
this(
appId,
shuffleId,
@@ -228,7 +231,7 @@
shuffleWriteClient,
rssHandle,
taskFailureCallback,
- shuffleHandleInfo,
+ shuffleManager.getShuffleHandleInfo(rssHandle),
context);
BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
final WriteBufferManager bufferManager =
@@ -264,7 +267,8 @@
}
}
- private void writeImpl(Iterator<Product2<K, V>> records) {
+ // Gluten needs this method.
+ protected void writeImpl(Iterator<Product2<K, V>> records) {
List<ShuffleBlockInfo> shuffleBlockInfos;
boolean isCombine = shuffleDependency.mapSideCombine();
Function1<V, C> createCombiner = null;
@@ -322,6 +326,11 @@
+ bufferManager.getManagerCostInfo());
}
+ // Gluten needs this method
+ protected void internalCheckBlockSendResult() {
+ this.checkBlockSendResult(this.blockIds);
+ }
+
private void checkSentRecordCount(long recordCount) {
if (recordCount != bufferManager.getRecordCount()) {
String errorMsg =