[SPARK-48224][SQL] Disallow map keys from being of variant type
### What changes were proposed in this pull request?
This PR disallows map keys from being of variant type. Therefore, SQL statements like `select map(parse_json('{"a": 1}'), 1)`, which would work earlier, will throw an exception now.
### Why are the changes needed?
Allowing variant to be the key type of a map can result in undefined behavior as this has not been tested.
### Does this PR introduce _any_ user-facing change?
Yes, users could use variants as keys in maps earlier. However, this PR disallows this possibility.
### How was this patch tested?
Unit tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #46516 from harshmotw-db/map_variant_key.
Authored-by: Harsh Motwani <harsh.motwani@databricks.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index d2c708b..a0d578c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -58,7 +58,7 @@
}
def checkForMapKeyType(keyType: DataType): TypeCheckResult = {
- if (keyType.existsRecursively(_.isInstanceOf[MapType])) {
+ if (keyType.existsRecursively(dt => dt.isInstanceOf[MapType] || dt.isInstanceOf[VariantType])) {
DataTypeMismatch(
errorSubClass = "INVALID_MAP_KEY_TYPE",
messageParameters = Map(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 5f135e4..497b335 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -28,7 +28,7 @@
import org.apache.spark.sql.catalyst.util.TypeUtils.ordinalNumber
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -359,6 +359,38 @@
)
}
+ // map key can't be variant
+ val map6 = CreateMap(Seq(
+ Literal.create(new VariantVal(Array[Byte](), Array[Byte]())),
+ Literal.create(1)
+ ))
+ map6.checkInputDataTypes() match {
+ case TypeCheckResult.TypeCheckSuccess => fail("should not allow variant as a part of map key")
+ case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) =>
+ assert(errorSubClass == "INVALID_MAP_KEY_TYPE")
+ assert(messageParameters === Map("keyType" -> "\"VARIANT\""))
+ }
+
+ // map key can't contain variant
+ val map7 = CreateMap(
+ Seq(
+ CreateStruct(
+ Seq(Literal.create(1), Literal.create(new VariantVal(Array[Byte](), Array[Byte]())))
+ ),
+ Literal.create(1)
+ )
+ )
+ map7.checkInputDataTypes() match {
+ case TypeCheckResult.TypeCheckSuccess => fail("should not allow variant as a part of map key")
+ case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) =>
+ assert(errorSubClass == "INVALID_MAP_KEY_TYPE")
+ assert(
+ messageParameters === Map(
+ "keyType" -> "\"STRUCT<col1: INT NOT NULL, col2: VARIANT NOT NULL>\""
+ )
+ )
+ }
+
test("MapFromArrays") {
val intSeq = Seq(5, 10, 15, 20, 25)
val longSeq = intSeq.map(_.toLong)