[SPARK-49836][SQL][SS] Fix possibly broken query when window is provided to window/session_window fn
### What changes were proposed in this pull request?
This PR fixes the correctness issue about losing operators during analysis - it happens when window is provided to window()/session_window() function.
The rule `TimeWindowing` and `SessionWindowing` are responsible to resolve the time window functions. When the window function has `window` as parameter (time column) (in other words, building time window from time window), the rule wraps window with WindowTime function so that the rule ResolveWindowTime will further resolve this. (And TimeWindowing/SessionWindowing will resolve this again against the result of ResolveWindowTime.)
The issue is that the rule uses "return" for the above, which intends to have "early return" as the other branch is too long compared to this branch. This unfortunately does not work as intended - the intention is just to go out of current local scope (mostly end of curly brace), but it seems to break the loop of execution in "outer" side.
(I haven't debugged further but it's simply clear that it doesn't work as intended.)
Quoting from Scala doc:
> Nonlocal returns are implemented by throwing and catching scala.runtime.NonLocalReturnException-s.
It's not super clear where NonLocalReturnException is caught in the call stack; it might exit the execution for much broader scope (context) than expected. And it's finally deprecated in Scala 3.2 and likely be removed in future.
https://dotty.epfl.ch/docs/reference/dropped-features/nonlocal-returns.html
Interestingly it does not break every query for chained time window aggregations. Spark already has several tests with DataFrame API and they haven't failed. The reproducer in community report is using SQL statement - where each aggregation is considered as subquery.
This PR fixes the rule to NOT use early return and instead have a huge if else.
### Why are the changes needed?
Described in above.
### Does this PR introduce _any_ user-facing change?
Yes, this fixes the possible query breakage. The impacted workloads may not be very huge as chained time window aggregations is an advanced usage, and it does not break every query for the usage.
### How was this patch tested?
New UTs.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48309 from HeartSaVioR/SPARK-49836.
Lead-authored-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
Co-authored-by: Andrzej Zera <andrzejzera@gmail.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala
index e506a36..a8680d0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala
@@ -87,85 +87,86 @@
val window = windowExpressions.head
+ // time window is provided as time column of window function, replace it with WindowTime
if (StructType.acceptsType(window.timeColumn.dataType)) {
- return p.transformExpressions {
+ p.transformExpressions {
case t: TimeWindow => t.copy(timeColumn = WindowTime(window.timeColumn))
}
- }
-
- val metadata = window.timeColumn match {
- case a: Attribute => a.metadata
- case _ => Metadata.empty
- }
-
- val newMetadata = new MetadataBuilder()
- .withMetadata(metadata)
- .putBoolean(TimeWindow.marker, true)
- .build()
-
- def getWindow(i: Int, dataType: DataType): Expression = {
- val timestamp = PreciseTimestampConversion(window.timeColumn, dataType, LongType)
- val remainder = (timestamp - window.startTime) % window.slideDuration
- val lastStart = timestamp - CaseWhen(Seq((LessThan(remainder, 0),
- remainder + window.slideDuration)), Some(remainder))
- val windowStart = lastStart - i * window.slideDuration
- val windowEnd = windowStart + window.windowDuration
-
- // We make sure value fields are nullable since the dataType of TimeWindow defines them
- // as nullable.
- CreateNamedStruct(
- Literal(WINDOW_START) ::
- PreciseTimestampConversion(windowStart, LongType, dataType).castNullable() ::
- Literal(WINDOW_END) ::
- PreciseTimestampConversion(windowEnd, LongType, dataType).castNullable() ::
- Nil)
- }
-
- val windowAttr = AttributeReference(
- WINDOW_COL_NAME, window.dataType, metadata = newMetadata)()
-
- if (window.windowDuration == window.slideDuration) {
- val windowStruct = Alias(getWindow(0, window.timeColumn.dataType), WINDOW_COL_NAME)(
- exprId = windowAttr.exprId, explicitMetadata = Some(newMetadata))
-
- val replacedPlan = p transformExpressions {
- case t: TimeWindow => windowAttr
- }
-
- // For backwards compatibility we add a filter to filter out nulls
- val filterExpr = IsNotNull(window.timeColumn)
-
- replacedPlan.withNewChildren(
- Project(windowStruct +: child.output,
- Filter(filterExpr, child)) :: Nil)
} else {
- val overlappingWindows =
- math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt
- val windows =
- Seq.tabulate(overlappingWindows)(i =>
- getWindow(i, window.timeColumn.dataType))
+ val metadata = window.timeColumn match {
+ case a: Attribute => a.metadata
+ case _ => Metadata.empty
+ }
- val projections = windows.map(_ +: child.output)
+ val newMetadata = new MetadataBuilder()
+ .withMetadata(metadata)
+ .putBoolean(TimeWindow.marker, true)
+ .build()
- // When the condition windowDuration % slideDuration = 0 is fulfilled,
- // the estimation of the number of windows becomes exact one,
- // which means all produced windows are valid.
- val filterExpr =
- if (window.windowDuration % window.slideDuration == 0) {
- IsNotNull(window.timeColumn)
+ def getWindow(i: Int, dataType: DataType): Expression = {
+ val timestamp = PreciseTimestampConversion(window.timeColumn, dataType, LongType)
+ val remainder = (timestamp - window.startTime) % window.slideDuration
+ val lastStart = timestamp - CaseWhen(Seq((LessThan(remainder, 0),
+ remainder + window.slideDuration)), Some(remainder))
+ val windowStart = lastStart - i * window.slideDuration
+ val windowEnd = windowStart + window.windowDuration
+
+ // We make sure value fields are nullable since the dataType of TimeWindow defines them
+ // as nullable.
+ CreateNamedStruct(
+ Literal(WINDOW_START) ::
+ PreciseTimestampConversion(windowStart, LongType, dataType).castNullable() ::
+ Literal(WINDOW_END) ::
+ PreciseTimestampConversion(windowEnd, LongType, dataType).castNullable() ::
+ Nil)
+ }
+
+ val windowAttr = AttributeReference(
+ WINDOW_COL_NAME, window.dataType, metadata = newMetadata)()
+
+ if (window.windowDuration == window.slideDuration) {
+ val windowStruct = Alias(getWindow(0, window.timeColumn.dataType), WINDOW_COL_NAME)(
+ exprId = windowAttr.exprId, explicitMetadata = Some(newMetadata))
+
+ val replacedPlan = p transformExpressions {
+ case t: TimeWindow => windowAttr
+ }
+
+ // For backwards compatibility we add a filter to filter out nulls
+ val filterExpr = IsNotNull(window.timeColumn)
+
+ replacedPlan.withNewChildren(
+ Project(windowStruct +: child.output,
+ Filter(filterExpr, child)) :: Nil)
} else {
- window.timeColumn >= windowAttr.getField(WINDOW_START) &&
- window.timeColumn < windowAttr.getField(WINDOW_END)
+ val overlappingWindows =
+ math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt
+ val windows =
+ Seq.tabulate(overlappingWindows)(i =>
+ getWindow(i, window.timeColumn.dataType))
+
+ val projections = windows.map(_ +: child.output)
+
+ // When the condition windowDuration % slideDuration = 0 is fulfilled,
+ // the estimation of the number of windows becomes exact one,
+ // which means all produced windows are valid.
+ val filterExpr =
+ if (window.windowDuration % window.slideDuration == 0) {
+ IsNotNull(window.timeColumn)
+ } else {
+ window.timeColumn >= windowAttr.getField(WINDOW_START) &&
+ window.timeColumn < windowAttr.getField(WINDOW_END)
+ }
+
+ val substitutedPlan = Filter(filterExpr,
+ Expand(projections, windowAttr +: child.output, child))
+
+ val renamedPlan = p transformExpressions {
+ case t: TimeWindow => windowAttr
+ }
+
+ renamedPlan.withNewChildren(substitutedPlan :: Nil)
}
-
- val substitutedPlan = Filter(filterExpr,
- Expand(projections, windowAttr +: child.output, child))
-
- val renamedPlan = p transformExpressions {
- case t: TimeWindow => windowAttr
- }
-
- renamedPlan.withNewChildren(substitutedPlan :: Nil)
}
} else if (numWindowExpr > 1) {
throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p)
@@ -210,74 +211,74 @@
val session = sessionExpressions.head
if (StructType.acceptsType(session.timeColumn.dataType)) {
- return p transformExpressions {
+ p transformExpressions {
case t: SessionWindow => t.copy(timeColumn = WindowTime(session.timeColumn))
}
- }
-
- val metadata = session.timeColumn match {
- case a: Attribute => a.metadata
- case _ => Metadata.empty
- }
-
- val newMetadata = new MetadataBuilder()
- .withMetadata(metadata)
- .putBoolean(SessionWindow.marker, true)
- .build()
-
- val sessionAttr = AttributeReference(
- SESSION_COL_NAME, session.dataType, metadata = newMetadata)()
-
- val sessionStart =
- PreciseTimestampConversion(session.timeColumn, session.timeColumn.dataType, LongType)
- val gapDuration = session.gapDuration match {
- case expr if expr.dataType == CalendarIntervalType =>
- expr
- case expr if Cast.canCast(expr.dataType, CalendarIntervalType) =>
- Cast(expr, CalendarIntervalType)
- case other =>
- throw QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType)
- }
- val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration,
- session.timeColumn.dataType, LongType)
-
- // We make sure value fields are nullable since the dataType of SessionWindow defines them
- // as nullable.
- val literalSessionStruct = CreateNamedStruct(
- Literal(SESSION_START) ::
- PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType)
- .castNullable() ::
- Literal(SESSION_END) ::
- PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType)
- .castNullable() ::
- Nil)
-
- val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)(
- exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata))
-
- val replacedPlan = p transformExpressions {
- case s: SessionWindow => sessionAttr
- }
-
- val filterByTimeRange = if (gapDuration.foldable) {
- val interval = gapDuration.eval().asInstanceOf[CalendarInterval]
- interval == null || interval.months + interval.days + interval.microseconds <= 0
} else {
- true
- }
+ val metadata = session.timeColumn match {
+ case a: Attribute => a.metadata
+ case _ => Metadata.empty
+ }
- // As same as tumbling window, we add a filter to filter out nulls.
- // And we also filter out events with negative or zero or invalid gap duration.
- val filterExpr = if (filterByTimeRange) {
- IsNotNull(session.timeColumn) &&
- (sessionAttr.getField(SESSION_END) > sessionAttr.getField(SESSION_START))
- } else {
- IsNotNull(session.timeColumn)
- }
+ val newMetadata = new MetadataBuilder()
+ .withMetadata(metadata)
+ .putBoolean(SessionWindow.marker, true)
+ .build()
- replacedPlan.withNewChildren(
- Filter(filterExpr,
- Project(sessionStruct +: child.output, child)) :: Nil)
+ val sessionAttr = AttributeReference(
+ SESSION_COL_NAME, session.dataType, metadata = newMetadata)()
+
+ val sessionStart =
+ PreciseTimestampConversion(session.timeColumn, session.timeColumn.dataType, LongType)
+ val gapDuration = session.gapDuration match {
+ case expr if expr.dataType == CalendarIntervalType =>
+ expr
+ case expr if Cast.canCast(expr.dataType, CalendarIntervalType) =>
+ Cast(expr, CalendarIntervalType)
+ case other =>
+ throw QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType)
+ }
+ val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration,
+ session.timeColumn.dataType, LongType)
+
+ // We make sure value fields are nullable since the dataType of SessionWindow defines them
+ // as nullable.
+ val literalSessionStruct = CreateNamedStruct(
+ Literal(SESSION_START) ::
+ PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType)
+ .castNullable() ::
+ Literal(SESSION_END) ::
+ PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType)
+ .castNullable() ::
+ Nil)
+
+ val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)(
+ exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata))
+
+ val replacedPlan = p transformExpressions {
+ case s: SessionWindow => sessionAttr
+ }
+
+ val filterByTimeRange = if (gapDuration.foldable) {
+ val interval = gapDuration.eval().asInstanceOf[CalendarInterval]
+ interval == null || interval.months + interval.days + interval.microseconds <= 0
+ } else {
+ true
+ }
+
+ // As same as tumbling window, we add a filter to filter out nulls.
+ // And we also filter out events with negative or zero or invalid gap duration.
+ val filterExpr = if (filterByTimeRange) {
+ IsNotNull(session.timeColumn) &&
+ (sessionAttr.getField(SESSION_END) > sessionAttr.getField(SESSION_START))
+ } else {
+ IsNotNull(session.timeColumn)
+ }
+
+ replacedPlan.withNewChildren(
+ Filter(filterExpr,
+ Project(sessionStruct +: child.output, child)) :: Nil)
+ }
} else if (numWindowExpr > 1) {
throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p)
} else {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
index 1ac1dda..6c1ca94 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
@@ -547,4 +547,55 @@
}
}
}
+
+ test("SPARK-49836 using window fn with window as parameter should preserve parent operator") {
+ withTempView("clicks") {
+ val df = Seq(
+ // small window: [00:00, 01:00), user1, 2
+ ("2024-09-30 00:00:00", "user1"), ("2024-09-30 00:00:30", "user1"),
+ // small window: [01:00, 02:00), user2, 2
+ ("2024-09-30 00:01:00", "user2"), ("2024-09-30 00:01:30", "user2"),
+ // small window: [03:00, 04:00), user1, 1
+ ("2024-09-30 00:03:30", "user1"),
+ // small window: [11:00, 12:00), user1, 3
+ ("2024-09-30 00:11:00", "user1"), ("2024-09-30 00:11:30", "user1"),
+ ("2024-09-30 00:11:45", "user1")
+ ).toDF("eventTime", "userId")
+
+ // session window: (01:00, 09:00), user1, 3 / (02:00, 07:00), user2, 2 /
+ // (12:00, 12:05), user1, 3
+
+ df.createOrReplaceTempView("clicks")
+
+ val aggregatedData = spark.sql(
+ """
+ |SELECT
+ | userId,
+ | avg(cpu_large.numClicks) AS clicksPerSession
+ |FROM
+ |(
+ | SELECT
+ | session_window(small_window, '5 minutes') AS session,
+ | userId,
+ | sum(numClicks) AS numClicks
+ | FROM
+ | (
+ | SELECT
+ | window(eventTime, '1 minute') AS small_window,
+ | userId,
+ | count(*) AS numClicks
+ | FROM clicks
+ | GROUP BY window, userId
+ | ) cpu_small
+ | GROUP BY session_window, userId
+ |) cpu_large
+ |GROUP BY userId
+ |""".stripMargin)
+
+ checkAnswer(
+ aggregatedData,
+ Seq(Row("user1", 3), Row("user2", 2))
+ )
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
index 6ee173b..c52d428 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import java.sql.Timestamp
import java.time.LocalDateTime
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -714,4 +715,56 @@
)
}
}
+
+ test("SPARK-49836 using window fn with window as parameter should preserve parent operator") {
+ withTempView("clicks") {
+ val df = Seq(
+ // small window: [00:00, 01:00), user1, 2
+ ("2024-09-30 00:00:00", "user1"), ("2024-09-30 00:00:30", "user1"),
+ // small window: [01:00, 02:00), user2, 2
+ ("2024-09-30 00:01:00", "user2"), ("2024-09-30 00:01:30", "user2"),
+ // small window: [07:00, 08:00), user1, 1
+ ("2024-09-30 00:07:00", "user1"),
+ // small window: [11:00, 12:00), user1, 3
+ ("2024-09-30 00:11:00", "user1"), ("2024-09-30 00:11:30", "user1"),
+ ("2024-09-30 00:11:45", "user1")
+ ).toDF("eventTime", "userId")
+
+ // large window: [00:00, 10:00), user1, 3, [00:00, 10:00), user2, 2, [10:00, 20:00), user1, 3
+
+ df.createOrReplaceTempView("clicks")
+
+ val aggregatedData = spark.sql(
+ """
+ |SELECT
+ | cpu_large.large_window.end AS timestamp,
+ | avg(cpu_large.numClicks) AS avgClicksPerUser
+ |FROM
+ |(
+ | SELECT
+ | window(small_window, '10 minutes') AS large_window,
+ | userId,
+ | sum(numClicks) AS numClicks
+ | FROM
+ | (
+ | SELECT
+ | window(eventTime, '1 minute') AS small_window,
+ | userId,
+ | count(*) AS numClicks
+ | FROM clicks
+ | GROUP BY window, userId
+ | ) cpu_small
+ | GROUP BY window, userId
+ |) cpu_large
+ |GROUP BY timestamp
+ |""".stripMargin)
+
+ checkAnswer(
+ aggregatedData,
+ Seq(
+ Row(Timestamp.valueOf("2024-09-30 00:10:00"), 2.5),
+ Row(Timestamp.valueOf("2024-09-30 00:20:00"), 3))
+ )
+ }
+ }
}