[SPARK-43018][SQL] Fix bug for INSERT commands with timestamp literals
### What changes were proposed in this pull request?
This PR fixes a correctness bug for INSERT commands with timestamp literals. The bug manifests when:
* An INSERT command includes a user-specified column list of fewer columns than the target table.
* The provided values include timestamp literals.
The bug was that the long integer values stored in the rows to represent these timestamp literals were getting assigned back to `UnresolvedInlineTable` rows without the timestamp type. Then the analyzer inserted an implicit cast from `LongType` to `TimestampType` later, which incorrectly caused the value to change during execution.
This PR fixes the bug by propagating the timestamp type directly to the output table instead.
### Why are the changes needed?
This PR fixes a correctness bug.
### Does this PR introduce _any_ user-facing change?
Yes, this PR fixes a correctness bug.
### How was this patch tested?
This PR adds a new unit test suite.
Closes #40652 from dtenedor/assign-correct-insert-types.
Authored-by: Daniel Tenedorio <daniel.tenedorio@databricks.com>
Signed-off-by: Gengliang Wang <gengliang@apache.org>
(cherry picked from commit 9f0bf51a3a7f6175de075198e00a55bfdc491f15)
Signed-off-by: Gengliang Wang <gengliang@apache.org>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala
index 6afb51b..630a85e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala
@@ -107,10 +107,10 @@
insertTableSchemaWithoutPartitionColumns.map { schema: StructType =>
val regenerated: InsertIntoStatement =
regenerateUserSpecifiedCols(i, schema)
- val expanded: LogicalPlan =
+ val (expanded: LogicalPlan, addedDefaults: Boolean) =
addMissingDefaultValuesForInsertFromInlineTable(node, schema, i.userSpecifiedCols.size)
val replaced: Option[LogicalPlan] =
- replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded)
+ replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded, addedDefaults)
replaced.map { r: LogicalPlan =>
node = r
for (child <- children.reverse) {
@@ -131,10 +131,10 @@
insertTableSchemaWithoutPartitionColumns.map { schema =>
val regenerated: InsertIntoStatement = regenerateUserSpecifiedCols(i, schema)
val project: Project = i.query.asInstanceOf[Project]
- val expanded: Project =
+ val (expanded: Project, addedDefaults: Boolean) =
addMissingDefaultValuesForInsertFromProject(project, schema, i.userSpecifiedCols.size)
val replaced: Option[LogicalPlan] =
- replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded)
+ replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded, addedDefaults)
replaced.map { r =>
regenerated.copy(query = r)
}.getOrElse(i)
@@ -270,67 +270,83 @@
/**
* Updates an inline table to generate missing default column values.
+ * Returns the resulting plan plus a boolean indicating whether such values were added.
*/
- private def addMissingDefaultValuesForInsertFromInlineTable(
+ def addMissingDefaultValuesForInsertFromInlineTable(
node: LogicalPlan,
insertTableSchemaWithoutPartitionColumns: StructType,
- numUserSpecifiedColumns: Int): LogicalPlan = {
+ numUserSpecifiedColumns: Int): (LogicalPlan, Boolean) = {
val schema = insertTableSchemaWithoutPartitionColumns
- val newDefaultExpressions: Seq[Expression] =
- getDefaultExpressionsForInsert(schema, numUserSpecifiedColumns)
- val newNames: Seq[String] = if (numUserSpecifiedColumns > 0) {
- schema.fields.drop(numUserSpecifiedColumns).map(_.name)
- } else {
- schema.fields.map(_.name)
- }
- node match {
- case _ if newDefaultExpressions.isEmpty => node
+ val newDefaultExpressions: Seq[UnresolvedAttribute] =
+ getNewDefaultExpressionsForInsert(schema, numUserSpecifiedColumns, node.output.size)
+ val newNames: Seq[String] = schema.fields.map(_.name)
+ val resultPlan: LogicalPlan = node match {
+ case _ if newDefaultExpressions.isEmpty =>
+ node
case table: UnresolvedInlineTable =>
table.copy(
- names = table.names ++ newNames,
+ names = newNames,
rows = table.rows.map { row => row ++ newDefaultExpressions })
case local: LocalRelation =>
- // Note that we have consumed a LocalRelation but return an UnresolvedInlineTable, because
- // addMissingDefaultValuesForInsertFromProject must replace unresolved DEFAULT references.
- UnresolvedInlineTable(
- local.output.map(_.name) ++ newNames,
- local.data.map { row =>
- val colTypes = StructType(local.output.map(col => StructField(col.name, col.dataType)))
- row.toSeq(colTypes).map(Literal(_)) ++ newDefaultExpressions
+ val newDefaultExpressionsRow = new GenericInternalRow(
+ // Note that this code path only runs when there is a user-specified column list of fewer
+ // column than the target table; otherwise, the above 'newDefaultExpressions' is empty and
+ // we match the first case in this list instead.
+ schema.fields.drop(local.output.size).map {
+ case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) =>
+ analyze(f, "INSERT") match {
+ case lit: Literal => lit.value
+ case _ => null
+ }
+ case _ => null
})
- case _ => node
+ LocalRelation(
+ output = schema.toAttributes,
+ data = local.data.map { row =>
+ new JoinedRow(row, newDefaultExpressionsRow)
+ })
+ case _ =>
+ node
}
+ (resultPlan, newDefaultExpressions.nonEmpty)
}
/**
* Adds a new expressions to a projection to generate missing default column values.
+ * Returns the logical plan plus a boolean indicating if such defaults were added.
*/
private def addMissingDefaultValuesForInsertFromProject(
project: Project,
insertTableSchemaWithoutPartitionColumns: StructType,
- numUserSpecifiedColumns: Int): Project = {
+ numUserSpecifiedColumns: Int): (Project, Boolean) = {
val schema = insertTableSchemaWithoutPartitionColumns
val newDefaultExpressions: Seq[Expression] =
- getDefaultExpressionsForInsert(schema, numUserSpecifiedColumns)
+ getNewDefaultExpressionsForInsert(schema, numUserSpecifiedColumns, project.projectList.size)
val newAliases: Seq[NamedExpression] =
newDefaultExpressions.zip(schema.fields).map {
case (expr, field) => Alias(expr, field.name)()
}
- project.copy(projectList = project.projectList ++ newAliases)
+ (project.copy(projectList = project.projectList ++ newAliases),
+ newDefaultExpressions.nonEmpty)
}
/**
* This is a helper for the addMissingDefaultValuesForInsertFromInlineTable methods above.
*/
- private def getDefaultExpressionsForInsert(
- schema: StructType,
- numUserSpecifiedColumns: Int): Seq[Expression] = {
+ private def getNewDefaultExpressionsForInsert(
+ insertTableSchemaWithoutPartitionColumns: StructType,
+ numUserSpecifiedColumns: Int,
+ numProvidedValues: Int): Seq[UnresolvedAttribute] = {
val remainingFields: Seq[StructField] = if (numUserSpecifiedColumns > 0) {
- schema.fields.drop(numUserSpecifiedColumns)
+ insertTableSchemaWithoutPartitionColumns.fields.drop(numUserSpecifiedColumns)
} else {
Seq.empty
}
val numDefaultExpressionsToAdd = getStructFieldsForDefaultExpressions(remainingFields).size
+ // Limit the number of new DEFAULT expressions to the difference of the number of columns in
+ // the target table and the number of provided values in the source relation. This clamps the
+ // total final number of provided values to the number of columns in the target table.
+ .min(insertTableSchemaWithoutPartitionColumns.size - numProvidedValues)
Seq.fill(numDefaultExpressionsToAdd)(UnresolvedAttribute(CURRENT_DEFAULT_COLUMN_NAME))
}
@@ -351,7 +367,8 @@
*/
private def replaceExplicitDefaultValuesForInputOfInsertInto(
insertTableSchemaWithoutPartitionColumns: StructType,
- input: LogicalPlan): Option[LogicalPlan] = {
+ input: LogicalPlan,
+ addedDefaults: Boolean): Option[LogicalPlan] = {
val schema = insertTableSchemaWithoutPartitionColumns
val defaultExpressions: Seq[Expression] = schema.fields.map {
case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) => analyze(f, "INSERT")
@@ -371,7 +388,11 @@
case project: Project =>
replaceExplicitDefaultValuesForProject(defaultExpressions, project)
case local: LocalRelation =>
- Some(local)
+ if (addedDefaults) {
+ Some(local)
+ } else {
+ None
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumnsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumnsSuite.scala
new file mode 100644
index 0000000..fc540e6
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumnsSuite.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{StructField, StructType, TimestampType}
+
+class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession {
+ val rule = ResolveDefaultColumns(catalog = null)
+ // This is the internal storage for the timestamp 2020-12-31 00:00:00.0.
+ val literal = Literal(1609401600000000L, TimestampType)
+ val table = UnresolvedInlineTable(
+ names = Seq("attr1"),
+ rows = Seq(Seq(literal)))
+ val localRelation = ResolveInlineTables(table).asInstanceOf[LocalRelation]
+
+ def asLocalRelation(result: LogicalPlan): LocalRelation = result match {
+ case r: LocalRelation => r
+ case _ => fail(s"invalid result operator type: $result")
+ }
+
+ test("SPARK-43018: Add DEFAULTs for INSERT from VALUES list with user-defined columns") {
+ // Call the 'addMissingDefaultValuesForInsertFromInlineTable' method with one user-specified
+ // column. We add a default value of NULL to the row as a result.
+ val insertTableSchemaWithoutPartitionColumns = StructType(Seq(
+ StructField("c1", TimestampType),
+ StructField("c2", TimestampType)))
+ val (result: LogicalPlan, _: Boolean) =
+ rule.addMissingDefaultValuesForInsertFromInlineTable(
+ localRelation, insertTableSchemaWithoutPartitionColumns, numUserSpecifiedColumns = 1)
+ val relation = asLocalRelation(result)
+ assert(relation.output.map(_.name) == Seq("c1", "c2"))
+ val data: Seq[Seq[Any]] = relation.data.map { row =>
+ row.toSeq(StructType(relation.output.map(col => StructField(col.name, col.dataType))))
+ }
+ assert(data == Seq(Seq(literal.value, null)))
+ }
+
+ test("SPARK-43018: Add no DEFAULTs for INSERT from VALUES list with no user-defined columns") {
+ // Call the 'addMissingDefaultValuesForInsertFromInlineTable' method with zero user-specified
+ // columns. The table is unchanged because there are no default columns to add in this case.
+ val insertTableSchemaWithoutPartitionColumns = StructType(Seq(
+ StructField("c1", TimestampType),
+ StructField("c2", TimestampType)))
+ val (result: LogicalPlan, _: Boolean) =
+ rule.addMissingDefaultValuesForInsertFromInlineTable(
+ localRelation, insertTableSchemaWithoutPartitionColumns, numUserSpecifiedColumns = 0)
+ assert(asLocalRelation(result) == localRelation)
+ }
+
+ test("SPARK-43018: INSERT timestamp values into a table with column DEFAULTs") {
+ withTable("t") {
+ sql("create table t(id int, ts timestamp) using parquet")
+ sql("insert into t (ts) values (timestamp'2020-12-31')")
+ checkAnswer(spark.table("t"),
+ sql("select null, timestamp'2020-12-31'").collect().head)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index d95a372..13ffa6d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -1100,9 +1100,15 @@
exception = intercept[AnalysisException] {
sql(s"INSERT INTO $t1(data, data) VALUES(5)")
},
- errorClass = "COLUMN_ALREADY_EXISTS",
- parameters = Map("columnName" -> "`data`")
- )
+ errorClass = "_LEGACY_ERROR_TEMP_2305",
+ parameters = Map(
+ "numCols" -> "3",
+ "rowSize" -> "2",
+ "ri" -> "0"),
+ context = ExpectedContext(
+ fragment = s"INSERT INTO $t1(data, data)",
+ start = 0,
+ stop = 26))
}
}
@@ -1123,14 +1129,20 @@
assert(intercept[AnalysisException] {
sql(s"INSERT OVERWRITE $t1 VALUES(4)")
}.getMessage.contains("not enough data columns"))
- // Duplicate columns
+ // Duplicate columns
checkError(
exception = intercept[AnalysisException] {
sql(s"INSERT OVERWRITE $t1(data, data) VALUES(5)")
},
- errorClass = "COLUMN_ALREADY_EXISTS",
- parameters = Map("columnName" -> "`data`")
- )
+ errorClass = "_LEGACY_ERROR_TEMP_2305",
+ parameters = Map(
+ "numCols" -> "3",
+ "rowSize" -> "2",
+ "ri" -> "0"),
+ context = ExpectedContext(
+ fragment = s"INSERT OVERWRITE $t1(data, data)",
+ start = 0,
+ stop = 31))
}
}
@@ -1152,14 +1164,20 @@
assert(intercept[AnalysisException] {
sql(s"INSERT OVERWRITE $t1 VALUES('a', 4)")
}.getMessage.contains("not enough data columns"))
- // Duplicate columns
+ // Duplicate columns
checkError(
exception = intercept[AnalysisException] {
sql(s"INSERT OVERWRITE $t1(data, data) VALUES(5)")
},
- errorClass = "COLUMN_ALREADY_EXISTS",
- parameters = Map("columnName" -> "`data`")
- )
+ errorClass = "_LEGACY_ERROR_TEMP_2305",
+ parameters = Map(
+ "numCols" -> "4",
+ "rowSize" -> "3",
+ "ri" -> "0"),
+ context = ExpectedContext(
+ fragment = s"INSERT OVERWRITE $t1(data, data)",
+ start = 0,
+ stop = 31))
}
}