[SPARK-49478][CONNECT] Handle null metrics in ConnectProgressExecutionListener
### What changes were proposed in this pull request?
Handling null `TaskMetrics` in `ConnectProgressExecutionListenerSuite` by reporting 0 `inputBytesRead` on null.
### Why are the changes needed?
On task end, `TaskMetrics` may be `null`, as in the case of task failure (see [here](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala#L83)). This can cause NPEs for failed tasks with null metrics.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Added a new test for task done with `null` metrics.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47944 from davintjong-db/connect-progress-listener-null-metrics.
Lead-authored-by: Davin Tjong <davin.tjong@databricks.com>
Co-authored-by: Davin Tjong <107501978+davintjong-db@users.noreply.github.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala
index a188176..72c77fd 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala
@@ -144,7 +144,9 @@
tracker.stages.get(taskEnd.stageId).foreach { stage =>
stage.update { i =>
i.completedTasks += 1
- i.inputBytesRead += taskEnd.taskMetrics.inputMetrics.bytesRead
+ i.inputBytesRead += Option(taskEnd.taskMetrics)
+ .map(_.inputMetrics.bytesRead)
+ .getOrElse(0L)
}
}
// This should never become negative, simply reset to zero if it does.
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala
index 7c1b936..df5df23 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala
@@ -79,11 +79,16 @@
}
}
- test("taskDone") {
+ def testTaskDone(metricsPopulated: Boolean): Unit = {
val listener = new ConnectProgressExecutionListener
listener.registerJobTag(testTag)
listener.onJobStart(testJobStart)
+ val metricsOrNull = if (metricsPopulated) {
+ testStage1Task1Metrics
+ } else {
+ null
+ }
// Finish the tasks
val taskEnd = SparkListenerTaskEnd(
1,
@@ -92,7 +97,7 @@
Success,
testStage1Task1,
testStage1Task1ExecutorMetrics,
- testStage1Task1Metrics)
+ metricsOrNull)
val t = listener.trackedTags(testTag)
var yielded = false
@@ -117,7 +122,11 @@
assert(stages.map(_.numTasks).sum == 2)
assert(stages.map(_.completedTasks).sum == 1)
assert(stages.size == 2)
- assert(stages.map(_.inputBytesRead).sum == 500)
+ if (metricsPopulated) {
+ assert(stages.map(_.inputBytesRead).sum == 500)
+ } else {
+ assert(stages.map(_.inputBytesRead).sum == 0)
+ }
assert(
stages
.map(_.completed match {
@@ -140,7 +149,11 @@
assert(stages.map(_.numTasks).sum == 2)
assert(stages.map(_.completedTasks).sum == 1)
assert(stages.size == 2)
- assert(stages.map(_.inputBytesRead).sum == 500)
+ if (metricsPopulated) {
+ assert(stages.map(_.inputBytesRead).sum == 500)
+ } else {
+ assert(stages.map(_.inputBytesRead).sum == 0)
+ }
assert(
stages
.map(_.completed match {
@@ -153,4 +166,12 @@
assert(yielded, "Must updated with results")
}
+ test("taskDone - populated metrics") {
+ testTaskDone(metricsPopulated = true)
+ }
+
+ test("taskDone - null metrics") {
+ testTaskDone(metricsPopulated = false)
+ }
+
}