[SPARK-26806][SS] EventTimeStats.merge should handle zeros correctly
## What changes were proposed in this pull request?
Right now, EventTimeStats.merge doesn't handle `zero.merge(zero)` correctly. This will make `avg` become `NaN`. And whatever gets merged with the result of `zero.merge(zero)`, `avg` will still be `NaN`. Then finally, we call `NaN.toLong` and get `0`, and the user will see the following incorrect report:
```
"eventTime" : {
"avg" : "1970-01-01T00:00:00.000Z",
"max" : "2019-01-31T12:57:00.000Z",
"min" : "2019-01-30T18:44:04.000Z",
"watermark" : "1970-01-01T00:00:00.000Z"
}
```
This issue was reported by liancheng .
This PR fixes the above issue.
## How was this patch tested?
The new unit tests.
Closes #23718 from zsxwing/merge-zero.
Authored-by: Shixiong Zhu <zsxwing@gmail.com>
Signed-off-by: Shixiong Zhu <zsxwing@gmail.com>
(cherry picked from commit 03a928cbecaf38bbbab3e6b957fcbb542771cfbd)
Signed-off-by: Shixiong Zhu <zsxwing@gmail.com>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
index 55e7508..4069633 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
@@ -36,10 +36,19 @@
}
def merge(that: EventTimeStats): Unit = {
- this.max = math.max(this.max, that.max)
- this.min = math.min(this.min, that.min)
- this.count += that.count
- this.avg += (that.avg - this.avg) * that.count / this.count
+ if (that.count == 0) {
+ // no-op
+ } else if (this.count == 0) {
+ this.max = that.max
+ this.min = that.min
+ this.count = that.count
+ this.avg = that.avg
+ } else {
+ this.max = math.max(this.max, that.max)
+ this.min = math.min(this.min, that.min)
+ this.count += that.count
+ this.avg += (that.avg - this.avg) * that.count / this.count
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
index 4f19fa0b..14a193f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
@@ -38,9 +38,9 @@
sqlContext.streams.active.foreach(_.stop())
}
- test("EventTimeStats") {
- val epsilon = 10E-6
+ private val epsilon = 10E-6
+ test("EventTimeStats") {
val stats = EventTimeStats(max = 100, min = 10, avg = 20.0, count = 5)
stats.add(80L)
stats.max should be (100)
@@ -57,7 +57,6 @@
}
test("EventTimeStats: avg on large values") {
- val epsilon = 10E-6
val largeValue = 10000000000L // 10B
// Make sure `largeValue` will cause overflow if we use a Long sum to calc avg.
assert(largeValue * largeValue != BigInt(largeValue) * BigInt(largeValue))
@@ -75,6 +74,33 @@
stats.avg should be ((largeValue + 0.5) +- epsilon)
}
+ test("EventTimeStats: zero merge zero") {
+ val stats = EventTimeStats.zero
+ val stats2 = EventTimeStats.zero
+ stats.merge(stats2)
+ stats should be (EventTimeStats.zero)
+ }
+
+ test("EventTimeStats: non-zero merge zero") {
+ val stats = EventTimeStats(max = 10, min = 1, avg = 5.0, count = 3)
+ val stats2 = EventTimeStats.zero
+ stats.merge(stats2)
+ stats.max should be (10L)
+ stats.min should be (1L)
+ stats.avg should be (5.0 +- epsilon)
+ stats.count should be (3L)
+ }
+
+ test("EventTimeStats: zero merge non-zero") {
+ val stats = EventTimeStats.zero
+ val stats2 = EventTimeStats(max = 10, min = 1, avg = 5.0, count = 3)
+ stats.merge(stats2)
+ stats.max should be (10L)
+ stats.min should be (1L)
+ stats.avg should be (5.0 +- epsilon)
+ stats.count should be (3L)
+ }
+
test("error on bad column") {
val inputData = MemoryStream[Int].toDF()
val e = intercept[AnalysisException] {