blob: 497b335289b112bba137300f9ba8be7b6bc43add [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.{SPARK_DOC_ROOT, SparkFunSuite, SparkRuntimeException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.TypeUtils.ordinalNumber
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
/**
* Runs through the testFunc for all integral data types.
*
* @param testFunc a test function that accepts a conversion function to convert an integer
* into another data type.
*/
private def testIntegralDataTypes(testFunc: (Int => Any) => Unit): Unit = {
testFunc(_.toByte)
testFunc(_.toShort)
testFunc(identity)
testFunc(_.toLong)
}
test("GetArrayItem") {
val typeA = ArrayType(StringType)
val array = Literal.create(Seq("a", "b"), typeA)
testIntegralDataTypes { convert =>
checkEvaluation(GetArrayItem(array, Literal(convert(1))), "b")
}
val nullArray = Literal.create(null, typeA)
val nullInt = Literal.create(null, IntegerType)
checkEvaluation(GetArrayItem(nullArray, Literal(1)), null)
checkEvaluation(GetArrayItem(array, nullInt), null)
checkEvaluation(GetArrayItem(nullArray, nullInt), null)
val nonNullArray = Literal.create(Seq(1), ArrayType(IntegerType, false))
checkEvaluation(GetArrayItem(nonNullArray, Literal(0)), 1)
val nestedArray = Literal.create(Seq(Seq(1)), ArrayType(ArrayType(IntegerType)))
checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1))
}
test("SPARK-33386: GetArrayItem ArrayIndexOutOfBoundsException") {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
val array = Literal.create(Seq("a", "b"), ArrayType(StringType))
if (ansiEnabled) {
checkExceptionInExpression[Exception](
GetArrayItem(array, Literal(5)),
"The index 5 is out of bounds. The array has 2 elements."
)
checkExceptionInExpression[Exception](
GetArrayItem(array, Literal(-1)),
"The index -1 is out of bounds. The array has 2 elements."
)
} else {
checkEvaluation(GetArrayItem(array, Literal(5)), null)
checkEvaluation(GetArrayItem(array, Literal(-1)), null)
}
}
}
}
test("SPARK-40066: GetMapValue returns null on invalid map value access") {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
val map = Literal.create(Map(1 -> "a", 2 -> "b"), MapType(IntegerType, StringType))
checkEvaluation(GetMapValue(map, Literal(5)), null)
}
}
}
test("SPARK-26637 handles GetArrayItem nullability correctly when input array size is constant") {
// CreateArray case
val a = AttributeReference("a", IntegerType, nullable = false)()
val b = AttributeReference("b", IntegerType, nullable = true)()
val array = CreateArray(a :: b :: Nil)
assert(!GetArrayItem(array, Literal(0)).nullable)
assert(GetArrayItem(array, Literal(1)).nullable)
assert(!GetArrayItem(array, Subtract(Literal(2), Literal(2))).nullable)
assert(GetArrayItem(array, AttributeReference("ordinal", IntegerType)()).nullable)
// GetArrayStructFields case
val f1 = StructField("a", IntegerType, nullable = false)
val f2 = StructField("b", IntegerType, nullable = true)
val structType = StructType(f1 :: f2 :: Nil)
val c = AttributeReference("c", structType, nullable = false)()
val inputArray1 = CreateArray(c :: Nil)
val inputArray1ContainsNull = c.nullable
val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull)
assert(!GetArrayItem(stArray1, Literal(0)).nullable)
val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull)
assert(GetArrayItem(stArray2, Literal(0)).nullable)
val d = AttributeReference("d", structType, nullable = true)()
val inputArray2 = CreateArray(c :: d :: Nil)
val inputArray2ContainsNull = c.nullable || d.nullable
val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull)
assert(!GetArrayItem(stArray3, Literal(0)).nullable)
assert(GetArrayItem(stArray3, Literal(1)).nullable)
val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull)
assert(GetArrayItem(stArray4, Literal(0)).nullable)
assert(GetArrayItem(stArray4, Literal(1)).nullable)
}
test("GetMapValue") {
val typeM = MapType(StringType, StringType)
val map = Literal.create(Map("a" -> "b"), typeM)
val nullMap = Literal.create(null, typeM)
val nullString = Literal.create(null, StringType)
checkEvaluation(GetMapValue(map, Literal("a")), "b")
checkEvaluation(GetMapValue(map, nullString), null)
checkEvaluation(GetMapValue(nullMap, nullString), null)
checkEvaluation(GetMapValue(map, nullString), null)
val nonNullMap = Literal.create(Map("a" -> 1), MapType(StringType, IntegerType, false))
checkEvaluation(GetMapValue(nonNullMap, Literal("a")), 1)
val nestedMap = Literal.create(Map("a" -> Map("b" -> "c")), MapType(StringType, typeM))
checkEvaluation(GetMapValue(nestedMap, Literal("a")), Map("b" -> "c"))
}
test("GetStructField") {
val typeS = StructType(StructField("a", IntegerType) :: Nil)
val struct = Literal.create(create_row(1), typeS)
val nullStruct = Literal.create(null, typeS)
def getStructField(expr: Expression, fieldName: String): GetStructField = {
expr.dataType match {
case StructType(fields) =>
val index = fields.indexWhere(_.name == fieldName)
GetStructField(expr, index)
}
}
checkEvaluation(getStructField(struct, "a"), 1)
checkEvaluation(getStructField(nullStruct, "a"), null)
val nestedStruct = Literal.create(create_row(create_row(1)),
StructType(StructField("a", typeS) :: Nil))
checkEvaluation(getStructField(nestedStruct, "a"), create_row(1))
val typeS_fieldNotNullable = StructType(StructField("a", IntegerType, false) :: Nil)
val struct_fieldNotNullable = Literal.create(create_row(1), typeS_fieldNotNullable)
val nullStruct_fieldNotNullable = Literal.create(null, typeS_fieldNotNullable)
assert(getStructField(struct_fieldNotNullable, "a").nullable === false)
assert(getStructField(struct, "a").nullable)
assert(getStructField(nullStruct_fieldNotNullable, "a").nullable)
assert(getStructField(nullStruct, "a").nullable)
}
test("GetArrayStructFields") {
// test 4 types: struct field nullability X array element nullability
val type1 = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val type2 = ArrayType(StructType(StructField("a", IntegerType, nullable = false) :: Nil))
val type3 = ArrayType(StructType(StructField("a", IntegerType) :: Nil), containsNull = false)
val type4 = ArrayType(
StructType(StructField("a", IntegerType, nullable = false) :: Nil), containsNull = false)
val input1 = Literal.create(Seq(create_row(1)), type4)
val input2 = Literal.create(Seq(create_row(null)), type3)
val input3 = Literal.create(Seq(null), type2)
val input4 = Literal.create(null, type1)
def getArrayStructFields(expr: Expression, fieldName: String): Expression = {
ExtractValue.apply(expr, Literal.create(fieldName, StringType), _ == _)
}
checkEvaluation(getArrayStructFields(input1, "a"), Seq(1))
checkEvaluation(getArrayStructFields(input2, "a"), Seq(null))
checkEvaluation(getArrayStructFields(input3, "a"), Seq(null))
checkEvaluation(getArrayStructFields(input4, "a"), null)
}
test("SPARK-32167: nullability of GetArrayStructFields") {
val resolver = SQLConf.get.resolver
val array1 = ArrayType(
new StructType().add("a", "int", nullable = true),
containsNull = false)
val data1 = Literal.create(Seq(Row(null)), array1)
val get1 = ExtractValue(data1, Literal("a"), resolver).asInstanceOf[GetArrayStructFields]
assert(get1.containsNull)
val array2 = ArrayType(
new StructType().add("a", "int", nullable = false),
containsNull = true)
val data2 = Literal.create(Seq(null), array2)
val get2 = ExtractValue(data2, Literal("a"), resolver).asInstanceOf[GetArrayStructFields]
assert(get2.containsNull)
val array3 = ArrayType(
new StructType().add("a", "int", nullable = false),
containsNull = false)
val data3 = Literal.create(Seq(Row(1)), array3)
val get3 = ExtractValue(data3, Literal("a"), resolver).asInstanceOf[GetArrayStructFields]
assert(!get3.containsNull)
}
test("CreateArray") {
val intSeq = Seq(5, 10, 15, 20, 25)
val longSeq = intSeq.map(_.toLong)
val byteSeq = intSeq.map(_.toByte)
val strSeq = intSeq.map(_.toString)
checkEvaluation(CreateArray(intSeq.map(Literal(_))), intSeq, EmptyRow)
checkEvaluation(CreateArray(longSeq.map(Literal(_))), longSeq, EmptyRow)
checkEvaluation(CreateArray(byteSeq.map(Literal(_))), byteSeq, EmptyRow)
checkEvaluation(CreateArray(strSeq.map(Literal(_))), strSeq, EmptyRow)
val intWithNull = intSeq.map(Literal(_)) :+ Literal.create(null, IntegerType)
val longWithNull = longSeq.map(Literal(_)) :+ Literal.create(null, LongType)
val byteWithNull = byteSeq.map(Literal(_)) :+ Literal.create(null, ByteType)
val strWithNull = strSeq.map(Literal(_)) :+ Literal.create(null, StringType)
checkEvaluation(CreateArray(intWithNull), intSeq :+ null, EmptyRow)
checkEvaluation(CreateArray(longWithNull), longSeq :+ null, EmptyRow)
checkEvaluation(CreateArray(byteWithNull), byteSeq :+ null, EmptyRow)
checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow)
checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil)
val array = CreateArray(Seq(
Literal.create(intSeq, ArrayType(IntegerType, containsNull = false)),
Literal.create(intSeq :+ null, ArrayType(IntegerType, containsNull = true))))
assert(array.dataType ===
ArrayType(ArrayType(IntegerType, containsNull = true), containsNull = false))
checkEvaluation(array, Seq(intSeq, intSeq :+ null))
}
test("CreateMap") {
def interlace(keys: Seq[Literal], values: Seq[Literal]): Seq[Literal] = {
keys.zip(values).flatMap { case (k, v) => Seq(k, v) }
}
val intSeq = Seq(5, 10, 15, 20, 25)
val longSeq = intSeq.map(_.toLong)
val strSeq = intSeq.map(_.toString)
checkEvaluation(CreateMap(Nil), Map.empty)
checkEvaluation(
CreateMap(interlace(intSeq.map(Literal(_)), longSeq.map(Literal(_)))),
create_map(intSeq, longSeq))
checkEvaluation(
CreateMap(interlace(strSeq.map(Literal(_)), longSeq.map(Literal(_)))),
create_map(strSeq, longSeq))
checkEvaluation(
CreateMap(interlace(longSeq.map(Literal(_)), strSeq.map(Literal(_)))),
create_map(longSeq, strSeq))
val strWithNull = strSeq.drop(1).map(Literal(_)) :+ Literal.create(null, StringType)
checkEvaluation(
CreateMap(interlace(intSeq.map(Literal(_)), strWithNull)),
create_map(intSeq, strWithNull.map(_.value)))
// Map key can't be null
checkErrorInExpression[SparkRuntimeException](
CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))),
"NULL_MAP_KEY")
checkErrorInExpression[SparkRuntimeException](
CreateMap(Seq(Literal(1), Literal(2), Literal(1), Literal(3))),
errorClass = "DUPLICATED_MAP_KEY",
parameters = Map(
"key" -> "1",
"mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"")
)
withSQLConf(SQLConf.MAP_KEY_DEDUP_POLICY.key -> SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) {
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(
CreateMap(Seq(Literal(1), Literal(2), Literal(1), Literal(3))),
create_map(1 -> 3))
}
// ArrayType map key and value
val map = CreateMap(Seq(
Literal.create(intSeq, ArrayType(IntegerType, containsNull = false)),
Literal.create(strSeq, ArrayType(StringType, containsNull = false)),
Literal.create(intSeq :+ null, ArrayType(IntegerType, containsNull = true)),
Literal.create(strSeq :+ null, ArrayType(StringType, containsNull = true))))
assert(map.dataType ===
MapType(
ArrayType(IntegerType, containsNull = true),
ArrayType(StringType, containsNull = true),
valueContainsNull = false))
checkEvaluation(map, create_map(intSeq -> strSeq, (intSeq :+ null) -> (strSeq :+ null)))
// map key can't be map
val map2 = CreateMap(Seq(
Literal.create(create_map(1 -> 1), MapType(IntegerType, IntegerType)),
Literal(1)
))
map2.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key")
case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) =>
assert(errorSubClass == "INVALID_MAP_KEY_TYPE")
assert(messageParameters === Map("keyType" -> "\"MAP<INT, INT>\""))
}
// expects a positive even number of arguments
val map3 = CreateMap(Seq(Literal(1), Literal(2), Literal(3)))
checkError(
exception = intercept[AnalysisException] {
map3.checkInputDataTypes()
},
errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION",
parameters = Map(
"functionName" -> "`map`",
"expectedNum" -> "2n (n > 0)",
"actualNum" -> "3",
"docroot" -> SPARK_DOC_ROOT)
)
// The given keys of function map should all be the same type
val map4 = CreateMap(Seq(Literal(1), Literal(2), Literal('a'), Literal(3)))
assert(map4.checkInputDataTypes() ==
DataTypeMismatch(
errorSubClass = "CREATE_MAP_KEY_DIFF_TYPES",
messageParameters = Map(
"functionName" -> "`map`",
"dataType" -> "[\"INT\", \"STRING\"]")
)
)
// The given values of function map should all be the same type
val map5 = CreateMap(Seq(Literal(1), Literal(2), Literal(3), Literal('a')))
assert(map5.checkInputDataTypes() ==
DataTypeMismatch(
errorSubClass = "CREATE_MAP_VALUE_DIFF_TYPES",
messageParameters = Map(
"functionName" -> "`map`",
"dataType" -> "[\"INT\", \"STRING\"]")
)
)
}
// map key can't be variant
val map6 = CreateMap(Seq(
Literal.create(new VariantVal(Array[Byte](), Array[Byte]())),
Literal.create(1)
))
map6.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess => fail("should not allow variant as a part of map key")
case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) =>
assert(errorSubClass == "INVALID_MAP_KEY_TYPE")
assert(messageParameters === Map("keyType" -> "\"VARIANT\""))
}
// map key can't contain variant
val map7 = CreateMap(
Seq(
CreateStruct(
Seq(Literal.create(1), Literal.create(new VariantVal(Array[Byte](), Array[Byte]())))
),
Literal.create(1)
)
)
map7.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess => fail("should not allow variant as a part of map key")
case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) =>
assert(errorSubClass == "INVALID_MAP_KEY_TYPE")
assert(
messageParameters === Map(
"keyType" -> "\"STRUCT<col1: INT NOT NULL, col2: VARIANT NOT NULL>\""
)
)
}
test("MapFromArrays") {
val intSeq = Seq(5, 10, 15, 20, 25)
val longSeq = intSeq.map(_.toLong)
val strSeq = intSeq.map(_.toString)
val integerSeq = Seq[java.lang.Integer](5, 10, 15, 20, 25)
val intWithNullSeq = Seq[java.lang.Integer](5, 10, null, 20, 25)
val longWithNullSeq = intSeq.map(java.lang.Long.valueOf(_))
val intArray = Literal.create(intSeq, ArrayType(IntegerType, false))
val longArray = Literal.create(longSeq, ArrayType(LongType, false))
val strArray = Literal.create(strSeq, ArrayType(StringType, false))
val integerArray = Literal.create(integerSeq, ArrayType(IntegerType, true))
val intWithNullArray = Literal.create(intWithNullSeq, ArrayType(IntegerType, true))
val longWithNullArray = Literal.create(longWithNullSeq, ArrayType(LongType, true))
val nullArray = Literal.create(null, ArrayType(StringType, false))
checkEvaluation(MapFromArrays(intArray, longArray), create_map(intSeq, longSeq))
checkEvaluation(MapFromArrays(intArray, strArray), create_map(intSeq, strSeq))
checkEvaluation(MapFromArrays(integerArray, strArray), create_map(integerSeq, strSeq))
checkEvaluation(
MapFromArrays(strArray, intWithNullArray), create_map(strSeq, intWithNullSeq))
checkEvaluation(
MapFromArrays(strArray, longWithNullArray), create_map(strSeq, longWithNullSeq))
checkEvaluation(
MapFromArrays(strArray, longWithNullArray), create_map(strSeq, longWithNullSeq))
checkEvaluation(MapFromArrays(nullArray, nullArray), null)
// Map key can't be null
checkErrorInExpression[SparkRuntimeException](
MapFromArrays(intWithNullArray, strArray),
"NULL_MAP_KEY")
checkErrorInExpression[SparkRuntimeException](
MapFromArrays(
Literal.create(Seq(1, 1), ArrayType(IntegerType)),
Literal.create(Seq(2, 3), ArrayType(IntegerType))),
errorClass = "DUPLICATED_MAP_KEY",
parameters = Map(
"key" -> "1",
"mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"")
)
withSQLConf(SQLConf.MAP_KEY_DEDUP_POLICY.key -> SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) {
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(
MapFromArrays(
Literal.create(Seq(1, 1), ArrayType(IntegerType)),
Literal.create(Seq(2, 3), ArrayType(IntegerType))),
create_map(1 -> 3))
}
// map key can't be map
val arrayOfMap = Seq(create_map(1 -> "a", 2 -> "b"))
val map = MapFromArrays(
Literal.create(arrayOfMap, ArrayType(MapType(IntegerType, StringType))),
Literal.create(Seq(1), ArrayType(IntegerType)))
map.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key")
case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) =>
assert(errorSubClass == "INVALID_MAP_KEY_TYPE")
assert(messageParameters === Map("keyType" -> "\"MAP<INT, STRING>\""))
}
}
test("CreateStruct") {
val row = create_row(1, 2, 3)
val c1 = $"a".int.at(0)
val c3 = $"c".int.at(2)
checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row)
checkEvaluation(CreateStruct(Literal.create(null, LongType) :: Nil), create_row(null))
}
test("CreateNamedStruct") {
val row = create_row(1, 2, 3)
val c1 = $"a".int.at(0)
val c3 = $"c".int.at(2)
checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), create_row(1, 3), row)
checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")),
create_row(1, UTF8String.fromString("y")), row)
checkEvaluation(CreateNamedStruct(Seq("a", "x", "b", 2.0)),
create_row(UTF8String.fromString("x"), 2.0))
checkEvaluation(CreateNamedStruct(Seq("a", Literal.create(null, IntegerType))),
create_row(null))
// expects a positive even number of arguments
val namedStruct1 = CreateNamedStruct(Seq(Literal(1), Literal(2), Literal(3)))
checkError(
exception = intercept[AnalysisException] {
namedStruct1.checkInputDataTypes()
},
errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION",
parameters = Map(
"functionName" -> "`named_struct`",
"expectedNum" -> "2n (n > 0)",
"actualNum" -> "3",
"docroot" -> SPARK_DOC_ROOT)
)
}
test("test dsl for complex type") {
def quickResolve(u: UnresolvedExtractValue): Expression = {
ExtractValue(u.child, u.extraction, _ == _)
}
checkEvaluation(quickResolve(Symbol("c")
.map(MapType(StringType, StringType)).at(0).getItem("a")),
"b", create_row(Map("a" -> "b")))
checkEvaluation(quickResolve($"c".array(StringType).at(0).getItem(1)),
"b", create_row(Seq("a", "b")))
checkEvaluation(quickResolve($"c".struct($"a".int).at(0).getField("a")),
1, create_row(create_row(1)))
}
test("ensure to preserve metadata") {
val metadata = new MetadataBuilder()
.putString("key", "value")
.build()
def checkMetadata(expr: Expression): Unit = {
assert(expr.dataType.asInstanceOf[StructType]("a").metadata === metadata)
assert(expr.dataType.asInstanceOf[StructType]("b").metadata === Metadata.empty)
}
val a = AttributeReference("a", IntegerType, metadata = metadata)()
val b = AttributeReference("b", IntegerType)()
checkMetadata(CreateStruct(Seq(a, b)))
checkMetadata(CreateNamedStruct(Seq("a", a, "b", b)))
}
test("StringToMap") {
val expectedDataType = MapType(StringType, StringType, valueContainsNull = true)
assert(new StringToMap("").dataType === expectedDataType)
val s0 = Literal("a:1,b:2,c:3")
val m0 = Map("a" -> "1", "b" -> "2", "c" -> "3")
checkEvaluation(new StringToMap(s0), m0)
val s1 = Literal("a: ,b:2")
val m1 = Map("a" -> " ", "b" -> "2")
checkEvaluation(new StringToMap(s1), m1)
val s2 = Literal("a=1,b=2,c=3")
val m2 = Map("a" -> "1", "b" -> "2", "c" -> "3")
checkEvaluation(StringToMap(s2, Literal(","), Literal("=")), m2)
val s3 = Literal("")
val m3 = Map[String, String]("" -> null)
checkEvaluation(StringToMap(s3, Literal(","), Literal("=")), m3)
val s4 = Literal("a:1_b:2_c:3")
val m4 = Map("a" -> "1", "b" -> "2", "c" -> "3")
checkEvaluation(new StringToMap(s4, Literal("_")), m4)
val s5 = Literal("a")
val m5 = Map("a" -> null)
checkEvaluation(new StringToMap(s5), m5)
val s6 = Literal("a=1&b=2&c=3")
val m6 = Map("a" -> "1", "b" -> "2", "c" -> "3")
checkEvaluation(StringToMap(s6, NonFoldableLiteral("&"), NonFoldableLiteral("=")), m6)
checkErrorInExpression[SparkRuntimeException](
new StringToMap(Literal("a:1,b:2,a:3")),
errorClass = "DUPLICATED_MAP_KEY",
parameters = Map(
"key" -> "a",
"mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"")
)
withSQLConf(SQLConf.MAP_KEY_DEDUP_POLICY.key -> SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) {
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(
new StringToMap(Literal("a:1,b:2,a:3")),
create_map("a" -> "3", "b" -> "2"))
}
// arguments checking
assert(new StringToMap(Literal("a:1,b:2,c:3")).checkInputDataTypes().isSuccess)
assert(new StringToMap(Literal(null)).checkInputDataTypes() ==
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(0),
"requiredType" -> "\"STRING\"",
"inputSql" -> "\"NULL\"",
"inputType" -> "\"VOID\""
)
)
)
assert(new StringToMap(Literal("a:1,b:2,c:3"), Literal(null)).checkInputDataTypes() ==
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(1),
"requiredType" -> "\"STRING\"",
"inputSql" -> "\"NULL\"",
"inputType" -> "\"VOID\""
)
)
)
assert(StringToMap(Literal("a:1,b:2,c:3"), Literal(null),
Literal(null)).checkInputDataTypes() ==
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(1),
"requiredType" -> "\"STRING\"",
"inputSql" -> "\"NULL\"",
"inputType" -> "\"VOID\""
)
)
)
assert(new StringToMap(Literal(null), Literal(null)).checkInputDataTypes() ==
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(0),
"requiredType" -> "\"STRING\"",
"inputSql" -> "\"NULL\"",
"inputType" -> "\"VOID\""
)
)
)
}
test("SPARK-22693: CreateNamedStruct should not use global variables") {
val ctx = new CodegenContext
CreateNamedStruct(Seq("a", "x", "b", 2.0)).genCode(ctx)
assert(ctx.inlinedMutableStates.isEmpty)
}
test("SPARK-33338: semanticEquals should handle static GetMapValue correctly") {
val keys = new Array[UTF8String](1)
val values = new Array[UTF8String](1)
keys(0) = UTF8String.fromString("key")
values(0) = UTF8String.fromString("value")
val d1 = new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values))
val d2 = new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values))
val m1 = GetMapValue(Literal.create(d1, MapType(StringType, StringType)), Literal("a"))
val m2 = GetMapValue(Literal.create(d2, MapType(StringType, StringType)), Literal("a"))
assert(m1.semanticEquals(m2))
}
test("SPARK-40315: Literals of ArrayBasedMapData should have deterministic hashCode.") {
val keys = new Array[UTF8String](1)
val values1 = new Array[UTF8String](1)
val values2 = new Array[UTF8String](1)
keys(0) = UTF8String.fromString("key")
values1(0) = UTF8String.fromString("value1")
values2(0) = UTF8String.fromString("value2")
val d1 = new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values1))
val d2 = new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values1))
val d3 = new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values2))
val m1 = Literal.create(d1, MapType(StringType, StringType))
val m2 = Literal.create(d2, MapType(StringType, StringType))
val m3 = Literal.create(d3, MapType(StringType, StringType))
// If two Literals of ArrayBasedMapData have the same elements, we expect them to be equal and
// to have the same hashCode().
assert(m1 == m2)
assert(m1.hashCode() == m2.hashCode())
// If two Literals of ArrayBasedMapData have different elements, we expect them not to be equal
// and to have different hashCode().
assert(m1 != m3)
assert(m1.hashCode() != m3.hashCode())
}
}