[SPARK-48186][SQL] Add support for AbstractMapType
### What changes were proposed in this pull request?
Addition of an abstract MapType (similar to abstract ArrayType in sql internal types) which accepts `StringTypeCollated` as `keyType` & `valueType`. Apart from extending this interface for all Spark functions, this PR also introduces collation awareness for json expression: schema_of_json.
### Why are the changes needed?
This is needed in order to enable collation support for functions that use collated maps.
### Does this PR introduce _any_ user-facing change?
Yes, users should now be able to use collated strings within arguments for json function: schema_of_json.
### How was this patch tested?
E2e sql tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #46458 from uros-db/abstract-map.
Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractMapType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractMapType.scala
new file mode 100644
index 0000000..62f422f
--- /dev/null
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractMapType.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal.types
+
+import org.apache.spark.sql.types.{AbstractDataType, DataType, MapType}
+
+
+/**
+ * Use AbstractMapType(AbstractDataType, AbstractDataType)
+ * for defining expected types for expression parameters.
+ */
+case class AbstractMapType(
+ keyType: AbstractDataType,
+ valueType: AbstractDataType
+ ) extends AbstractDataType {
+
+ override private[sql] def defaultConcreteType: DataType =
+ MapType(keyType.defaultConcreteType, valueType.defaultConcreteType, valueContainsNull = true)
+
+ override private[sql] def acceptsType(other: DataType): Boolean = {
+ other.isInstanceOf[MapType] &&
+ keyType.acceptsType(other.asInstanceOf[MapType].keyType) &&
+ valueType.acceptsType(other.asInstanceOf[MapType].valueType)
+ }
+
+ override private[spark] def simpleString: String =
+ s"map<${keyType.simpleString}, ${valueType.simpleString}>"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
index 258bc0e..fde2093 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
@@ -28,6 +28,7 @@
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors}
+import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeAnyCollation}
import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType, VariantType}
import org.apache.spark.unsafe.types.UTF8String
@@ -57,7 +58,7 @@
def convertToMapData(exp: Expression): Map[String, String] = exp match {
case m: CreateMap
- if m.dataType.acceptsType(MapType(StringType, StringType, valueContainsNull = false)) =>
+ if AbstractMapType(StringTypeAnyCollation, StringTypeAnyCollation).acceptsType(m.dataType) =>
val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData]
ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) =>
key.toString -> value.toString
@@ -77,7 +78,7 @@
columnNameOfCorruptRecord: String): Unit = {
schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
val f = schema(corruptFieldIndex)
- if (f.dataType != StringType || !f.nullable) {
+ if (!f.dataType.isInstanceOf[StringType] || !f.nullable) {
throw QueryCompilationErrors.invalidFieldTypeForCorruptRecordError()
}
}
@@ -110,7 +111,7 @@
*/
def checkJsonSchema(schema: DataType): TypeCheckResult = {
val isInvalid = schema.existsRecursively {
- case MapType(keyType, _, _) if keyType != StringType => true
+ case MapType(keyType, _, _) if !keyType.isInstanceOf[StringType] => true
case _ => false
}
if (isInvalid) {
@@ -133,7 +134,7 @@
def checkXmlSchema(schema: DataType): TypeCheckResult = {
val isInvalid = schema.existsRecursively {
// XML field names must be StringType
- case MapType(keyType, _, _) if keyType != StringType => true
+ case MapType(keyType, _, _) if !keyType.isInstanceOf[StringType] => true
case _ => false
}
if (isInvalid) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index 8258bb3..7005d66 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -875,7 +875,7 @@
child = child,
options = ExprUtils.convertToMapData(options))
- override def dataType: DataType = StringType
+ override def dataType: DataType = SQLConf.get.defaultStringType
override def nullable: Boolean = false
@@ -921,7 +921,8 @@
.map(ArrayType(_, containsNull = at.containsNull))
.getOrElse(ArrayType(StructType(Nil), containsNull = at.containsNull))
case other: DataType =>
- jsonInferSchema.canonicalizeType(other, jsonOptions).getOrElse(StringType)
+ jsonInferSchema.canonicalizeType(other, jsonOptions).getOrElse(
+ SQLConf.get.defaultStringType)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
index 19f34ec..530a776 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -658,6 +658,40 @@
})
}
+ test("Support SchemaOfJson json expression with collation") {
+ case class SchemaOfJsonTestCase(
+ input: String,
+ collationName: String,
+ result: Row
+ )
+
+ val testCases = Seq(
+ SchemaOfJsonTestCase("'[{\"col\":0}]'",
+ "UTF8_BINARY", Row("ARRAY<STRUCT<col: BIGINT>>")),
+ SchemaOfJsonTestCase("'[{\"col\":01}]', map('allowNumericLeadingZeros', 'true')",
+ "UTF8_BINARY_LCASE", Row("ARRAY<STRUCT<col: BIGINT>>")),
+ SchemaOfJsonTestCase("'[]'",
+ "UNICODE", Row("ARRAY<STRING>")),
+ SchemaOfJsonTestCase("''",
+ "UNICODE_CI", Row("STRING"))
+ )
+
+ // Supported collations
+ testCases.foreach(t => {
+ val query =
+ s"""
+ |SELECT schema_of_json(${t.input})
+ |""".stripMargin
+ // Result & data type
+ withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
+ val testQuery = sql(query)
+ checkAnswer(testQuery, t.result)
+ val dataType = StringType(t.collationName)
+ assert(testQuery.schema.fields.head.dataType.sameType(dataType))
+ }
+ })
+ }
+
test("Support StringToMap expression with collation") {
// Supported collations
case class StringToMapTestCase[R](t: String, p: String, k: String, c: String, result: R)