[SPARK-48035][SQL][FOLLOWUP] Fix try_add/try_multiply being semantic equal to add/multiply
### What changes were proposed in this pull request?
- This is a follow-up to the previous PR: https://github.com/apache/spark/pull/46307.
- With the new changes we do the evalMode check in the `collectOperands` function instead of introducing a new function.
### Why are the changes needed?
- Better code quality and readability.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- Existing unit tests.
### Was this patch authored or co-authored using generative AI tooling?
- No
Closes #46414 from db-scnakandala/db-scnakandala/master.
Authored-by: Supun Nakandala <supun.nakandala@databricks.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 2759f5a..de15ec4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -1378,20 +1378,6 @@
}
reorderResult
}
-
- /**
- * Helper method to collect the evaluation mode of the commutative expressions. This is
- * used by the canonicalized methods of [[Add]] and [[Multiply]] operators to ensure that
- * all operands have the same evaluation mode before reordering the operands.
- */
- protected def collectEvalModes(
- e: Expression,
- f: PartialFunction[CommutativeExpression, Seq[EvalMode.Value]]
- ): Seq[EvalMode.Value] = e match {
- case c: CommutativeExpression if f.isDefinedAt(c) =>
- f(c) ++ c.children.flatMap(collectEvalModes(_, f))
- case _ => Nil
- }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 91c10a5..a085a4e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -452,14 +452,12 @@
copy(left = newLeft, right = newRight)
override lazy val canonicalized: Expression = {
- val evalModes = collectEvalModes(this, {case Add(_, _, evalMode) => Seq(evalMode)})
- lazy val reorderResult = buildCanonicalizedPlan(
- { case Add(l, r, _) => Seq(l, r) },
+ val reorderResult = buildCanonicalizedPlan(
+ { case Add(l, r, em) if em == evalMode => Seq(l, r) },
{ case (l: Expression, r: Expression) => Add(l, r, evalMode)},
Some(evalMode)
)
- if (resolved && evalModes.forall(_ == evalMode) && reorderResult.resolved &&
- reorderResult.dataType == dataType) {
+ if (resolved && reorderResult.resolved && reorderResult.dataType == dataType) {
reorderResult
} else {
// SPARK-40903: Avoid reordering decimal Add for canonicalization if the result data type is
@@ -609,16 +607,11 @@
newLeft: Expression, newRight: Expression): Multiply = copy(left = newLeft, right = newRight)
override lazy val canonicalized: Expression = {
- val evalModes = collectEvalModes(this, {case Multiply(_, _, evalMode) => Seq(evalMode)})
- if (evalModes.forall(_ == evalMode)) {
- buildCanonicalizedPlan(
- { case Multiply(l, r, _) => Seq(l, r) },
- { case (l: Expression, r: Expression) => Multiply(l, r, evalMode)},
- Some(evalMode)
- )
- } else {
- withCanonicalizedChildren
- }
+ buildCanonicalizedPlan(
+ { case Multiply(l, r, em) if em == evalMode => Seq(l, r) },
+ { case (l: Expression, r: Expression) => Multiply(l, r, evalMode) },
+ Some(evalMode)
+ )
}
}