[MINOR] improvement(test): A better computation logic for WriteAndReadMetricsTest without using reflection (#1563)
### What changes were proposed in this pull request?
Use a better computation logic for WriteAndReadMetricsTest without using reflection.
### Why are the changes needed?
No need to use reflection, which will be quite confusing sometimes.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing UTs.
diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/WriteAndReadMetricsTest.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/WriteAndReadMetricsTest.java
index c7b014d..ec03ddb 100644
--- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/WriteAndReadMetricsTest.java
+++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/WriteAndReadMetricsTest.java
@@ -17,20 +17,17 @@
package org.apache.uniffle.test;
-import java.lang.reflect.InvocationTargetException;
-import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import scala.collection.Seq;
-
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.scheduler.SparkListener;
+import org.apache.spark.scheduler.SparkListenerTaskEnd;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
-import org.apache.spark.status.AppStatusStore;
-import org.apache.spark.status.api.v1.StageData;
import org.junit.jupiter.api.Test;
public class WriteAndReadMetricsTest extends SimpleTestBase {
@@ -42,6 +39,10 @@
@Override
public Map<String, Long> runTest(SparkSession spark, String fileName) throws Exception {
+ // Instantiate WriteAndReadMetricsSparkListener and add it to SparkContext
+ WriteAndReadMetricsSparkListener listener = new WriteAndReadMetricsSparkListener();
+ spark.sparkContext().addSparkListener(listener);
+
// take a rest to make sure shuffle server is registered
Thread.sleep(3000);
@@ -63,8 +64,8 @@
// take a rest to make sure all task metrics are updated before read stageData
Thread.sleep(100);
for (int stageId : spark.sparkContext().statusTracker().getJobInfo(0).get().stageIds()) {
- long writeRecords = getFirstStageData(spark, stageId).shuffleWriteRecords();
- long readRecords = getFirstStageData(spark, stageId).shuffleReadRecords();
+ long writeRecords = listener.getWriteRecords(stageId);
+ long readRecords = listener.getReadRecords(stageId);
result.put(stageId + "-write-records", writeRecords);
result.put(stageId + "-read-records", readRecords);
}
@@ -72,31 +73,31 @@
return result;
}
- private StageData getFirstStageData(SparkSession spark, int stageId)
- throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
- AppStatusStore statestore = spark.sparkContext().statusStore();
- try {
- return ((Seq<StageData>)
- statestore
- .getClass()
- .getDeclaredMethod("stageData", int.class, boolean.class)
- .invoke(statestore, stageId, false))
- .toList()
- .head();
- } catch (Exception e) {
- return ((Seq<StageData>)
- statestore
- .getClass()
- .getDeclaredMethod(
- "stageData",
- int.class,
- boolean.class,
- List.class,
- boolean.class,
- double[].class)
- .invoke(statestore, stageId, false, new ArrayList<>(), true, new double[] {}))
- .toList()
- .head();
+ private static class WriteAndReadMetricsSparkListener extends SparkListener {
+ private HashMap<Integer, Long> stageIdToWriteRecords = new HashMap<>();
+ private HashMap<Integer, Long> stageIdToReadRecords = new HashMap<>();
+
+ @Override
+ public void onTaskEnd(SparkListenerTaskEnd event) {
+ int stageId = event.stageId();
+ TaskMetrics taskMetrics = event.taskMetrics();
+ if (taskMetrics != null) {
+ long writeRecords = taskMetrics.shuffleWriteMetrics().recordsWritten();
+ long readRecords = taskMetrics.shuffleReadMetrics().recordsRead();
+ // Accumulate writeRecords and readRecords for the given stageId
+ stageIdToWriteRecords.put(
+ stageId, stageIdToWriteRecords.getOrDefault(stageId, 0L) + writeRecords);
+ stageIdToReadRecords.put(
+ stageId, stageIdToReadRecords.getOrDefault(stageId, 0L) + readRecords);
+ }
+ }
+
+ public long getWriteRecords(int stageId) {
+ return stageIdToWriteRecords.getOrDefault(stageId, 0L);
+ }
+
+ public long getReadRecords(int stageId) {
+ return stageIdToReadRecords.getOrDefault(stageId, 0L);
}
}
}