[SPARK-48067][SQL] Fix variant default columns
### What changes were proposed in this pull request?
Changes the literal `sql` representation of a variant value to `parse_json(variant.toJson)`. This is because there is no other representation of a literal variant.
This allows variant default columns to work because default columns store a literal string representation in the schema struct fields metadata as the default value.
### Why are the changes needed?
previously we could not set a variant default column like
```
create table t(
v6 variant default parse_json('{\"k\": \"v\"}')
)
```
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
added UT
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #46312 from richardc-db/fix_variant_default_cols.
Authored-by: Richard Chen <r.chen@databricks.com>
Signed-off-by: Gengliang Wang <gengliang@apache.org>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 0fad3eff..4cffc7f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -42,6 +42,7 @@
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils
import org.apache.spark.sql.catalyst.trees.TreePattern
import org.apache.spark.sql.catalyst.trees.TreePattern.{LITERAL, NULL_LITERAL, TRUE_OR_FALSE_LITERAL}
import org.apache.spark.sql.catalyst.types._
@@ -204,6 +205,8 @@
create(new GenericInternalRow(
struct.fields.map(f => default(f.dataType).value)), struct)
case udt: UserDefinedType[_] => Literal(default(udt.sqlType).value, udt)
+ case VariantType =>
+ create(VariantExpressionEvalUtils.castToVariant(0, IntegerType), VariantType)
case other =>
throw QueryExecutionErrors.noDefaultForDataTypeError(dataType)
}
@@ -549,6 +552,7 @@
s"${Literal(kv._1, mapType.keyType).sql}, ${Literal(kv._2, mapType.valueType).sql}"
}
s"MAP(${keysAndValues.mkString(", ")})"
+ case (v: VariantVal, variantType: VariantType) => s"PARSE_JSON('${v.toJson(timeZoneId)}')"
case _ => value.toString
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
index 19e5f9b..caab98b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
@@ -26,15 +26,17 @@
import scala.util.Random
import org.apache.spark.SparkRuntimeException
-import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
+import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, ExpressionEvalHelper, Literal}
+import org.apache.spark.sql.catalyst.expressions.variant.{VariantExpressionEvalUtils, VariantGet}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.VariantVal
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
import org.apache.spark.util.ArrayImplicits._
-class VariantSuite extends QueryTest with SharedSparkSession {
+class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEvalHelper {
import testImplicits._
test("basic tests") {
@@ -445,4 +447,141 @@
}
}
}
+
+ test("SPARK-48067: default variant columns works") {
+ withTable("t") {
+ sql("""create table t(
+ v1 variant default null,
+ v2 variant default parse_json(null),
+ v3 variant default cast(null as variant),
+ v4 variant default parse_json('1'),
+ v5 variant default parse_json('1'),
+ v6 variant default parse_json('{\"k\": \"v\"}'),
+ v7 variant default cast(5 as int),
+ v8 variant default cast('hello' as string),
+ v9 variant default parse_json(to_json(parse_json('{\"k\": \"v\"}')))
+ ) using parquet""")
+ sql("""insert into t values(DEFAULT, DEFAULT, DEFAULT, DEFAULT, DEFAULT, DEFAULT, DEFAULT,
+ DEFAULT, DEFAULT)""")
+
+ val expected = sql("""select
+ cast(null as variant) as v1,
+ parse_json(null) as v2,
+ cast(null as variant) as v3,
+ parse_json('1') as v4,
+ parse_json('1') as v5,
+ parse_json('{\"k\": \"v\"}') as v6,
+ cast(cast(5 as int) as variant) as v7,
+ cast('hello' as variant) as v8,
+ parse_json(to_json(parse_json('{\"k\": \"v\"}'))) as v9
+ """)
+ val actual = sql("select * from t")
+ checkAnswer(actual, expected.collect())
+ }
+ }
+
+ Seq(
+ (
+ "basic int parse json",
+ VariantExpressionEvalUtils.parseJson(UTF8String.fromString("1")),
+ VariantType
+ ),
+ (
+ "basic json parse json",
+ VariantExpressionEvalUtils.parseJson(UTF8String.fromString("{\"k\": \"v\"}")),
+ VariantType
+ ),
+ (
+ "basic null parse json",
+ VariantExpressionEvalUtils.parseJson(UTF8String.fromString("null")),
+ VariantType
+ ),
+ (
+ "basic null",
+ null,
+ VariantType
+ ),
+ (
+ "basic array",
+ new GenericArrayData(Array[Int](1, 2, 3, 4, 5)),
+ new ArrayType(IntegerType, false)
+ ),
+ (
+ "basic string",
+ UTF8String.fromString("literal string"),
+ StringType
+ ),
+ (
+ "basic timestamp",
+ 0L,
+ TimestampType
+ ),
+ (
+ "basic int",
+ 0,
+ IntegerType
+ ),
+ (
+ "basic struct",
+ Literal.default(new StructType().add("col0", StringType)).eval(),
+ new StructType().add("col0", StringType)
+ ),
+ (
+ "complex struct with child variant",
+ Literal.default(new StructType()
+ .add("col0", StringType)
+ .add("col1", new StructType().add("col0", VariantType))
+ .add("col2", VariantType)
+ .add("col3", new ArrayType(VariantType, false))
+ ).eval(),
+ new StructType()
+ .add("col0", StringType)
+ .add("col1", new StructType().add("col0", VariantType))
+ .add("col2", VariantType)
+ .add("col3", new ArrayType(VariantType, false))
+ ),
+ (
+ "basic array with null",
+ new GenericArrayData(Array[Any](1, 2, null)),
+ new ArrayType(IntegerType, true)
+ ),
+ (
+ "basic map with null",
+ new ArrayBasedMapData(
+ new GenericArrayData(Array[Any](UTF8String.fromString("k1"), UTF8String.fromString("k2"))),
+ new GenericArrayData(Array[Any](1, null))
+ ),
+ new MapType(StringType, IntegerType, true)
+ )
+ ).foreach { case (testName, value, dt) =>
+ test(s"SPARK-48067: Variant literal `sql` correctly recreates the variant - $testName") {
+ val l = Literal.create(
+ VariantExpressionEvalUtils.castToVariant(value, dt.asInstanceOf[DataType]), VariantType)
+ val jsonString = l.eval().asInstanceOf[VariantVal]
+ .toJson(DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone))
+ val expectedSql = s"PARSE_JSON('$jsonString')"
+ assert(l.sql == expectedSql)
+ val valueFromLiteralSql =
+ spark.sql(s"select ${l.sql}").collect()(0).getAs[VariantVal](0)
+
+ // Cast the variants to their specified type to compare for logical equality.
+ // Currently, variant equality naively compares its value and metadata binaries. However,
+ // variant equality is more complex than this.
+ val castVariantExpr = VariantGet(
+ l,
+ Literal.create(UTF8String.fromString("$"), StringType),
+ dt,
+ true,
+ Some(DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone).toString())
+ )
+ val sqlVariantExpr = VariantGet(
+ Literal.create(valueFromLiteralSql, VariantType),
+ Literal.create(UTF8String.fromString("$"), StringType),
+ dt,
+ true,
+ Some(DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone).toString())
+ )
+ checkEvaluation(castVariantExpr, sqlVariantExpr.eval())
+ }
+ }
}