[SPARK-49864][SQL] Improve message of BINARY_ARITHMETIC_OVERFLOW
### What changes were proposed in this pull request?
BINARY_ARITHMETIC_OVERFLOW did not have a suggestion on bypassing the error. This PR improves on that.
### Why are the changes needed?
All errors should suggest a way to overcome an issue, so that customers can fix problems easier.
### Does this PR introduce _any_ user-facing change?
Yes, change in error message.
### How was this patch tested?
Tests added for all paths for bytes.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48335 from mihailom-db/binary_arithmetic_overflow.
Authored-by: Mihailo Milosevic <mihailo.milosevic@databricks.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json
index 940ed17..eca3587 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -127,7 +127,7 @@
},
"BINARY_ARITHMETIC_OVERFLOW" : {
"message" : [
- "<value1> <symbol> <value2> caused overflow."
+ "<value1> <symbol> <value2> caused overflow. Use <functionName> to ignore overflow problem and return NULL."
],
"sqlState" : "22003"
},
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index f889c3e..d8ba1fe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -294,12 +294,18 @@
case ByteType | ShortType =>
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val tmpResult = ctx.freshName("tmpResult")
+ val try_suggestion = symbol match {
+ case "+" => "try_add"
+ case "-" => "try_subtract"
+ case "*" => "try_multiply"
+ case _ => ""
+ }
val overflowCheck = if (failOnError) {
val javaType = CodeGenerator.boxedType(dataType)
s"""
|if ($tmpResult < $javaType.MIN_VALUE || $tmpResult > $javaType.MAX_VALUE) {
| throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(
- | $eval1, "$symbol", $eval2);
+ | $eval1, "$symbol", $eval2, "$try_suggestion");
|}
""".stripMargin
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 4a23e97..5e3aa3e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -610,13 +610,17 @@
}
def binaryArithmeticCauseOverflowError(
- eval1: Short, symbol: String, eval2: Short): SparkArithmeticException = {
+ eval1: Short,
+ symbol: String,
+ eval2: Short,
+ suggestedFunc: String): SparkArithmeticException = {
new SparkArithmeticException(
errorClass = "BINARY_ARITHMETIC_OVERFLOW",
messageParameters = Map(
"value1" -> toSQLValue(eval1, ShortType),
"symbol" -> symbol,
- "value2" -> toSQLValue(eval2, ShortType)),
+ "value2" -> toSQLValue(eval2, ShortType),
+ "functionName" -> toSQLId(suggestedFunc)),
context = Array.empty,
summary = "")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
index 19b1b5d..1c860e6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
@@ -24,27 +24,27 @@
import org.apache.spark.sql.types.Decimal.DecimalIsConflicted
private[sql] object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOrdering {
- private def checkOverflow(res: Int, x: Byte, y: Byte, op: String): Unit = {
+ private def checkOverflow(res: Int, x: Byte, y: Byte, op: String, hint: String): Unit = {
if (res > Byte.MaxValue || res < Byte.MinValue) {
- throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(x, op, y)
+ throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(x, op, y, hint)
}
}
override def plus(x: Byte, y: Byte): Byte = {
val tmp = x + y
- checkOverflow(tmp, x, y, "+")
+ checkOverflow(tmp, x, y, "+", "try_add")
tmp.toByte
}
override def minus(x: Byte, y: Byte): Byte = {
val tmp = x - y
- checkOverflow(tmp, x, y, "-")
+ checkOverflow(tmp, x, y, "-", "try_subtract")
tmp.toByte
}
override def times(x: Byte, y: Byte): Byte = {
val tmp = x * y
- checkOverflow(tmp, x, y, "*")
+ checkOverflow(tmp, x, y, "*", "try_multiply")
tmp.toByte
}
@@ -55,27 +55,27 @@
private[sql] object ShortExactNumeric extends ShortIsIntegral with Ordering.ShortOrdering {
- private def checkOverflow(res: Int, x: Short, y: Short, op: String): Unit = {
+ private def checkOverflow(res: Int, x: Short, y: Short, op: String, hint: String): Unit = {
if (res > Short.MaxValue || res < Short.MinValue) {
- throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(x, op, y)
+ throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(x, op, y, hint)
}
}
override def plus(x: Short, y: Short): Short = {
val tmp = x + y
- checkOverflow(tmp, x, y, "+")
+ checkOverflow(tmp, x, y, "+", "try_add")
tmp.toShort
}
override def minus(x: Short, y: Short): Short = {
val tmp = x - y
- checkOverflow(tmp, x, y, "-")
+ checkOverflow(tmp, x, y, "-", "try_subtract")
tmp.toShort
}
override def times(x: Short, y: Short): Short = {
val tmp = x * y
- checkOverflow(tmp, x, y, "*")
+ checkOverflow(tmp, x, y, "*", "try_multiply")
tmp.toShort
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
index 00dfd34..9d1448d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
@@ -767,7 +767,40 @@
parameters = Map(
"value1" -> "127S",
"symbol" -> "+",
- "value2" -> "5S"),
+ "value2" -> "5S",
+ "functionName" -> "`try_add`"),
+ sqlState = "22003")
+ }
+ }
+
+ test("BINARY_ARITHMETIC_OVERFLOW: byte minus byte result overflow") {
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ checkError(
+ exception = intercept[SparkArithmeticException] {
+ sql(s"select -2Y - 127Y").collect()
+ },
+ condition = "BINARY_ARITHMETIC_OVERFLOW",
+ parameters = Map(
+ "value1" -> "-2S",
+ "symbol" -> "-",
+ "value2" -> "127S",
+ "functionName" -> "`try_subtract`"),
+ sqlState = "22003")
+ }
+ }
+
+ test("BINARY_ARITHMETIC_OVERFLOW: byte multiply byte result overflow") {
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ checkError(
+ exception = intercept[SparkArithmeticException] {
+ sql(s"select 127Y * 5Y").collect()
+ },
+ condition = "BINARY_ARITHMETIC_OVERFLOW",
+ parameters = Map(
+ "value1" -> "127S",
+ "symbol" -> "*",
+ "value2" -> "5S",
+ "functionName" -> "`try_multiply`"),
sqlState = "22003")
}
}