[SPARK-47297][SQL] Add collation support for format expressions
### What changes were proposed in this pull request?
Introduce collation awareness for format expressions: to_number, try_to_number, to_char, space.
### Why are the changes needed?
Add collation support for format expressions in Spark.
### Does this PR introduce _any_ user-facing change?
Yes, users should now be able to use collated strings within arguments for format functions: to_number, try_to_number, to_char, space.
### How was this patch tested?
E2e sql tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #46423 from uros-db/format-expressions.
Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala
index 6d95d7e..e914190 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala
@@ -26,6 +26,8 @@
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.util.ToNumberParser
import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.types.StringTypeAnyCollation
import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, DatetimeType, Decimal, DecimalType, StringType}
import org.apache.spark.unsafe.types.UTF8String
@@ -47,7 +49,8 @@
DecimalType.USER_DEFAULT
}
- override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(StringTypeAnyCollation, StringTypeAnyCollation)
override def checkInputDataTypes(): TypeCheckResult = {
val inputTypeCheck = super.checkInputDataTypes()
@@ -247,8 +250,9 @@
inputExpr.dataType match {
case _: DatetimeType => DateFormatClass(inputExpr, format)
case _: BinaryType =>
- if (!(format.dataType == StringType && format.foldable)) {
- throw QueryCompilationErrors.nonFoldableArgumentError(funcName, "format", StringType)
+ if (!(format.dataType.isInstanceOf[StringType] && format.foldable)) {
+ throw QueryCompilationErrors.nonFoldableArgumentError(funcName, "format",
+ format.dataType)
}
val fmt = format.eval()
if (fmt == null) {
@@ -279,8 +283,8 @@
}
}
- override def dataType: DataType = StringType
- override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, StringType)
+ override def dataType: DataType = SQLConf.get.defaultStringType
+ override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, StringTypeAnyCollation)
override def checkInputDataTypes(): TypeCheckResult = {
val inputTypeCheck = super.checkInputDataTypes()
if (inputTypeCheck.isSuccess) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 0769c8e..c2ea17d 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -1906,7 +1906,7 @@
case class StringSpace(child: Expression)
extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
- override def dataType: DataType = StringType
+ override def dataType: DataType = SQLConf.get.defaultStringType
override def inputTypes: Seq[DataType] = Seq(IntegerType)
override def nullSafeEval(s: Any): Any = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
index 596923d..4314ff9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -19,9 +19,10 @@
import scala.collection.immutable.Seq
+import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.{MapType, StringType}
+import org.apache.spark.sql.types._
// scalastyle:off nonascii
class CollationSQLExpressionsSuite
@@ -330,6 +331,135 @@
})
}
+ test("Support StringSpace expression with collation") {
+ case class StringSpaceTestCase(
+ input: Int,
+ collationName: String,
+ result: String
+ )
+
+ val testCases = Seq(
+ StringSpaceTestCase(1, "UTF8_BINARY", " "),
+ StringSpaceTestCase(2, "UTF8_BINARY_LCASE", " "),
+ StringSpaceTestCase(3, "UNICODE", " "),
+ StringSpaceTestCase(4, "UNICODE_CI", " ")
+ )
+
+ // Supported collations
+ testCases.foreach(t => {
+ val query =
+ s"""
+ |select space(${t.input})
+ |""".stripMargin
+ // Result & data type
+ withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
+ val testQuery = sql(query)
+ checkAnswer(testQuery, Row(t.result))
+ val dataType = StringType(t.collationName)
+ assert(testQuery.schema.fields.head.dataType.sameType(dataType))
+ }
+ })
+ }
+
+ test("Support ToNumber & TryToNumber expressions with collation") {
+ case class ToNumberTestCase(
+ input: String,
+ collationName: String,
+ format: String,
+ result: Any,
+ resultType: DataType
+ )
+
+ val testCases = Seq(
+ ToNumberTestCase("123", "UTF8_BINARY", "999", 123, DecimalType(3, 0)),
+ ToNumberTestCase("1", "UTF8_BINARY_LCASE", "0.00", 1.00, DecimalType(3, 2)),
+ ToNumberTestCase("99,999", "UNICODE", "99,999", 99999, DecimalType(5, 0)),
+ ToNumberTestCase("$14.99", "UNICODE_CI", "$99.99", 14.99, DecimalType(4, 2))
+ )
+
+ // Supported collations (ToNumber)
+ testCases.foreach(t => {
+ val query =
+ s"""
+ |select to_number('${t.input}', '${t.format}')
+ |""".stripMargin
+ // Result & data type
+ withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
+ val testQuery = sql(query)
+ checkAnswer(testQuery, Row(t.result))
+ assert(testQuery.schema.fields.head.dataType.sameType(t.resultType))
+ }
+ })
+
+ // Supported collations (TryToNumber)
+ testCases.foreach(t => {
+ val query =
+ s"""
+ |select try_to_number('${t.input}', '${t.format}')
+ |""".stripMargin
+ // Result & data type
+ withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
+ val testQuery = sql(query)
+ checkAnswer(testQuery, Row(t.result))
+ assert(testQuery.schema.fields.head.dataType.sameType(t.resultType))
+ }
+ })
+ }
+
+ test("Handle invalid number for ToNumber variant expression with collation") {
+ // to_number should throw an exception if the conversion fails
+ val number = "xx"
+ val query = s"SELECT to_number('$number', '999');"
+ withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") {
+ val e = intercept[SparkIllegalArgumentException] {
+ val testQuery = sql(query)
+ testQuery.collect()
+ }
+ assert(e.getErrorClass === "INVALID_FORMAT.MISMATCH_INPUT")
+ }
+ }
+
+ test("Handle invalid number for TryToNumber variant expression with collation") {
+ // try_to_number shouldn't throw an exception if the conversion fails
+ val number = "xx"
+ val query = s"SELECT try_to_number('$number', '999');"
+ withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") {
+ val testQuery = sql(query)
+ checkAnswer(testQuery, Row(null))
+ }
+ }
+
+ test("Support ToChar expression with collation") {
+ case class ToCharTestCase(
+ input: Int,
+ collationName: String,
+ format: String,
+ result: String
+ )
+
+ val testCases = Seq(
+ ToCharTestCase(12, "UTF8_BINARY", "999", " 12"),
+ ToCharTestCase(34, "UTF8_BINARY_LCASE", "000D00", "034.00"),
+ ToCharTestCase(56, "UNICODE", "$99.99", "$56.00"),
+ ToCharTestCase(78, "UNICODE_CI", "99D9S", "78.0+")
+ )
+
+ // Supported collations
+ testCases.foreach(t => {
+ val query =
+ s"""
+ |select to_char(${t.input}, '${t.format}')
+ |""".stripMargin
+ // Result & data type
+ withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
+ val testQuery = sql(query)
+ checkAnswer(testQuery, Row(t.result))
+ val dataType = StringType(t.collationName)
+ assert(testQuery.schema.fields.head.dataType.sameType(dataType))
+ }
+ })
+ }
+
test("Support StringToMap expression with collation") {
// Supported collations
case class StringToMapTestCase[R](t: String, p: String, k: String, c: String, result: R)