[SPARK-48158][SQL] Add collation support for XML expressions
### What changes were proposed in this pull request?
Introduce collation awareness for XML expressions: from_xml, schema_of_xml, to_xml.
### Why are the changes needed?
Add collation support for XML expressions in Spark.
### Does this PR introduce _any_ user-facing change?
Yes, users should now be able to use collated strings within arguments for XML functions: from_xml, schema_of_xml, to_xml.
### How was this patch tested?
E2e sql tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #46507 from uros-db/xml-expressions.
Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala
index 415d55d..237d740 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala
@@ -27,6 +27,7 @@
import org.apache.spark.sql.catalyst.xml.{StaxXmlGenerator, StaxXmlParser, ValidatorUtil, XmlInferSchema, XmlOptions}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.types.StringTypeAnyCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -140,7 +141,7 @@
converter(parser.parse(str))
}
- override def inputTypes: Seq[AbstractDataType] = StringType :: Nil
+ override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil
override def sql: String = schema match {
case _: MapType => "entries"
@@ -178,7 +179,7 @@
child = child,
options = ExprUtils.convertToMapData(options))
- override def dataType: DataType = StringType
+ override def dataType: DataType = SQLConf.get.defaultStringType
override def nullable: Boolean = false
@@ -226,7 +227,7 @@
.map(ArrayType(_, containsNull = at.containsNull))
.getOrElse(ArrayType(StructType(Nil), containsNull = at.containsNull))
case other: DataType =>
- xmlInferSchema.canonicalizeType(other).getOrElse(StringType)
+ xmlInferSchema.canonicalizeType(other).getOrElse(SQLConf.get.defaultStringType)
}
UTF8String.fromString(dataType.sql)
@@ -320,7 +321,7 @@
getAndReset()
}
- override def dataType: DataType = StringType
+ override def dataType: DataType = SQLConf.get.defaultStringType
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
index 2b63901..dd5703d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import java.text.SimpleDateFormat
+
import scala.collection.immutable.Seq
import org.apache.spark.{SparkException, SparkIllegalArgumentException, SparkRuntimeException}
@@ -860,6 +862,128 @@
assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
}
+ test("Support XmlToStructs xml expression with collation") {
+ case class XmlToStructsTestCase(
+ input: String,
+ collationName: String,
+ schema: String,
+ options: String,
+ result: Row,
+ structFields: Seq[StructField]
+ )
+
+ val testCases = Seq(
+ XmlToStructsTestCase("<p><a>1</a></p>", "UTF8_BINARY", "'a INT'", "",
+ Row(1), Seq(
+ StructField("a", IntegerType, nullable = true)
+ )),
+ XmlToStructsTestCase("<p><A>true</A><B>0.8</B></p>", "UTF8_BINARY_LCASE",
+ "'A BOOLEAN, B DOUBLE'", "", Row(true, 0.8), Seq(
+ StructField("A", BooleanType, nullable = true),
+ StructField("B", DoubleType, nullable = true)
+ )),
+ XmlToStructsTestCase("<p><s>Spark</s></p>", "UNICODE", "'s STRING'", "",
+ Row("Spark"), Seq(
+ StructField("s", StringType("UNICODE"), nullable = true)
+ )),
+ XmlToStructsTestCase("<p><time>26/08/2015</time></p>", "UNICODE_CI", "'time Timestamp'",
+ ", map('timestampFormat', 'dd/MM/yyyy')", Row(
+ new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.S").parse("2015-08-26 00:00:00.0")
+ ), Seq(
+ StructField("time", TimestampType, nullable = true)
+ ))
+ )
+
+ // Supported collations
+ testCases.foreach(t => {
+ val query =
+ s"""
+ |select from_xml('${t.input}', ${t.schema} ${t.options})
+ |""".stripMargin
+ // Result
+ withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
+ val testQuery = sql(query)
+ checkAnswer(testQuery, Row(t.result))
+ val dataType = StructType(t.structFields)
+ assert(testQuery.schema.fields.head.dataType.sameType(dataType))
+ }
+ })
+ }
+
+ test("Support SchemaOfXml xml expression with collation") {
+ case class SchemaOfXmlTestCase(
+ input: String,
+ collationName: String,
+ result: String
+ )
+
+ val testCases = Seq(
+ SchemaOfXmlTestCase("<p><a>1</a></p>", "UTF8_BINARY", "STRUCT<a: BIGINT>"),
+ SchemaOfXmlTestCase("<p><A>true</A><B>0.8</B></p>", "UTF8_BINARY_LCASE",
+ "STRUCT<A: BOOLEAN, B: DOUBLE>"),
+ SchemaOfXmlTestCase("<p></p>", "UNICODE", "STRUCT<>"),
+ SchemaOfXmlTestCase("<p><A>1</A><A>2</A><A>3</A></p>", "UNICODE_CI",
+ "STRUCT<A: ARRAY<BIGINT>>")
+ )
+
+ // Supported collations
+ testCases.foreach(t => {
+ val query =
+ s"""
+ |select schema_of_xml('${t.input}')
+ |""".stripMargin
+ // Result
+ withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
+ val testQuery = sql(query)
+ checkAnswer(testQuery, Row(t.result))
+ val dataType = StringType(t.collationName)
+ assert(testQuery.schema.fields.head.dataType.sameType(dataType))
+ }
+ })
+ }
+
+ test("Support StructsToXml xml expression with collation") {
+ case class StructsToXmlTestCase(
+ input: String,
+ collationName: String,
+ result: String
+ )
+
+ val testCases = Seq(
+ StructsToXmlTestCase("named_struct('a', 1, 'b', 2)", "UTF8_BINARY",
+ s"""<ROW>
+ | <a>1</a>
+ | <b>2</b>
+ |</ROW>""".stripMargin),
+ StructsToXmlTestCase("named_struct('A', true, 'B', 2.0)", "UTF8_BINARY_LCASE",
+ s"""<ROW>
+ | <A>true</A>
+ | <B>2.0</B>
+ |</ROW>""".stripMargin),
+ StructsToXmlTestCase("named_struct()", "UNICODE",
+ "<ROW/>"),
+ StructsToXmlTestCase("named_struct('time', to_timestamp('2015-08-26'))", "UNICODE_CI",
+ s"""<ROW>
+ | <time>2015-08-26T00:00:00.000-07:00</time>
+ |</ROW>""".stripMargin)
+ )
+
+ // Supported collations
+ testCases.foreach(t => {
+ val query =
+ s"""
+ |select to_xml(${t.input})
+ |""".stripMargin
+ // Result
+ withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
+ val testQuery = sql(query)
+ checkAnswer(testQuery, Row(t.result))
+ val dataType = StringType(t.collationName)
+ assert(testQuery.schema.fields.head.dataType.sameType(dataType))
+ }
+ })
+ }
+
test("Support ParseJson & TryParseJson variant expressions with collation") {
case class ParseJsonTestCase(
input: String,