blob: 73abf8074e8c2ac4c4b3e365a1d8eb27229a4238 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.variant
import java.time.{LocalDateTime, ZoneId, ZoneOffset}
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.types.variant.VariantUtil._
import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
// Zero-extend each byte in the array with the appropriate number of bytes.
// Used to manually construct variant binary values with a given offset size.
// E.g. padded(Array(1,2,3), 3) will produce Array(1,0,0,2,0,0,3,0,0).
private def padded(a: Array[Byte], size: Int): Array[Byte] = {
a.flatMap { b =>
val padding = List.fill(size - 1)(0.toByte)
b :: padding
}
}
test("to_json malformed") {
def check(value: Array[Byte], metadata: Array[Byte],
errorClass: String = "MALFORMED_VARIANT"): Unit = {
checkErrorInExpression[SparkRuntimeException](
ResolveTimeZone.resolveTimeZones(
StructsToJson(Map.empty, Literal(new VariantVal(value, metadata)))),
errorClass
)
}
val emptyMetadata = Array[Byte](VERSION, 0, 0)
// INT8 only has 7 byte content.
check(Array(primitiveHeader(INT8), 0, 0, 0, 0, 0, 0, 0), emptyMetadata)
// DECIMAL16 only has 15 byte content.
check(Array(primitiveHeader(DECIMAL16)) ++ Array.fill(16)(0.toByte), emptyMetadata)
// 1e38 has a precision of 39. Even if it still fits into 16 bytes, it is not a valid decimal.
check(Array[Byte](primitiveHeader(DECIMAL16), 0) ++
BigDecimal(1e38).toBigInt.toByteArray.reverse, emptyMetadata)
// Short string content too short.
check(Array(shortStrHeader(2), 'x'), emptyMetadata)
// Long string length too short (requires 4 bytes).
check(Array(primitiveHeader(LONG_STR), 0, 0, 0), emptyMetadata)
// Long string content too short.
check(Array(primitiveHeader(LONG_STR), 1, 0, 0, 0), emptyMetadata)
// Size is 1 but no content.
check(Array(arrayHeader(false, 1),
/* size */ 1,
/* offset list */ 0), emptyMetadata)
// Requires 4-byte size is but the actual size only has one byte.
check(Array(arrayHeader(true, 1),
/* size */ 0,
/* offset list */ 0), emptyMetadata)
// Offset out of bound.
check(Array(arrayHeader(false, 1),
/* size */ 1,
/* offset list */ 1, 1), emptyMetadata)
// Id out of bound.
check(Array(objectHeader(false, 1, 1),
/* size */ 1,
/* id list */ 0,
/* offset list */ 0, 2,
/* field data */ primitiveHeader(INT1), 1), emptyMetadata)
// Variant version is not 1.
check(Array(primitiveHeader(INT1), 0), Array[Byte](3, 0, 0))
check(Array(primitiveHeader(INT1), 0), Array[Byte](2, 0, 0))
// Construct binary values that are over 1 << 24 bytes, but otherwise valid.
val bigVersion = Array[Byte]((VERSION | (3 << 6)).toByte)
val a = Array.fill(1 << 24)('a'.toByte)
val hugeMetadata = bigVersion ++ Array[Byte](2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1) ++
a ++ Array[Byte]('b')
check(Array(primitiveHeader(TRUE)), hugeMetadata, "VARIANT_CONSTRUCTOR_SIZE_LIMIT")
// The keys are 'aaa....' and 'b'. Values are "yyy..." and 'true'.
val y = Array.fill(1 << 24)('y'.toByte)
val hugeObject = Array[Byte](objectHeader(true, 4, 4)) ++
/* size */ padded(Array(2), 4) ++
/* id list */ padded(Array(0, 1), 4) ++
// Second value starts at offset 5 + (1 << 24), which is `5001` little-endian. The last value
// is 1 byte, so the one-past-the-end value is `6001`
/* offset list */ Array[Byte](0, 0, 0, 0, 5, 0, 0, 1, 6, 0, 0, 1) ++
/* field data */ Array[Byte](primitiveHeader(LONG_STR), 0, 0, 0, 1) ++ y ++ Array[Byte](
primitiveHeader(TRUE)
)
val smallMetadata = bigVersion ++ Array[Byte](2, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0) ++
Array[Byte]('a', 'b')
check(hugeObject, smallMetadata, "VARIANT_CONSTRUCTOR_SIZE_LIMIT")
check(hugeObject, hugeMetadata, "VARIANT_CONSTRUCTOR_SIZE_LIMIT")
}
// Test valid forms of Variant that our writer would never produce.
test("to_json valid input") {
def check(expectedJson: String, value: Array[Byte], metadata: Array[Byte]): Unit = {
checkEvaluation(
StructsToJson(Map.empty, Literal(new VariantVal(value, metadata))),
expectedJson
)
}
// Some valid metadata formats. Check that they aren't rejected.
// Sorted string bit is set, and can be ignored.
val emptyMetadata2 = Array[Byte](VERSION | 1 << 4, 0, 0)
// Bit 5 is not defined in the spec, and can be ignored.
val emptyMetadata3 = Array[Byte](VERSION | 1 << 5, 0, 0)
// Can specify 3 bytes per size/offset, even if they aren't needed.
val header = (VERSION | (2 << 6)).toByte
val emptyMetadata4 = Array[Byte](header, 0, 0, 0, 0, 0, 0)
check("true", Array(primitiveHeader(TRUE)), emptyMetadata2)
check("true", Array(primitiveHeader(TRUE)), emptyMetadata3)
check("true", Array(primitiveHeader(TRUE)), emptyMetadata4)
}
// Test StructsToJson with manually constructed input that uses up to 4 bytes for offsets and
// sizes. We never produce 4-byte offsets, since they're only needed for >16 MiB values, which we
// error out on, but the reader should be able to handle them if some other writer decides to use
// them for smaller values.
test("to_json with large offsets and sizes") {
def check(expectedJson: String, value: Array[Byte], metadata: Array[Byte]): Unit = {
checkEvaluation(
StructsToJson(Map.empty, Literal(new VariantVal(value, metadata))),
expectedJson
)
}
for {
offsetSize <- 1 to 4
idSize <- 1 to 4
metadataSize <- 1 to 4
largeSize <- Seq(false, true)
} {
// Test array
val version = Array[Byte]((VERSION | ((metadataSize - 1) << 6)).toByte)
val emptyMetadata = version ++ padded(Array(0, 0), metadataSize)
// Construct a binary with the given sizes. Regardless, to_json should produce the same
// result.
val arrayValue = Array[Byte](arrayHeader(largeSize, offsetSize)) ++
/* size */ padded(Array(3), if (largeSize) 4 else 1) ++
/* offset list */ padded(Array(0, 1, 4, 5), offsetSize) ++
Array[Byte](/* values */ primitiveHeader(FALSE),
primitiveHeader(INT2), 2, 1, primitiveHeader(NULL))
check("[false,258,null]", arrayValue, emptyMetadata)
// Test object
val metadata = version ++
padded(Array(3, 0, 1, 2, 3), metadataSize) ++
Array[Byte]('a', 'b', 'c')
val objectValue = Array[Byte](objectHeader(largeSize, idSize, offsetSize)) ++
/* size */ padded(Array(3), if (largeSize) 4 else 1) ++
/* id list */ padded(Array(0, 1, 2), idSize) ++
/* offset list */ padded(Array(0, 2, 4, 6), offsetSize) ++
/* field data */ Array[Byte](primitiveHeader(INT1), 1,
primitiveHeader(INT1), 2, shortStrHeader(1), '3')
check("""{"a":1,"b":2,"c":"3"}""", objectValue, metadata)
}
}
test("to_json large binary") {
def check(expectedJson: String, value: Array[Byte], metadata: Array[Byte]): Unit = {
checkEvaluation(
StructsToJson(Map.empty, Literal(new VariantVal(value, metadata))),
expectedJson
)
}
// Create a binary that uses the max 1 << 24 bytes for both metadata and value.
val bigVersion = Array[Byte]((VERSION | (2 << 6)).toByte)
// Create a single huge value, followed by a one-byte string. We'll have 1 header byte, plus 12
// bytes for size and offsets, plus 1 byte for the final value, so the large value is 1 << 24 -
// 14 bytes, or (-14, -1, -1) as a signed little-endian value.
val aSize = (1 << 24) - 14
val a = Array.fill(aSize)('a'.toByte)
val hugeMetadata = bigVersion ++ Array[Byte](2, 0, 0, 0, 0, 0, -14, -1, -1, -13, -1, -1) ++
a ++ Array[Byte]('b')
// Validate metadata in isolation.
check("true", Array(primitiveHeader(TRUE)), hugeMetadata)
// The object will contain a large string, and the following bytes:
// - object header and size: 1+4 bytes
// - ID list: 6 bytes
// - offset list: 9 bytes
// - field headers and string length: 6 bytes
// In order to get the full binary to 1 << 24, the large string is (1 << 24) - 26 bytes. As a
// signed little-endian value, this is (-26, -1, -1).
val ySize = (1 << 24) - 26
val y = Array.fill(ySize)('y'.toByte)
val hugeObject = Array[Byte](objectHeader(true, 3, 3)) ++
/* size */ padded(Array(2), 4) ++
/* id list */ padded(Array(0, 1), 3) ++
// Second offset is (-26,-1,-1), plus 5 bytes for string header, so (-21,-1,-1)
/* offset list */ Array[Byte](0, 0, 0, -21, -1, -1, -20, -1, -1) ++
/* field data */ Array[Byte](primitiveHeader(LONG_STR), -26, -1, -1, 0) ++ y ++ Array[Byte](
primitiveHeader(TRUE)
)
// Same as hugeObject, but with a short string.
val smallObject = Array[Byte](objectHeader(false, 1, 1)) ++
/* size */ Array[Byte](2) ++
/* id list */ Array[Byte](0, 1) ++
/* offset list */ Array[Byte](0, 6, 7) ++
/* field data */ Array[Byte](primitiveHeader(LONG_STR), 1, 0, 0, 0, 'y',
primitiveHeader(TRUE))
val smallMetadata = bigVersion ++ Array[Byte](2, 0, 0, 0, 0, 0, 1, 0, 0, 2, 0, 0) ++
Array[Byte]('a', 'b')
// Check all combinations of large/small value and metadata.
val expectedResult1 =
s"""{"${a.map(_.toChar).mkString}":"${y.map(_.toChar).mkString}","b":true}"""
check(expectedResult1, hugeObject, hugeMetadata)
val expectedResult2 =
s"""{"${a.map(_.toChar).mkString}":"y","b":true}"""
check(expectedResult2, smallObject, hugeMetadata)
val expectedResult3 =
s"""{"a":"${y.map(_.toChar).mkString}","b":true}"""
check(expectedResult3, hugeObject, smallMetadata)
val expectedResult4 =
s"""{"a":"y","b":true}"""
check(expectedResult4, smallObject, smallMetadata)
}
test("is_variant_null invalid input") {
checkErrorInExpression[SparkRuntimeException](
IsVariantNull(Literal(new VariantVal(Array(), Array(1, 2, 3)))),
"MALFORMED_VARIANT"
)
}
private def parseJson(input: String): VariantVal =
VariantExpressionEvalUtils.parseJson(UTF8String.fromString(input))
private def variantGet(input: String, path: String, dataType: DataType): VariantGet =
VariantGet(Literal(parseJson(input)), Literal(path), dataType, failOnError = true)
private def tryVariantGet(input: String, path: String, dataType: DataType): VariantGet =
VariantGet(Literal(parseJson(input)), Literal(path), dataType, failOnError = false)
private def testVariantGet(input: String, path: String, dataType: DataType, output: Any): Unit = {
checkEvaluation(variantGet(input, path, dataType), output)
checkEvaluation(
VariantGet(variantGet(input, path, VariantType), Literal("$"), dataType, failOnError = true),
output
)
checkEvaluation(tryVariantGet(input, path, dataType), output)
}
// If an individual element cannot be cast to the target type, `variant_get` will return an error
// and `try_variant_get` will only set that element to be null.
private def testInvalidVariantGet(
input: String,
path: String,
dataType: DataType,
parameters: Map[String, String] = null,
tryOutput: Any = null): Unit = {
checkErrorInExpression[SparkRuntimeException](
variantGet(input, path, dataType),
"INVALID_VARIANT_CAST",
Option(parameters).getOrElse(
Map("value" -> input, "dataType" -> ("\"" + dataType.sql + "\"")))
)
checkEvaluation(tryVariantGet(input, path, dataType), tryOutput)
}
test("variant_get cast") {
// Source type is string.
testVariantGet("\"true\"", "$", BooleanType, true)
testVariantGet("\"false\"", "$", BooleanType, false)
testVariantGet("\" t \"", "$", BooleanType, true)
testInvalidVariantGet("\"true\"", "$", IntegerType)
testVariantGet("\"1\"", "$", IntegerType, 1)
testVariantGet("\"9223372036854775807\"", "$", LongType, 9223372036854775807L)
testVariantGet("\"-0.0\"", "$", DoubleType, -0.0)
testVariantGet("\"inf\"", "$", DoubleType, Double.PositiveInfinity)
testVariantGet("\"-inf\"", "$", DoubleType, Double.NegativeInfinity)
testVariantGet("\"nan\"", "$", DoubleType, Double.NaN)
testVariantGet("\"12.34\"", "$", FloatType, 12.34f)
testVariantGet("\"12.34\"", "$", DecimalType(9, 4), Decimal(12.34))
testVariantGet("\"1970-01-01\"", "$", DateType, 0)
testVariantGet("\"1970-03-01\"", "$", DateType, 59)
// Source type is boolean.
testVariantGet("true", "$", BooleanType, true)
testVariantGet("false", "$", BooleanType, false)
testVariantGet("true", "$", ByteType, 1.toByte)
testVariantGet("true", "$", DoubleType, 1.0)
testVariantGet("true", "$", DecimalType(18, 17), Decimal(1))
testInvalidVariantGet("true", "$", DecimalType(18, 18))
testVariantGet("false", "$", DecimalType(18, 18), Decimal(0))
// Source type is integer.
testVariantGet("1", "$", BooleanType, true)
testVariantGet("0", "$", BooleanType, false)
testInvalidVariantGet("1", "$", BinaryType)
testVariantGet("127", "$", ByteType, 127.toByte)
testInvalidVariantGet("128", "$", ByteType)
testVariantGet("-32768", "$", ShortType, (-32768).toShort)
testInvalidVariantGet("-32769", "$", ShortType)
testVariantGet("2147483647", "$", IntegerType, 2147483647)
testInvalidVariantGet("2147483648", "$", IntegerType)
testVariantGet("9223372036854775807", "$", LongType, 9223372036854775807L)
testVariantGet("-9223372036854775808", "$", LongType, -9223372036854775808L)
testVariantGet("2147483647", "$", FloatType, 2147483647.0f)
testVariantGet("2147483647", "$", DoubleType, 2147483647.0d)
testVariantGet("1", "$", DecimalType(9, 4), Decimal(1))
testVariantGet("99999999", "$", DecimalType(38, 30), Decimal(99999999))
testInvalidVariantGet("100000000", "$", DecimalType(38, 30))
testInvalidVariantGet("12345", "$", DecimalType(6, 3))
testVariantGet("-1", "$", TimestampType, -1000000L)
testVariantGet("9223372036854", "$", TimestampType, 9223372036854000000L)
testInvalidVariantGet("9223372036855", "$", TimestampType)
testInvalidVariantGet("0", "$", TimestampNTZType)
// Source type is double. Always use scientific notation to avoid decimal.
testVariantGet("1E0", "$", BooleanType, true)
testVariantGet("0E0", "$", BooleanType, false)
testVariantGet("-0E0", "$", BooleanType, false)
testVariantGet("127E0", "$", ByteType, 127.toByte)
testInvalidVariantGet(
"128E0",
"$",
ByteType,
Map("value" -> "128.0", "dataType" -> "\"TINYINT\"")
)
testVariantGet("-9.223372036854776E18", "$", LongType, Long.MinValue)
testInvalidVariantGet("-9.223372036854778E18", "$", LongType)
testVariantGet("1E308", "$", FloatType, Float.PositiveInfinity)
testVariantGet("12345E-4", "$", DecimalType(5, 2), Decimal(1.23))
testVariantGet("9999999999E-2", "$", DecimalType(38, 30), Decimal(99999999.99))
testInvalidVariantGet(
"100000000E0",
"$",
DecimalType(38, 30),
Map("value" -> "1.0E8", "dataType" -> "\"DECIMAL(38,30)\"")
)
testVariantGet("9223372036854.5E0", "$", TimestampType, 9223372036854500352L)
testInvalidVariantGet(
"9223372036855E0",
"$",
TimestampType,
Map("value" -> "9.223372036855E12", "dataType" -> "\"TIMESTAMP\"")
)
// Source type is decimal.
testVariantGet("1.0", "$", BooleanType, true)
testVariantGet("0.0", "$", BooleanType, false)
testVariantGet("-0.0", "$", BooleanType, false)
testVariantGet("2147483647.999", "$", IntegerType, 2147483647)
testInvalidVariantGet("9223372036854775808", "$", LongType)
testVariantGet("-9223372036854775808.0", "$", LongType, -9223372036854775808L)
testVariantGet("123.0", "$", DecimalType(6, 3), Decimal(123000, 6, 3))
testVariantGet("1.14", "$", DecimalType(2, 1), Decimal(11, 2, 1))
testVariantGet("1.15", "$", DecimalType(2, 1), Decimal(12, 2, 1))
testVariantGet(
"0.0000000009999999994",
"$",
DecimalType(18, 18),
Decimal("0.000000000999999999")
)
testVariantGet("0.0000000009999999995", "$", DecimalType(18, 18), Decimal("0.000000001"))
testInvalidVariantGet("9.5", "$", DecimalType(1, 0))
testVariantGet("9999999999999999999.9999999999999999999", "$", FloatType, 1e19f)
testVariantGet("9999999999999999999.9999999999999999999", "$", DoubleType, 1e19)
testVariantGet(
"9999999999999999999.9999999999999999999",
"$",
StringType,
"9999999999999999999.9999999999999999999"
)
// Input doesn't fit into decimal, use double instead, which causes a loss of precision.
testVariantGet("9999999999999999999.99999999999999999999", "$", StringType, "1.0E19")
// Input fits into `decimal(38, 38)`.
testVariantGet(
"0.99999999999999999999999999999999999999",
"$",
DecimalType(38, 38),
Decimal("0.99999999999999999999999999999999999999")
)
testVariantGet("1.10", "$", StringType, "1.1")
testVariantGet("-1.00", "$", StringType, "-1")
// Test Decimal(N, 0).
testVariantGet("-100000000000000000000", "$", StringType, "-100000000000000000000")
testVariantGet(
"99999999999999999999000000000000000000",
"$",
StringType,
"99999999999999999999000000000000000000"
)
// Source type is null.
testVariantGet("null", "$", BooleanType, null)
testVariantGet("null", "$", IntegerType, null)
testVariantGet("null", "$", DoubleType, null)
testVariantGet("null", "$", DecimalType(18, 9), null)
testVariantGet("null", "$", TimestampType, null)
testVariantGet("null", "$", DateType, null)
}
test("variant_get path extraction") {
// Test case adapted from `JsonExpressionsSuite`.
val json =
"""
|{"store":{"fruit":[{"weight":8,"type":"apple"},{"weight":9,"type":"pear"}],
|"basket":[[1,2,{"b":"y","a":"x"}],[3,4],[5,6]],"book":[{"author":"Nigel Rees",
|"title":"Sayings of the Century","category":"reference","price":8.95},
|{"author":"Herman Melville","title":"Moby Dick","category":"fiction","price":8.99,
|"isbn":"0-553-21311-3"},{"author":"J. R. R. Tolkien","title":"The Lord of the Rings",
|"category":"fiction","reader":[{"age":25,"name":"bob"},{"age":26,"name":"jack"}],
|"price":22.99,"isbn":"0-395-19395-8"}],"bicycle":{"price":19.95,"color":"red"}},
|"email":"amy@only_for_json_udf_test.net","owner":"amy","zip code":"94025",
|"fb:testid":"1234"}
|""".stripMargin
testVariantGet(json, "$.store.bicycle", StringType, """{"color":"red","price":19.95}""")
checkEvaluation(
VariantGet(
tryVariantGet(json, "$.store.bicycle", VariantType),
Literal("$"),
StringType,
failOnError = true
),
"""{"color":"red","price":19.95}"""
)
testVariantGet(json, "$.store.bicycle.color", StringType, "red")
testVariantGet(json, "$.store.bicycle.price", DoubleType, 19.95)
testVariantGet(
json,
"$.store.book",
StringType,
"""[{"author":"Nigel Rees","category":"reference","price":8.95,"title":
|"Sayings of the Century"},{"author":"Herman Melville","category":"fiction","isbn":
|"0-553-21311-3","price":8.99,"title":"Moby Dick"},{"author":"J. R. R. Tolkien","category":
|"fiction","isbn":"0-395-19395-8","price":22.99,"reader":[{"age":25,"name":"bob"},{"age":26,
|"name":"jack"}],"title":"The Lord of the Rings"}]""".stripMargin.replace("\n", "")
)
testVariantGet(
json,
"$.store.book[0]",
StringType,
"""{"author":"Nigel Rees","category":"reference","price":8.95,"title":
|"Sayings of the Century"}""".stripMargin.replace("\n", "")
)
testVariantGet(json, "$.store.book[0].category", StringType, "reference")
testVariantGet(json, "$.store.book[1].price", DoubleType, 8.99)
testVariantGet(json, "$.store.book[2].reader[0].name", StringType, "bob")
testVariantGet(json, "$.store.book[2].reader[1].age", IntegerType, 26)
testVariantGet(json, "$.store.basket[0][1]", IntegerType, 2)
testVariantGet(json, "$.store.basket[0][2]", StringType, """{"a":"x","b":"y"}""")
testVariantGet(json, "$.zip code", IntegerType, 94025)
testVariantGet(json, "$.fb:testid", IntegerType, 1234)
testVariantGet(
json,
"$.store.fruit",
DataType.fromDDL("array<struct<weight int, type string>>"),
Array(Row(8, "apple"), Row(9, "pear"))
)
testVariantGet(
json,
"$.store.book[0]",
DataType.fromDDL("struct<author string, title string, category string, price decimal(4, 2)>"),
Row("Nigel Rees", "Sayings of the Century", "reference", Decimal(8.95))
)
}
test("variant_get negative") {
testVariantGet("""{"a": 1}""", "$[0]", IntegerType, null)
testVariantGet("""{"a": 1}""", "$.A", IntegerType, null)
testVariantGet("[1]", "$.a", IntegerType, null)
testVariantGet("[1]", "$[1]", IntegerType, null)
testVariantGet("1", "$.a", IntegerType, null)
testVariantGet("1", "$[0]", IntegerType, null)
testInvalidVariantGet(
"""{"a": 1}""",
"$",
IntegerType,
Map("value" -> "{\"a\":1}", "dataType" -> "\"INT\"")
)
testInvalidVariantGet("[1]", "$", IntegerType)
}
test("variant_get large") {
val numKeys = 256
var json = (0 until numKeys).map(_.toString).mkString("[", ",", "]")
for (i <- 0 until numKeys) {
testVariantGet(json, "$[" + i + "]", IntegerType, i)
}
testVariantGet(json, "$[" + numKeys + "]", IntegerType, null)
json = (0 until numKeys).map(i => s""""$i": $i""").mkString("{", ",", "}")
for (i <- 0 until numKeys) {
testVariantGet(json, "$." + i, IntegerType, i)
}
testVariantGet(json, "$." + numKeys, IntegerType, null)
}
test("variant_get timestamp") {
DateTimeTestUtils.outstandingZoneIds.foreach { zid =>
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> zid.getId) {
def toMicros(time: LocalDateTime, zoneId: ZoneId): Long = {
val instant = time.atZone(zoneId).toInstant
instant.getEpochSecond * 1000000L + instant.getNano / 1000L
}
val input = "\"2026-04-05 5:16:07\""
val expected = LocalDateTime.of(2026, 4, 5, 5, 16, 7, 0)
testVariantGet(input, "$", TimestampType, toMicros(expected, zid))
testVariantGet(input, "$", TimestampNTZType, toMicros(expected, ZoneOffset.UTC))
}
}
}
test("variant_get overflow") {
for (ansi <- Seq(false, true)) {
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansi.toString) {
// `variant_get` is not affected by the ANSI flag. It doesn't have the LEGACY mode.
testInvalidVariantGet(
"""{"a": 2147483648}""",
"$.a",
IntegerType,
Map("value" -> "2147483648", "dataType" -> "\"INT\"")
)
}
}
}
test("variant_get nested") {
testVariantGet("null", "$", DataType.fromDDL("a int"), null)
testVariantGet("{}", "$", DataType.fromDDL("a int"), Row(null))
testVariantGet("""{"a": 1}""", "$", DataType.fromDDL("a int"), Row(1))
testInvalidVariantGet("1", "$", DataType.fromDDL("a int"))
testVariantGet("""{"a": 1, "b": "2"}""", "$", DataType.fromDDL("a int, b string"), Row(1, "2"))
testVariantGet("""{"a": 1, "b": "2"}""", "$", DataType.fromDDL("a string, b int"), Row("1", 2))
testVariantGet("""{"b": "2", "a": 1}""", "$", DataType.fromDDL("a string, b int"), Row("1", 2))
testVariantGet(
"""{"a": 1, "d": 2, "c": 3}""",
"$",
DataType.fromDDL("a int, b int, c int"),
Row(1, null, 3)
)
testInvalidVariantGet(
"""{"a": 1, "b": "2"}""",
"$",
DataType.fromDDL("a int, b boolean"),
Map("value" -> "\"2\"", "dataType" -> "\"BOOLEAN\""),
Row(1, null)
)
testVariantGet("null", "$", DataType.fromDDL("array<int>"), null)
testVariantGet("[]", "$", DataType.fromDDL("array<int>"), Array())
testInvalidVariantGet("{}", "$", DataType.fromDDL("array<int>"))
testVariantGet(
"""[1, 2, 3, null, "4", 5.0]""",
"$",
DataType.fromDDL("array<int>"),
Array(1, 2, 3, null, 4, 5)
)
testVariantGet(
"""[1, 2, 3, null, "4", 5.0]""",
"$",
DataType.fromDDL("array<string>"),
Array("1", "2", "3", null, "4", "5")
)
testVariantGet(
"""[[1], [2, 3], [4, 5, 6], [7, 8, 9, 10]]""",
"$",
DataType.fromDDL("array<array<int>>"),
Array(Array(1), Array(2, 3), Array(4, 5, 6), Array(7, 8, 9, 10))
)
testInvalidVariantGet(
"""[1, 2, 3, "hello"]""",
"$",
DataType.fromDDL("array<int>"),
Map("value" -> "\"hello\"", "dataType" -> "\"INT\""),
Array(1, 2, 3, null)
)
testVariantGet("null", "$", DataType.fromDDL("map<string, int>"), null)
testVariantGet("{}", "$", DataType.fromDDL("map<string, int>"), Map())
testInvalidVariantGet("[]", "$", DataType.fromDDL("map<string, int>"))
testVariantGet(
"""{"a": 1, "b": "2", "c": null}""",
"$",
DataType.fromDDL("map<string, int>"),
Map("a" -> 1, "b" -> 2, "c" -> null)
)
testVariantGet(
"""{"a": {}, "b": {"c": "d"}, "e": {"f": "g"}}""",
"$",
DataType.fromDDL("map<string, map<string, string>>"),
Map("a" -> Map(), "b" -> Map("c" -> "d"), "e" -> Map("f" -> "g"))
)
testInvalidVariantGet(
"""{"a": 1, "b": "2", "c": {}}""",
"$",
DataType.fromDDL("map<string, int>"),
Map("value" -> "{}", "dataType" -> "\"INT\""),
Map("a" -> 1, "b" -> 2, "c" -> null)
)
testVariantGet(
"""[{"a": 1}, {"b": 2}, null, {}]""",
"$",
DataType.fromDDL("array<struct<a int, b int>>"),
Array(Row(1, null), Row(null, 2), null, Row(null, null))
)
testVariantGet(
"""[{"a": 1}, {"b": 2}, null, {}]""",
"$",
DataType.fromDDL("array<map<string, int>>"),
Array(Map("a" -> 1), Map("b" -> 2), null, Map())
)
}
test("variant_get path") {
def checkInvalidPath(path: String): Unit = {
checkErrorInExpression[SparkRuntimeException](
variantGet("0", path, IntegerType),
"INVALID_VARIANT_GET_PATH",
Map("path" -> path, "functionName" -> "`variant_get`")
)
}
testVariantGet("""{"1": {"2": {"3": [4]}}}""", "$.1.2.3[0]", IntegerType, 4)
testVariantGet("""{"1": {"2": {"3": [4]}}}""", "$.1.2.3['0']", IntegerType, null)
// scalastyle:off nonascii
testVariantGet("""{"你好": {"世界": "hello"}}""", """$['你好']["世界"]""", StringType, "hello")
// scalastyle:on nonascii
checkInvalidPath("")
checkInvalidPath(".a")
checkInvalidPath("$1")
checkInvalidPath("$[-1]")
checkInvalidPath("""$['"]""")
}
test("cast from variant") {
// We do not test too many type combinations, as the cast implementation is mostly the same as
// variant_get.
def checkCast(input: Any, dataType: DataType, output: Any): Unit = {
for (mode <- Seq(EvalMode.LEGACY, EvalMode.ANSI, EvalMode.TRY)) {
checkEvaluation(Cast(Literal(input), dataType, evalMode = mode), output)
}
}
def checkInvalidCast(input: Any, dataType: DataType, tryOutput: Any): Unit = {
// Casting from variant is not affected by the ANSI flag.
for (mode <- Seq(EvalMode.LEGACY, EvalMode.ANSI)) {
checkExceptionInExpression[SparkRuntimeException](
Cast(Literal(input), dataType, evalMode = mode),
"INVALID_VARIANT_CAST"
)
}
checkEvaluation(Cast(Literal(input), dataType, evalMode = EvalMode.TRY), tryOutput)
}
checkCast(parseJson("1"), StringType, "1")
// Other to-string casts never produce NULL when the input is not NULL, but variant-to-string
// cast can produce NULL when the input is a variant null (not NULL).
checkCast(parseJson("null"), StringType, null)
checkCast(parseJson("\"1\""), IntegerType, 1)
checkInvalidCast(parseJson("2147483648"), IntegerType, null)
checkInvalidCast(parseJson("[2147483648, 1]"), ArrayType(IntegerType), Array(null, 1))
checkCast(Array(null, parseJson("true")), ArrayType(BooleanType), Array(null, true))
checkCast(
Array(null, parseJson("false"), parseJson("null")),
ArrayType(StringType),
Array(null, "false", null)
)
checkCast(Array(parseJson("[1]")), ArrayType(ArrayType(IntegerType)), Array(Array(1)))
checkInvalidCast(
Array(parseJson("\"hello\""), null, parseJson("\"1\"")),
ArrayType(IntegerType),
Array(null, null, 1)
)
}
test("atomic types that are not produced by parse_json") {
// Dictionary size is `0` for value 0. An empty dictionary contains one offset `0` for the
// one-past-the-end position (i.e. the sum of all string lengths).
val emptyMetadata = Array[Byte](VERSION, 0, 0)
def checkToJson(value: Array[Byte], expected: String): Unit = {
val input = Literal(new VariantVal(value, emptyMetadata))
checkEvaluation(StructsToJson(Map.empty, input), expected)
}
def checkCast(value: Array[Byte], dataType: DataType, expected: Any): Unit = {
val input = Literal(new VariantVal(value, emptyMetadata))
checkEvaluation(Cast(input, dataType, evalMode = EvalMode.ANSI), expected)
}
checkToJson(Array(primitiveHeader(DATE), 0, 0, 0, 0), "\"1970-01-01\"")
checkToJson(Array(primitiveHeader(DATE), -1, -1, -1, 127), "\"+5881580-07-11\"")
checkToJson(Array(primitiveHeader(DATE), 0, 0, 0, -128), "\"-5877641-06-23\"")
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
checkCast(Array(primitiveHeader(DATE), 0, 0, 0, 0), TimestampType, 0L)
checkCast(Array(primitiveHeader(DATE), 1, 0, 0, 0), TimestampType, MICROS_PER_DAY)
}
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Los_Angeles") {
checkCast(Array(primitiveHeader(DATE), 0, 0, 0, 0), TimestampType, 8 * MICROS_PER_HOUR)
checkCast(Array(primitiveHeader(DATE), 1, 0, 0, 0), TimestampType,
MICROS_PER_DAY + 8 * MICROS_PER_HOUR)
}
def littleEndianLong(value: Long): Array[Byte] =
BigInt(value).toByteArray.reverse.padTo(8, 0.toByte)
val time1 = littleEndianLong(0)
// In America/Los_Angeles timezone, timestamp value `skippedTime` is 2011-03-13 03:00:00.
// The next second of 2011-03-13 01:59:59 jumps to 2011-03-13 03:00:00.
val skippedTime = 1300010400000000L
val time2 = littleEndianLong(skippedTime)
val time3 = littleEndianLong(skippedTime - 1)
val time4 = littleEndianLong(Long.MinValue)
val time5 = littleEndianLong(Long.MaxValue)
val time6 = littleEndianLong(-62198755200000000L)
val timestampHeader = Array(primitiveHeader(TIMESTAMP))
val timestampNtzHeader = Array(primitiveHeader(TIMESTAMP_NTZ))
for (timeZone <- Seq("UTC", "America/Los_Angeles")) {
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) {
checkToJson(timestampNtzHeader ++ time1, "\"1970-01-01 00:00:00\"")
checkToJson(timestampNtzHeader ++ time2, "\"2011-03-13 10:00:00\"")
checkToJson(timestampNtzHeader ++ time3, "\"2011-03-13 09:59:59.999999\"")
checkToJson(timestampNtzHeader ++ time4, "\"-290308-12-21 19:59:05.224192\"")
checkToJson(timestampNtzHeader ++ time5, "\"+294247-01-10 04:00:54.775807\"")
checkToJson(timestampNtzHeader ++ time6, "\"-0001-01-01 00:00:00\"")
checkCast(timestampNtzHeader ++ time1, DateType, 0)
checkCast(timestampNtzHeader ++ time2, DateType, 15046)
checkCast(timestampNtzHeader ++ time3, DateType, 15046)
checkCast(timestampNtzHeader ++ time4, DateType, -106751992)
checkCast(timestampNtzHeader ++ time5, DateType, 106751991)
checkCast(timestampNtzHeader ++ time6, DateType, -719893)
}
}
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
checkToJson(timestampHeader ++ time1, "\"1970-01-01 00:00:00+00:00\"")
checkToJson(timestampHeader ++ time2, "\"2011-03-13 10:00:00+00:00\"")
checkToJson(timestampHeader ++ time3, "\"2011-03-13 09:59:59.999999+00:00\"")
checkToJson(timestampHeader ++ time4, "\"-290308-12-21 19:59:05.224192+00:00\"")
checkToJson(timestampHeader ++ time5, "\"+294247-01-10 04:00:54.775807+00:00\"")
checkToJson(timestampHeader ++ time6, "\"-0001-01-01 00:00:00+00:00\"")
checkCast(timestampHeader ++ time1, DateType, 0)
checkCast(timestampHeader ++ time2, DateType, 15046)
checkCast(timestampHeader ++ time3, DateType, 15046)
checkCast(timestampHeader ++ time4, DateType, -106751992)
checkCast(timestampHeader ++ time5, DateType, 106751991)
checkCast(timestampHeader ++ time6, DateType, -719893)
}
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Los_Angeles") {
checkToJson(timestampHeader ++ time1, "\"1969-12-31 16:00:00-08:00\"")
checkToJson(timestampHeader ++ time2, "\"2011-03-13 03:00:00-07:00\"")
checkToJson(timestampHeader ++ time3, "\"2011-03-13 01:59:59.999999-08:00\"")
checkToJson(timestampHeader ++ time4, "\"-290308-12-21 12:06:07.224192-07:52\"")
checkToJson(timestampHeader ++ time5, "\"+294247-01-09 20:00:54.775807-08:00\"")
checkToJson(timestampHeader ++ time6, "\"-0002-12-31 16:07:02-07:52\"")
checkCast(timestampHeader ++ time1, DateType, -1)
checkCast(timestampHeader ++ time2, DateType, 15046)
checkCast(timestampHeader ++ time3, DateType, 15046)
checkCast(timestampHeader ++ time4, DateType, -106751992)
checkCast(timestampHeader ++ time5, DateType, 106751990)
checkCast(timestampHeader ++ time6, DateType, -719894)
}
checkToJson(Array(primitiveHeader(FLOAT)) ++
BigInt(java.lang.Float.floatToIntBits(1.23F)).toByteArray.reverse, "1.23")
checkToJson(Array(primitiveHeader(FLOAT)) ++
BigInt(java.lang.Float.floatToIntBits(-0.0F)).toByteArray.reverse, "-0.0")
// Note: 1.23F.toDouble != 1.23.
checkCast(Array(primitiveHeader(FLOAT)) ++
BigInt(java.lang.Float.floatToIntBits(1.23F)).toByteArray.reverse, DoubleType, 1.23F.toDouble)
checkToJson(Array(primitiveHeader(BINARY), 0, 0, 0, 0), "\"\"")
checkToJson(Array(primitiveHeader(BINARY), 1, 0, 0, 0, 1), "\"AQ==\"")
checkToJson(Array(primitiveHeader(BINARY), 2, 0, 0, 0, 1, 2), "\"AQI=\"")
checkToJson(Array(primitiveHeader(BINARY), 3, 0, 0, 0, 1, 2, 3), "\"AQID\"")
checkCast(Array(primitiveHeader(BINARY), 3, 0, 0, 0, 1, 2, 3), StringType,
"\u0001\u0002\u0003")
checkCast(Array(primitiveHeader(BINARY), 5, 0, 0, 0, 72, 101, 108, 108, 111), StringType,
"Hello")
}
test("SPARK-48150: ParseJson expression nullability") {
assert(!ParseJson(Literal("["), failOnError = true).replacement.nullable)
assert(ParseJson(Literal("["), failOnError = false).replacement.nullable)
checkEvaluation(
ParseJson(Literal("["), failOnError = false).replacement,
null
)
}
test("cast to variant") {
def check[T : TypeTag](input: T, expectedJson: String): Unit = {
val cast = Cast(Literal.create(input), VariantType, evalMode = EvalMode.ANSI)
checkEvaluation(StructsToJson(Map.empty, cast), expectedJson)
}
check(null.asInstanceOf[String], null)
// The following tests cover all allowed scalar types.
for (input <- Seq[Any](false, true, 0.toByte, 1.toShort, 2, 3L, 4.0F, 5.0D)) {
check(input, input.toString)
}
for (precision <- Seq(9, 18, 38)) {
val input = BigDecimal("9" * precision)
check(Literal.create(input, DecimalType(precision, 0)), input.toString)
}
check("", "\"\"")
check("x" * 128, "\"" + ("x" * 128) + "\"")
check(Array[Byte](1, 2, 3), "\"AQID\"")
check(Literal(0, DateType), "\"1970-01-01\"")
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
check(Literal(0L, TimestampType), "\"1970-01-01 00:00:00+00:00\"")
check(Literal(0L, TimestampNTZType), "\"1970-01-01 00:00:00\"")
}
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Los_Angeles") {
check(Literal(0L, TimestampType), "\"1969-12-31 16:00:00-08:00\"")
check(Literal(0L, TimestampNTZType), "\"1970-01-01 00:00:00\"")
}
check(Array(null, "a", "b", "c"), """[null,"a","b","c"]""")
check(Map("z" -> 1, "y" -> 2, "x" -> 3), """{"x":3,"y":2,"z":1}""")
check(Array(parseJson("""{"a": 1,"b": [1, 2, 3]}"""),
parseJson("""{"c": true,"d": {"e": "str"}}""")),
"""[{"a":1,"b":[1,2,3]},{"c":true,"d":{"e":"str"}}]""")
val struct = Literal.create(
Row(
Seq("123", "true", "f"),
Map("a" -> "123", "b" -> "true", "c" -> "f"),
Row(0)),
StructType.fromDDL("c ARRAY<STRING>,b MAP<STRING, STRING>,a STRUCT<i: INT>"))
check(struct, """{"a":{"i":0},"b":{"a":"123","b":"true","c":"f"},"c":["123","true","f"]}""")
}
}