blob: 90535b5ff01b04b157dce8ff5f164ba68e2d7c50 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import java.io.{ByteArrayOutputStream, File}
import java.lang.{Long => JLong}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.util.{Locale, UUID}
import java.util.concurrent.atomic.AtomicLong
import scala.reflect.runtime.universe.TypeTag
import scala.util.Random
import org.scalatest.matchers.should.Matchers._
import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.Uuid
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, OneRowRelation}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.FakeV2Provider
import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.expressions.{Aggregator, Window}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession}
import org.apache.spark.sql.test.SQLTestData.{DecimalData, TestData2}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
class DataFrameSuite extends QueryTest
with SharedSparkSession
with AdaptiveSparkPlanHelper {
import testImplicits._
test("analysis error should be eagerly reported") {
intercept[Exception] { testData.select("nonExistentName") }
intercept[Exception] {
testData.groupBy("key").agg(Map("nonExistentName" -> "sum"))
}
intercept[Exception] {
testData.groupBy("nonExistentName").agg(Map("key" -> "sum"))
}
intercept[Exception] {
testData.groupBy($"abcd").agg(Map("key" -> "sum"))
}
}
test("dataframe toString") {
assert(testData.toString === "[key: int, value: string]")
assert(testData("key").toString === "key")
assert($"test".toString === "test")
}
test("rename nested groupby") {
val df = Seq((1, (1, 1))).toDF()
checkAnswer(
df.groupBy("_1").agg(sum("_2._1")).toDF("key", "total"),
Row(1, 1) :: Nil)
}
test("access complex data") {
assert(complexData.filter(complexData("a").getItem(0) === 2).count() == 1)
assert(complexData.filter(complexData("m").getItem("1") === 1).count() == 1)
assert(complexData.filter(complexData("s").getField("key") === 1).count() == 1)
}
test("table scan") {
checkAnswer(
testData,
testData.collect().toSeq)
}
test("empty data frame") {
assert(spark.emptyDataFrame.columns.toSeq === Seq.empty[String])
assert(spark.emptyDataFrame.count() === 0)
}
test("head, take and tail") {
assert(testData.take(2) === testData.collect().take(2))
assert(testData.head(2) === testData.collect().take(2))
assert(testData.tail(2) === testData.collect().takeRight(2))
assert(testData.head(2).head.schema === testData.schema)
}
test("dataframe alias") {
val df = Seq(Tuple1(1)).toDF("c").as("t")
val dfAlias = df.alias("t2")
df.col("t.c")
dfAlias.col("t2.c")
}
test("simple explode") {
val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDF("words")
checkAnswer(
df.explode("words", "word") { word: String => word.split(" ").toSeq }.select('word),
Row("a") :: Row("b") :: Row("c") :: Row("d") ::Row("e") :: Nil
)
}
test("explode") {
val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters")
val df2 =
df.explode('letters) {
case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq
}
checkAnswer(
df2
.select('_1 as 'letter, 'number)
.groupBy('letter)
.agg(count_distinct('number)),
Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil
)
}
test("Star Expansion - CreateStruct and CreateArray") {
val structDf = testData2.select("a", "b").as("record")
// CreateStruct and CreateArray in aggregateExpressions
assert(structDf.groupBy($"a").agg(min(struct($"record.*"))).
sort("a").first() == Row(1, Row(1, 1)))
assert(structDf.groupBy($"a").agg(min(array($"record.*"))).
sort("a").first() == Row(1, Seq(1, 1)))
// CreateStruct and CreateArray in project list (unresolved alias)
assert(structDf.select(struct($"record.*")).first() == Row(Row(1, 1)))
assert(structDf.select(array($"record.*")).first().getAs[Seq[Int]](0) === Seq(1, 1))
// CreateStruct and CreateArray in project list (alias)
assert(structDf.select(struct($"record.*").as("a")).first() == Row(Row(1, 1)))
assert(structDf.select(array($"record.*").as("a")).first().getAs[Seq[Int]](0) === Seq(1, 1))
}
test("Star Expansion - hash") {
val structDf = testData2.select("a", "b").as("record")
checkAnswer(
structDf.groupBy($"a", $"b").agg(min(hash($"a", $"*"))),
structDf.groupBy($"a", $"b").agg(min(hash($"a", $"a", $"b"))))
checkAnswer(
structDf.groupBy($"a", $"b").agg(hash($"a", $"*")),
structDf.groupBy($"a", $"b").agg(hash($"a", $"a", $"b")))
checkAnswer(
structDf.select(hash($"*")),
structDf.select(hash($"record.*")))
checkAnswer(
structDf.select(hash($"a", $"*")),
structDf.select(hash($"a", $"record.*")))
}
test("Star Expansion - xxhash64") {
val structDf = testData2.select("a", "b").as("record")
checkAnswer(
structDf.groupBy($"a", $"b").agg(min(xxhash64($"a", $"*"))),
structDf.groupBy($"a", $"b").agg(min(xxhash64($"a", $"a", $"b"))))
checkAnswer(
structDf.groupBy($"a", $"b").agg(xxhash64($"a", $"*")),
structDf.groupBy($"a", $"b").agg(xxhash64($"a", $"a", $"b")))
checkAnswer(
structDf.select(xxhash64($"*")),
structDf.select(xxhash64($"record.*")))
checkAnswer(
structDf.select(xxhash64($"a", $"*")),
structDf.select(xxhash64($"a", $"record.*")))
}
private def assertDecimalSumOverflow(
df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = {
if (!ansiEnabled) {
checkAnswer(df, expectedAnswer)
} else {
val e = intercept[SparkException] {
df.collect()
}
assert(e.getCause.isInstanceOf[ArithmeticException])
assert(e.getCause.getMessage.contains("cannot be represented as Decimal") ||
e.getCause.getMessage.contains("Overflow in sum of decimals"))
}
}
test("SPARK-28224: Aggregate sum big decimal overflow") {
val largeDecimals = spark.sparkContext.parallelize(
DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) ::
DecimalData(BigDecimal("9"* 20 + ".123"), BigDecimal("9"* 20 + ".123")) :: Nil).toDF()
Seq(true, false).foreach { ansiEnabled =>
withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
val structDf = largeDecimals.select("a").agg(sum("a"))
assertDecimalSumOverflow(structDf, ansiEnabled, Row(null))
}
}
}
test("SPARK-28067: sum of null decimal values") {
Seq("true", "false").foreach { wholeStageEnabled =>
withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) {
Seq("true", "false").foreach { ansiEnabled =>
withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled)) {
val df = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d"))
checkAnswer(df.agg(sum($"d")), Row(null))
}
}
}
}
}
def checkAggResultsForDecimalOverflow(aggFn: Column => Column): Unit = {
Seq("true", "false").foreach { wholeStageEnabled =>
withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
val df0 = Seq(
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
val df1 = Seq(
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
val df = df0.union(df1)
val df2 = df.withColumnRenamed("decNum", "decNum2").
join(df, "intNum").agg(aggFn($"decNum"))
val expectedAnswer = Row(null)
assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer)
val decStr = "1" + "0" * 19
val d1 = spark.range(0, 12, 1, 1)
val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(aggFn($"d"))
assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer)
val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(aggFn($"d"))
assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer)
val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"),
lit(1).as("key")).groupBy("key").agg(aggFn($"d").alias("aggd")).select($"aggd")
assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer)
val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d"))
val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")).
toDF("d")
assertDecimalSumOverflow(
nullsDf.union(largeDecimals).agg(aggFn($"d")), ansiEnabled, expectedAnswer)
val df3 = Seq(
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("50000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
val df4 = Seq(
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
val df5 = Seq(
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("20000000000000000000"), 2)).toDF("decNum", "intNum")
val df6 = df3.union(df4).union(df5)
val df7 = df6.groupBy("intNum").agg(sum("decNum"), countDistinct("decNum")).
filter("intNum == 1")
assertDecimalSumOverflow(df7, ansiEnabled, Row(1, null, 2))
}
}
}
}
}
test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") {
checkAggResultsForDecimalOverflow(c => sum(c))
}
test("SPARK-35955: Aggregate avg should not return wrong results for decimal overflow") {
checkAggResultsForDecimalOverflow(c => avg(c))
}
test("Star Expansion - ds.explode should fail with a meaningful message if it takes a star") {
val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv")
val e = intercept[AnalysisException] {
df.explode($"*") { case Row(prefix: String, csv: String) =>
csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq
}.queryExecution.assertAnalyzed()
}
assert(e.getMessage.contains("Invalid usage of '*' in explode/json_tuple/UDTF"))
checkAnswer(
df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) =>
csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq
},
Row("1", "1,2", "1:1") ::
Row("1", "1,2", "1:2") ::
Row("2", "4", "2:4") ::
Row("3", "7,8,9", "3:7") ::
Row("3", "7,8,9", "3:8") ::
Row("3", "7,8,9", "3:9") :: Nil)
}
test("Star Expansion - explode should fail with a meaningful message if it takes a star") {
val df = Seq(("1,2"), ("4"), ("7,8,9")).toDF("csv")
val e = intercept[AnalysisException] {
df.select(explode($"*"))
}
assert(e.getMessage.contains("Invalid usage of '*' in expression 'explode'"))
}
test("explode on output of array-valued function") {
val df = Seq(("1,2"), ("4"), ("7,8,9")).toDF("csv")
checkAnswer(
df.select(explode(split($"csv", pattern = ","))),
Row("1") :: Row("2") :: Row("4") :: Row("7") :: Row("8") :: Row("9") :: Nil)
}
test("Star Expansion - explode alias and star") {
val df = Seq((Array("a"), 1)).toDF("a", "b")
checkAnswer(
df.select(explode($"a").as("a"), $"*"),
Row("a", Seq("a"), 1) :: Nil)
}
test("sort after generate with join=true") {
val df = Seq((Array("a"), 1)).toDF("a", "b")
checkAnswer(
df.select($"*", explode($"a").as("c")).sortWithinPartitions("b", "c"),
Row(Seq("a"), 1, "a") :: Nil)
}
test("selectExpr") {
checkAnswer(
testData.selectExpr("abs(key)", "value"),
testData.collect().map(row => Row(math.abs(row.getInt(0)), row.getString(1))).toSeq)
}
test("selectExpr with alias") {
checkAnswer(
testData.selectExpr("key as k").select("k"),
testData.select("key").collect().toSeq)
}
test("selectExpr with udtf") {
val df = Seq((Map("1" -> 1), 1)).toDF("a", "b")
checkAnswer(
df.selectExpr("explode(a)"),
Row("1", 1) :: Nil)
}
test("filterExpr") {
val res = testData.collect().filter(_.getInt(0) > 90).toSeq
checkAnswer(testData.filter("key > 90"), res)
checkAnswer(testData.filter("key > 9.0e1"), res)
checkAnswer(testData.filter("key > .9e+2"), res)
checkAnswer(testData.filter("key > 0.9e+2"), res)
checkAnswer(testData.filter("key > 900e-1"), res)
checkAnswer(testData.filter("key > 900.0E-1"), res)
checkAnswer(testData.filter("key > 9.e+1"), res)
}
test("filterExpr using where") {
checkAnswer(
testData.where("key > 50"),
testData.collect().filter(_.getInt(0) > 50).toSeq)
}
test("repartition") {
intercept[IllegalArgumentException] {
testData.select("key").repartition(0)
}
checkAnswer(
testData.select("key").repartition(10).select("key"),
testData.select("key").collect().toSeq)
}
test("repartition with SortOrder") {
// passing SortOrder expressions to .repartition() should result in an informative error
def checkSortOrderErrorMsg[T](data: => Dataset[T]): Unit = {
val ex = intercept[IllegalArgumentException](data)
assert(ex.getMessage.contains("repartitionByRange"))
}
checkSortOrderErrorMsg {
Seq(0).toDF("a").repartition(2, $"a".asc)
}
checkSortOrderErrorMsg {
Seq((0, 0)).toDF("a", "b").repartition(2, $"a".asc, $"b")
}
}
test("repartitionByRange") {
val data1d = Random.shuffle(0.to(9))
val data2d = data1d.map(i => (i, data1d.size - i))
checkAnswer(
data1d.toDF("val").repartitionByRange(data1d.size, $"val".asc)
.select(spark_partition_id().as("id"), $"val"),
data1d.map(i => Row(i, i)))
checkAnswer(
data1d.toDF("val").repartitionByRange(data1d.size, $"val".desc)
.select(spark_partition_id().as("id"), $"val"),
data1d.map(i => Row(i, data1d.size - 1 - i)))
checkAnswer(
data1d.toDF("val").repartitionByRange(data1d.size, lit(42))
.select(spark_partition_id().as("id"), $"val"),
data1d.map(i => Row(0, i)))
checkAnswer(
data1d.toDF("val").repartitionByRange(data1d.size, lit(null), $"val".asc, rand())
.select(spark_partition_id().as("id"), $"val"),
data1d.map(i => Row(i, i)))
// .repartitionByRange() assumes .asc by default if no explicit sort order is specified
checkAnswer(
data2d.toDF("a", "b").repartitionByRange(data2d.size, $"a".desc, $"b")
.select(spark_partition_id().as("id"), $"a", $"b"),
data2d.toDF("a", "b").repartitionByRange(data2d.size, $"a".desc, $"b".asc)
.select(spark_partition_id().as("id"), $"a", $"b"))
// at least one partition-by expression must be specified
intercept[IllegalArgumentException] {
data1d.toDF("val").repartitionByRange(data1d.size)
}
intercept[IllegalArgumentException] {
data1d.toDF("val").repartitionByRange(data1d.size, Seq.empty: _*)
}
}
test("coalesce") {
intercept[IllegalArgumentException] {
testData.select("key").coalesce(0)
}
assert(testData.select("key").coalesce(1).rdd.partitions.size === 1)
checkAnswer(
testData.select("key").coalesce(1).select("key"),
testData.select("key").collect().toSeq)
assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 0)
}
test("convert $\"attribute name\" into unresolved attribute") {
checkAnswer(
testData.where($"key" === lit(1)).select($"value"),
Row("1"))
}
test("convert Scala Symbol 'attrname into unresolved attribute") {
checkAnswer(
testData.where($"key" === lit(1)).select("value"),
Row("1"))
}
test("select *") {
checkAnswer(
testData.select($"*"),
testData.collect().toSeq)
}
test("simple select") {
checkAnswer(
testData.where($"key" === lit(1)).select("value"),
Row("1"))
}
test("select with functions") {
checkAnswer(
testData.select(sum("value"), avg("value"), count(lit(1))),
Row(5050.0, 50.5, 100))
checkAnswer(
testData2.select($"a" + $"b", $"a" < $"b"),
Seq(
Row(2, false),
Row(3, true),
Row(3, false),
Row(4, false),
Row(4, false),
Row(5, false)))
checkAnswer(
testData2.select(sum_distinct($"a")),
Row(6))
}
test("sorting with null ordering") {
val data = Seq[java.lang.Integer](2, 1, null).toDF("key")
checkAnswer(data.orderBy($"key".asc), Row(null) :: Row(1) :: Row(2) :: Nil)
checkAnswer(data.orderBy(asc("key")), Row(null) :: Row(1) :: Row(2) :: Nil)
checkAnswer(data.orderBy($"key".asc_nulls_first), Row(null) :: Row(1) :: Row(2) :: Nil)
checkAnswer(data.orderBy(asc_nulls_first("key")), Row(null) :: Row(1) :: Row(2) :: Nil)
checkAnswer(data.orderBy($"key".asc_nulls_last), Row(1) :: Row(2) :: Row(null) :: Nil)
checkAnswer(data.orderBy(asc_nulls_last("key")), Row(1) :: Row(2) :: Row(null) :: Nil)
checkAnswer(data.orderBy($"key".desc), Row(2) :: Row(1) :: Row(null) :: Nil)
checkAnswer(data.orderBy(desc("key")), Row(2) :: Row(1) :: Row(null) :: Nil)
checkAnswer(data.orderBy($"key".desc_nulls_first), Row(null) :: Row(2) :: Row(1) :: Nil)
checkAnswer(data.orderBy(desc_nulls_first("key")), Row(null) :: Row(2) :: Row(1) :: Nil)
checkAnswer(data.orderBy($"key".desc_nulls_last), Row(2) :: Row(1) :: Row(null) :: Nil)
checkAnswer(data.orderBy(desc_nulls_last("key")), Row(2) :: Row(1) :: Row(null) :: Nil)
}
test("global sorting") {
checkAnswer(
testData2.orderBy($"a".asc, $"b".asc),
Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2)))
checkAnswer(
testData2.orderBy(asc("a"), desc("b")),
Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1)))
checkAnswer(
testData2.orderBy($"a".asc, $"b".desc),
Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1)))
checkAnswer(
testData2.orderBy($"a".desc, $"b".desc),
Seq(Row(3, 2), Row(3, 1), Row(2, 2), Row(2, 1), Row(1, 2), Row(1, 1)))
checkAnswer(
testData2.orderBy($"a".desc, $"b".asc),
Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2)))
checkAnswer(
arrayData.toDF().orderBy($"data".getItem(0).asc),
arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)
checkAnswer(
arrayData.toDF().orderBy($"data".getItem(0).desc),
arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)
checkAnswer(
arrayData.toDF().orderBy($"data".getItem(1).asc),
arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)
checkAnswer(
arrayData.toDF().orderBy($"data".getItem(1).desc),
arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
}
test("limit") {
checkAnswer(
testData.limit(10),
testData.take(10).toSeq)
checkAnswer(
arrayData.toDF().limit(1),
arrayData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
checkAnswer(
mapData.toDF().limit(1),
mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
// SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake
checkAnswer(
spark.range(2).toDF().limit(2147483638),
Row(0) :: Row(1) :: Nil
)
}
test("udf") {
val foo = udf((a: Int, b: String) => a.toString + b)
checkAnswer(
// SELECT *, foo(key, value) FROM testData
testData.select($"*", foo($"key", $"value")).limit(3),
Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil
)
}
test("callUDF without Hive Support") {
val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
df.sparkSession.udf.register("simpleUDF", (v: Int) => v * v)
checkAnswer(
df.select($"id", callUDF("simpleUDF", $"value")), // test deprecated one
Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil)
}
test("withColumn") {
val df = testData.toDF().withColumn("newCol", col("key") + 1)
checkAnswer(
df,
testData.collect().map { case Row(key: Int, value: String) =>
Row(key, value, key + 1)
}.toSeq)
assert(df.schema.map(_.name) === Seq("key", "value", "newCol"))
}
test("withColumns") {
val df = testData.toDF().withColumns(Seq("newCol1", "newCol2"),
Seq(col("key") + 1, col("key") + 2))
checkAnswer(
df,
testData.collect().map { case Row(key: Int, value: String) =>
Row(key, value, key + 1, key + 2)
}.toSeq)
assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCol2"))
val err = intercept[IllegalArgumentException] {
testData.toDF().withColumns(Seq("newCol1"),
Seq(col("key") + 1, col("key") + 2))
}
assert(
err.getMessage.contains("The size of column names: 1 isn't equal to the size of columns: 2"))
val err2 = intercept[AnalysisException] {
testData.toDF().withColumns(Seq("newCol1", "newCOL1"),
Seq(col("key") + 1, col("key") + 2))
}
assert(err2.getMessage.contains("Found duplicate column(s)"))
}
test("withColumns: case sensitive") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
val df = testData.toDF().withColumns(Seq("newCol1", "newCOL1"),
Seq(col("key") + 1, col("key") + 2))
checkAnswer(
df,
testData.collect().map { case Row(key: Int, value: String) =>
Row(key, value, key + 1, key + 2)
}.toSeq)
assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCOL1"))
val err = intercept[AnalysisException] {
testData.toDF().withColumns(Seq("newCol1", "newCol1"),
Seq(col("key") + 1, col("key") + 2))
}
assert(err.getMessage.contains("Found duplicate column(s)"))
}
}
test("withColumns: given metadata") {
def buildMetadata(num: Int): Seq[Metadata] = {
(0 until num).map { n =>
val builder = new MetadataBuilder
builder.putLong("key", n.toLong)
builder.build()
}
}
val df = testData.toDF().withColumns(
Seq("newCol1", "newCol2"),
Seq(col("key") + 1, col("key") + 2),
buildMetadata(2))
df.select("newCol1", "newCol2").schema.zipWithIndex.foreach { case (col, idx) =>
assert(col.metadata.getLong("key").toInt === idx)
}
val err = intercept[IllegalArgumentException] {
testData.toDF().withColumns(
Seq("newCol1", "newCol2"),
Seq(col("key") + 1, col("key") + 2),
buildMetadata(1))
}
assert(err.getMessage.contains(
"The size of column names: 2 isn't equal to the size of metadata elements: 1"))
}
test("replace column using withColumn") {
val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x")
val df3 = df2.withColumn("x", df2("x") + 1)
checkAnswer(
df3.select("x"),
Row(2) :: Row(3) :: Row(4) :: Nil)
}
test("replace column using withColumns") {
val df2 = sparkContext.parallelize(Seq((1, 2), (2, 3), (3, 4))).toDF("x", "y")
val df3 = df2.withColumns(Seq("x", "newCol1", "newCol2"),
Seq(df2("x") + 1, df2("y"), df2("y") + 1))
checkAnswer(
df3.select("x", "newCol1", "newCol2"),
Row(2, 2, 3) :: Row(3, 3, 4) :: Row(4, 4, 5) :: Nil)
}
test("drop column using drop") {
val df = testData.drop("key")
checkAnswer(
df,
testData.collect().map(x => Row(x.getString(1))).toSeq)
assert(df.schema.map(_.name) === Seq("value"))
}
test("drop columns using drop") {
val src = Seq((0, 2, 3)).toDF("a", "b", "c")
val df = src.drop("a", "b")
checkAnswer(df, Row(3))
assert(df.schema.map(_.name) === Seq("c"))
}
test("drop unknown column (no-op)") {
val df = testData.drop("random")
checkAnswer(
df,
testData.collect().toSeq)
assert(df.schema.map(_.name) === Seq("key", "value"))
}
test("drop column using drop with column reference") {
val col = testData("key")
val df = testData.drop(col)
checkAnswer(
df,
testData.collect().map(x => Row(x.getString(1))).toSeq)
assert(df.schema.map(_.name) === Seq("value"))
}
test("SPARK-28189 drop column using drop with column reference with case-insensitive names") {
// With SQL config caseSensitive OFF, case insensitive column name should work
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
val col1 = testData("KEY")
val df1 = testData.drop(col1)
checkAnswer(df1, testData.selectExpr("value"))
assert(df1.schema.map(_.name) === Seq("value"))
val col2 = testData("Key")
val df2 = testData.drop(col2)
checkAnswer(df2, testData.selectExpr("value"))
assert(df2.schema.map(_.name) === Seq("value"))
}
}
test("drop unknown column (no-op) with column reference") {
val col = Column("random")
val df = testData.drop(col)
checkAnswer(
df,
testData.collect().toSeq)
assert(df.schema.map(_.name) === Seq("key", "value"))
}
test("drop unknown column with same name with column reference") {
val col = Column("key")
val df = testData.drop(col)
checkAnswer(
df,
testData.collect().map(x => Row(x.getString(1))).toSeq)
assert(df.schema.map(_.name) === Seq("value"))
}
test("drop column after join with duplicate columns using column reference") {
val newSalary = salary.withColumnRenamed("personId", "id")
val col = newSalary("id")
// this join will result in duplicate "id" columns
val joinedDf = person.join(newSalary,
person("id") === newSalary("id"), "inner")
// remove only the "id" column that was associated with newSalary
val df = joinedDf.drop(col)
checkAnswer(
df,
joinedDf.collect().map {
case Row(id: Int, name: String, age: Int, idToDrop: Int, salary: Double) =>
Row(id, name, age, salary)
}.toSeq)
assert(df.schema.map(_.name) === Seq("id", "name", "age", "salary"))
assert(df("id") == person("id"))
}
test("drop top level columns that contains dot") {
val df1 = Seq((1, 2)).toDF("a.b", "a.c")
checkAnswer(df1.drop("a.b"), Row(2))
// Creates data set: {"a.b": 1, "a": {"b": 3}}
val df2 = Seq((1)).toDF("a.b").withColumn("a", struct(lit(3) as "b"))
// Not like select(), drop() parses the column name "a.b" literally without interpreting "."
checkAnswer(df2.drop("a.b").select("a.b"), Row(3))
// "`" is treated as a normal char here with no interpreting, "`a`b" is a valid column name.
assert(df2.drop("`a.b`").columns.size == 2)
}
test("drop(name: String) search and drop all top level columns that matches the name") {
val df1 = Seq((1, 2)).toDF("a", "b")
val df2 = Seq((3, 4)).toDF("a", "b")
checkAnswer(df1.crossJoin(df2), Row(1, 2, 3, 4))
// Finds and drops all columns that match the name (case insensitive).
checkAnswer(df1.crossJoin(df2).drop("A"), Row(2, 4))
}
test("withColumnRenamed") {
val df = testData.toDF().withColumn("newCol", col("key") + 1)
.withColumnRenamed("value", "valueRenamed")
checkAnswer(
df,
testData.collect().map { case Row(key: Int, value: String) =>
Row(key, value, key + 1)
}.toSeq)
assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
}
private lazy val person2: DataFrame = Seq(
("Bob", 16, 176),
("Alice", 32, 164),
("David", 60, 192),
("Amy", 24, 180)).toDF("name", "age", "height")
private lazy val person3: DataFrame = Seq(
("Luis", 1, 99),
("Luis", 16, 99),
("Luis", 16, 176),
("Fernando", 32, 99),
("Fernando", 32, 164),
("David", 60, 99),
("Amy", 24, 99)).toDF("name", "age", "height")
test("describe") {
val describeResult = Seq(
Row("count", "4", "4", "4"),
Row("mean", null, "33.0", "178.0"),
Row("stddev", null, "19.148542155126762", "11.547005383792516"),
Row("min", "Alice", "16", "164"),
Row("max", "David", "60", "192"))
val emptyDescribeResult = Seq(
Row("count", "0", "0", "0"),
Row("mean", null, null, null),
Row("stddev", null, null, null),
Row("min", null, null, null),
Row("max", null, null, null))
def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name)
val describeAllCols = person2.describe()
assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height"))
checkAnswer(describeAllCols, describeResult)
// All aggregate value should have been cast to string
describeAllCols.collect().foreach { row =>
row.toSeq.foreach { value =>
if (value != null) {
assert(value.isInstanceOf[String], "expected string but found " + value.getClass)
}
}
}
val describeOneCol = person2.describe("age")
assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age"))
checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d)} )
val describeNoCol = person2.select().describe()
assert(getSchemaAsSeq(describeNoCol) === Seq("summary"))
checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _, _) => Row(s)} )
val emptyDescription = person2.limit(0).describe()
assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height"))
checkAnswer(emptyDescription, emptyDescribeResult)
}
test("summary") {
val summaryResult = Seq(
Row("count", "4", "4", "4"),
Row("mean", null, "33.0", "178.0"),
Row("stddev", null, "19.148542155126762", "11.547005383792516"),
Row("min", "Alice", "16", "164"),
Row("25%", null, "16", "164"),
Row("50%", null, "24", "176"),
Row("75%", null, "32", "180"),
Row("max", "David", "60", "192"))
val emptySummaryResult = Seq(
Row("count", "0", "0", "0"),
Row("mean", null, null, null),
Row("stddev", null, null, null),
Row("min", null, null, null),
Row("25%", null, null, null),
Row("50%", null, null, null),
Row("75%", null, null, null),
Row("max", null, null, null))
def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name)
val summaryAllCols = person2.summary()
assert(getSchemaAsSeq(summaryAllCols) === Seq("summary", "name", "age", "height"))
checkAnswer(summaryAllCols, summaryResult)
// All aggregate value should have been cast to string
summaryAllCols.collect().foreach { row =>
row.toSeq.foreach { value =>
if (value != null) {
assert(value.isInstanceOf[String], "expected string but found " + value.getClass)
}
}
}
val summaryOneCol = person2.select("age").summary()
assert(getSchemaAsSeq(summaryOneCol) === Seq("summary", "age"))
checkAnswer(summaryOneCol, summaryResult.map { case Row(s, _, d, _) => Row(s, d)} )
val summaryNoCol = person2.select().summary()
assert(getSchemaAsSeq(summaryNoCol) === Seq("summary"))
checkAnswer(summaryNoCol, summaryResult.map { case Row(s, _, _, _) => Row(s)} )
val emptyDescription = person2.limit(0).summary()
assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height"))
checkAnswer(emptyDescription, emptySummaryResult)
}
test("SPARK-34165: Add count_distinct to summary") {
val summaryDF = person3.summary("count", "count_distinct")
val summaryResult = Seq(
Row("count", "7", "7", "7"),
Row("count_distinct", "4", "5", "3"))
def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name)
assert(getSchemaAsSeq(summaryDF) === Seq("summary", "name", "age", "height"))
checkAnswer(summaryDF, summaryResult)
val approxSummaryDF = person3.summary("count", "approx_count_distinct")
val approxSummaryResult = Seq(
Row("count", "7", "7", "7"),
Row("approx_count_distinct", "4", "5", "3"))
assert(getSchemaAsSeq(summaryDF) === Seq("summary", "name", "age", "height"))
checkAnswer(approxSummaryDF, approxSummaryResult)
}
test("summary advanced") {
val stats = Array("count", "50.01%", "max", "mean", "min", "25%")
val orderMatters = person2.summary(stats: _*)
assert(orderMatters.collect().map(_.getString(0)) === stats)
val onlyPercentiles = person2.summary("0.1%", "99.9%")
assert(onlyPercentiles.count() === 2)
val fooE = intercept[IllegalArgumentException] {
person2.summary("foo")
}
assert(fooE.getMessage === "foo is not a recognised statistic")
val parseE = intercept[IllegalArgumentException] {
person2.summary("foo%")
}
assert(parseE.getMessage === "Unable to parse foo% as a percentile")
}
test("apply on query results (SPARK-5462)") {
val df = testData.sparkSession.sql("select key from testData")
checkAnswer(df.select(df("key")), testData.select("key").collect().toSeq)
}
test("inputFiles") {
Seq("csv", "").foreach { useV1List =>
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> useV1List) {
withTempDir { dir =>
val df = Seq((1, 22)).toDF("a", "b")
val parquetDir = new File(dir, "parquet").getCanonicalPath
df.write.parquet(parquetDir)
val parquetDF = spark.read.parquet(parquetDir)
assert(parquetDF.inputFiles.nonEmpty)
val csvDir = new File(dir, "csv").getCanonicalPath
df.write.json(csvDir)
val csvDF = spark.read.json(csvDir)
assert(csvDF.inputFiles.nonEmpty)
val unioned = csvDF.union(parquetDF).inputFiles.sorted
val allFiles = (csvDF.inputFiles ++ parquetDF.inputFiles).distinct.sorted
assert(unioned === allFiles)
}
}
}
}
ignore("show") {
// This test case is intended ignored, but to make sure it compiles correctly
testData.select($"*").show()
testData.select($"*").show(1000)
}
test("getRows: truncate = [0, 20]") {
val longString = Array.fill(21)("1").mkString
val df = sparkContext.parallelize(Seq("1", longString)).toDF()
val expectedAnswerForFalse = Seq(
Seq("value"),
Seq("1"),
Seq("111111111111111111111"))
assert(df.getRows(10, 0) === expectedAnswerForFalse)
val expectedAnswerForTrue = Seq(
Seq("value"),
Seq("1"),
Seq("11111111111111111..."))
assert(df.getRows(10, 20) === expectedAnswerForTrue)
}
test("getRows: truncate = [3, 17]") {
val longString = Array.fill(21)("1").mkString
val df = sparkContext.parallelize(Seq("1", longString)).toDF()
val expectedAnswerForFalse = Seq(
Seq("value"),
Seq("1"),
Seq("111"))
assert(df.getRows(10, 3) === expectedAnswerForFalse)
val expectedAnswerForTrue = Seq(
Seq("value"),
Seq("1"),
Seq("11111111111111..."))
assert(df.getRows(10, 17) === expectedAnswerForTrue)
}
test("getRows: numRows = 0") {
val expectedAnswer = Seq(Seq("key", "value"), Seq("1", "1"))
assert(testData.select($"*").getRows(0, 20) === expectedAnswer)
}
test("getRows: array") {
val df = Seq(
(Array(1, 2, 3), Array(1, 2, 3)),
(Array(2, 3, 4), Array(2, 3, 4))
).toDF()
val expectedAnswer = Seq(
Seq("_1", "_2"),
Seq("[1, 2, 3]", "[1, 2, 3]"),
Seq("[2, 3, 4]", "[2, 3, 4]"))
assert(df.getRows(10, 20) === expectedAnswer)
}
test("getRows: binary") {
val df = Seq(
("12".getBytes(StandardCharsets.UTF_8), "ABC.".getBytes(StandardCharsets.UTF_8)),
("34".getBytes(StandardCharsets.UTF_8), "12346".getBytes(StandardCharsets.UTF_8))
).toDF()
val expectedAnswer = Seq(
Seq("_1", "_2"),
Seq("[31 32]", "[41 42 43 2E]"),
Seq("[33 34]", "[31 32 33 34 36]"))
assert(df.getRows(10, 20) === expectedAnswer)
}
test("showString: truncate = [0, 20]") {
val longString = Array.fill(21)("1").mkString
val df = sparkContext.parallelize(Seq("1", longString)).toDF()
val expectedAnswerForFalse = """+---------------------+
||value |
|+---------------------+
||1 |
||111111111111111111111|
|+---------------------+
|""".stripMargin
assert(df.showString(10, truncate = 0) === expectedAnswerForFalse)
val expectedAnswerForTrue = """+--------------------+
|| value|
|+--------------------+
|| 1|
||11111111111111111...|
|+--------------------+
|""".stripMargin
assert(df.showString(10, truncate = 20) === expectedAnswerForTrue)
}
test("showString: truncate = [0, 20], vertical = true") {
val longString = Array.fill(21)("1").mkString
val df = sparkContext.parallelize(Seq("1", longString)).toDF()
val expectedAnswerForFalse = "-RECORD 0----------------------\n" +
" value | 1 \n" +
"-RECORD 1----------------------\n" +
" value | 111111111111111111111 \n"
assert(df.showString(10, truncate = 0, vertical = true) === expectedAnswerForFalse)
val expectedAnswerForTrue = "-RECORD 0---------------------\n" +
" value | 1 \n" +
"-RECORD 1---------------------\n" +
" value | 11111111111111111... \n"
assert(df.showString(10, truncate = 20, vertical = true) === expectedAnswerForTrue)
}
test("showString: truncate = [3, 17]") {
val longString = Array.fill(21)("1").mkString
val df = sparkContext.parallelize(Seq("1", longString)).toDF()
val expectedAnswerForFalse = """+-----+
||value|
|+-----+
|| 1|
|| 111|
|+-----+
|""".stripMargin
assert(df.showString(10, truncate = 3) === expectedAnswerForFalse)
val expectedAnswerForTrue = """+-----------------+
|| value|
|+-----------------+
|| 1|
||11111111111111...|
|+-----------------+
|""".stripMargin
assert(df.showString(10, truncate = 17) === expectedAnswerForTrue)
}
test("showString: truncate = [3, 17], vertical = true") {
val longString = Array.fill(21)("1").mkString
val df = sparkContext.parallelize(Seq("1", longString)).toDF()
val expectedAnswerForFalse = "-RECORD 0----\n" +
" value | 1 \n" +
"-RECORD 1----\n" +
" value | 111 \n"
assert(df.showString(10, truncate = 3, vertical = true) === expectedAnswerForFalse)
val expectedAnswerForTrue = "-RECORD 0------------------\n" +
" value | 1 \n" +
"-RECORD 1------------------\n" +
" value | 11111111111111... \n"
assert(df.showString(10, truncate = 17, vertical = true) === expectedAnswerForTrue)
}
test("showString(negative)") {
val expectedAnswer = """+---+-----+
||key|value|
|+---+-----+
|+---+-----+
|only showing top 0 rows
|""".stripMargin
assert(testData.select($"*").showString(-1) === expectedAnswer)
}
test("showString(negative), vertical = true") {
val expectedAnswer = "(0 rows)\n"
assert(testData.select($"*").showString(-1, vertical = true) === expectedAnswer)
}
test("showString(0)") {
val expectedAnswer = """+---+-----+
||key|value|
|+---+-----+
|+---+-----+
|only showing top 0 rows
|""".stripMargin
assert(testData.select($"*").showString(0) === expectedAnswer)
}
test("showString(Int.MaxValue)") {
val df = Seq((1, 2), (3, 4)).toDF("a", "b")
val expectedAnswer = """+---+---+
|| a| b|
|+---+---+
|| 1| 2|
|| 3| 4|
|+---+---+
|""".stripMargin
assert(df.showString(Int.MaxValue) === expectedAnswer)
}
test("showString(0), vertical = true") {
val expectedAnswer = "(0 rows)\n"
assert(testData.select($"*").showString(0, vertical = true) === expectedAnswer)
}
test("showString: array") {
val df = Seq(
(Array(1, 2, 3), Array(1, 2, 3)),
(Array(2, 3, 4), Array(2, 3, 4))
).toDF()
val expectedAnswer = """+---------+---------+
|| _1| _2|
|+---------+---------+
||[1, 2, 3]|[1, 2, 3]|
||[2, 3, 4]|[2, 3, 4]|
|+---------+---------+
|""".stripMargin
assert(df.showString(10) === expectedAnswer)
}
test("showString: array, vertical = true") {
val df = Seq(
(Array(1, 2, 3), Array(1, 2, 3)),
(Array(2, 3, 4), Array(2, 3, 4))
).toDF()
val expectedAnswer = "-RECORD 0--------\n" +
" _1 | [1, 2, 3] \n" +
" _2 | [1, 2, 3] \n" +
"-RECORD 1--------\n" +
" _1 | [2, 3, 4] \n" +
" _2 | [2, 3, 4] \n"
assert(df.showString(10, vertical = true) === expectedAnswer)
}
test("showString: binary") {
val df = Seq(
("12".getBytes(StandardCharsets.UTF_8), "ABC.".getBytes(StandardCharsets.UTF_8)),
("34".getBytes(StandardCharsets.UTF_8), "12346".getBytes(StandardCharsets.UTF_8))
).toDF()
val expectedAnswer = """+-------+----------------+
|| _1| _2|
|+-------+----------------+
||[31 32]| [41 42 43 2E]|
||[33 34]|[31 32 33 34 36]|
|+-------+----------------+
|""".stripMargin
assert(df.showString(10) === expectedAnswer)
}
test("showString: binary, vertical = true") {
val df = Seq(
("12".getBytes(StandardCharsets.UTF_8), "ABC.".getBytes(StandardCharsets.UTF_8)),
("34".getBytes(StandardCharsets.UTF_8), "12346".getBytes(StandardCharsets.UTF_8))
).toDF()
val expectedAnswer = "-RECORD 0---------------\n" +
" _1 | [31 32] \n" +
" _2 | [41 42 43 2E] \n" +
"-RECORD 1---------------\n" +
" _1 | [33 34] \n" +
" _2 | [31 32 33 34 36] \n"
assert(df.showString(10, vertical = true) === expectedAnswer)
}
test("showString: minimum column width") {
val df = Seq(
(1, 1),
(2, 2)
).toDF()
val expectedAnswer = """+---+---+
|| _1| _2|
|+---+---+
|| 1| 1|
|| 2| 2|
|+---+---+
|""".stripMargin
assert(df.showString(10) === expectedAnswer)
}
test("showString: minimum column width, vertical = true") {
val df = Seq(
(1, 1),
(2, 2)
).toDF()
val expectedAnswer = "-RECORD 0--\n" +
" _1 | 1 \n" +
" _2 | 1 \n" +
"-RECORD 1--\n" +
" _1 | 2 \n" +
" _2 | 2 \n"
assert(df.showString(10, vertical = true) === expectedAnswer)
}
test("SPARK-33690: showString: escape meta-characters") {
val df1 = spark.sql("SELECT 'aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh'")
assert(df1.showString(1, truncate = 0) ===
"""+--------------------------------------+
||aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh|
|+--------------------------------------+
||aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh|
|+--------------------------------------+
|""".stripMargin)
val df2 = spark.sql("SELECT array('aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh')")
assert(df2.showString(1, truncate = 0) ===
"""+---------------------------------------------+
||array(aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh)|
|+---------------------------------------------+
||[aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh] |
|+---------------------------------------------+
|""".stripMargin)
val df3 =
spark.sql("SELECT map('aaa\nbbb\tccc', 'aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh')")
assert(df3.showString(1, truncate = 0) ===
"""+----------------------------------------------------------+
||map(aaa\nbbb\tccc, aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh)|
|+----------------------------------------------------------+
||{aaa\nbbb\tccc -> aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh} |
|+----------------------------------------------------------+
|""".stripMargin)
val df4 =
spark.sql("SELECT named_struct('v', 'aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh')")
assert(df4.showString(1, truncate = 0) ===
"""+-------------------------------------------------------+
||named_struct(v, aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh)|
|+-------------------------------------------------------+
||{aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh} |
|+-------------------------------------------------------+
|""".stripMargin)
}
test("SPARK-34308: printSchema: escape meta-characters") {
val captured = new ByteArrayOutputStream()
val df1 = spark.sql("SELECT 'aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh'")
Console.withOut(captured) {
df1.printSchema()
}
assert(captured.toString ===
"""root
| |-- aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh: string (nullable = false)
|
|""".stripMargin)
captured.reset()
val df2 = spark.sql("SELECT array('aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh')")
Console.withOut(captured) {
df2.printSchema()
}
assert(captured.toString ===
"""root
| |-- array(aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh): array (nullable = false)
| | |-- element: string (containsNull = false)
|
|""".stripMargin)
captured.reset()
val df3 =
spark.sql("SELECT map('aaa\nbbb\tccc', 'aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh')")
Console.withOut(captured) {
df3.printSchema()
}
assert(captured.toString ===
"""root
| |-- map(aaa\nbbb\tccc, aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh): map (nullable = false)
| | |-- key: string
| | |-- value: string (valueContainsNull = false)
|
|""".stripMargin)
captured.reset()
val df4 =
spark.sql("SELECT named_struct('v', 'aaa\nbbb\tccc\rddd\feee\bfff\u000Bggg\u0007hhh')")
Console.withOut(captured) {
df4.printSchema()
}
assert(captured.toString ===
"""root
| |-- named_struct(v, aaa\nbbb\tccc\rddd\feee\bfff\vggg\ahhh): struct (nullable = false)
| | |-- v: string (nullable = false)
|
|""".stripMargin)
}
test("SPARK-7319 showString") {
val expectedAnswer = """+---+-----+
||key|value|
|+---+-----+
|| 1| 1|
|+---+-----+
|only showing top 1 row
|""".stripMargin
assert(testData.select($"*").showString(1) === expectedAnswer)
}
test("SPARK-7319 showString, vertical = true") {
val expectedAnswer = "-RECORD 0----\n" +
" key | 1 \n" +
" value | 1 \n" +
"only showing top 1 row\n"
assert(testData.select($"*").showString(1, vertical = true) === expectedAnswer)
}
test("SPARK-23023 Cast rows to strings in showString") {
val df1 = Seq(Seq(1, 2, 3, 4)).toDF("a")
assert(df1.showString(10) ===
s"""+------------+
|| a|
|+------------+
||[1, 2, 3, 4]|
|+------------+
|""".stripMargin)
val df2 = Seq(Map(1 -> "a", 2 -> "b")).toDF("a")
assert(df2.showString(10) ===
s"""+----------------+
|| a|
|+----------------+
||{1 -> a, 2 -> b}|
|+----------------+
|""".stripMargin)
val df3 = Seq(((1, "a"), 0), ((2, "b"), 0)).toDF("a", "b")
assert(df3.showString(10) ===
s"""+------+---+
|| a| b|
|+------+---+
||{1, a}| 0|
||{2, b}| 0|
|+------+---+
|""".stripMargin)
}
test("SPARK-7327 show with empty dataFrame") {
val expectedAnswer = """+---+-----+
||key|value|
|+---+-----+
|+---+-----+
|""".stripMargin
assert(testData.select($"*").filter($"key" < 0).showString(1) === expectedAnswer)
}
test("SPARK-7327 show with empty dataFrame, vertical = true") {
assert(testData.select($"*").filter($"key" < 0).showString(1, vertical = true) === "(0 rows)\n")
}
test("SPARK-18350 show with session local timezone") {
val d = Date.valueOf("2016-12-01")
val ts = Timestamp.valueOf("2016-12-01 00:00:00")
val df = Seq((d, ts)).toDF("d", "ts")
val expectedAnswer = """+----------+-------------------+
||d |ts |
|+----------+-------------------+
||2016-12-01|2016-12-01 00:00:00|
|+----------+-------------------+
|""".stripMargin
assert(df.showString(1, truncate = 0) === expectedAnswer)
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
val expectedAnswer = """+----------+-------------------+
||d |ts |
|+----------+-------------------+
||2016-12-01|2016-12-01 08:00:00|
|+----------+-------------------+
|""".stripMargin
assert(df.showString(1, truncate = 0) === expectedAnswer)
}
}
test("SPARK-18350 show with session local timezone, vertical = true") {
val d = Date.valueOf("2016-12-01")
val ts = Timestamp.valueOf("2016-12-01 00:00:00")
val df = Seq((d, ts)).toDF("d", "ts")
val expectedAnswer = "-RECORD 0------------------\n" +
" d | 2016-12-01 \n" +
" ts | 2016-12-01 00:00:00 \n"
assert(df.showString(1, truncate = 0, vertical = true) === expectedAnswer)
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
val expectedAnswer = "-RECORD 0------------------\n" +
" d | 2016-12-01 \n" +
" ts | 2016-12-01 08:00:00 \n"
assert(df.showString(1, truncate = 0, vertical = true) === expectedAnswer)
}
}
test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") {
val rowRDD = sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0))))
val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false)))
val df = spark.createDataFrame(rowRDD, schema)
df.rdd.collect()
}
test("SPARK-6899: type should match when using codegen") {
checkAnswer(decimalData.agg(avg("a")), Row(new java.math.BigDecimal(2)))
}
test("SPARK-7133: Implement struct, array, and map field accessor") {
assert(complexData.filter(complexData("a")(0) === 2).count() == 1)
assert(complexData.filter(complexData("m")("1") === 1).count() == 1)
assert(complexData.filter(complexData("s")("key") === 1).count() == 1)
assert(complexData.filter(complexData("m")(complexData("s")("value")) === 1).count() == 1)
assert(complexData.filter(complexData("a")(complexData("s")("key")) === 1).count() == 1)
}
test("SPARK-7551: support backticks for DataFrame attribute resolution") {
withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") {
val df = spark.read.json(Seq("""{"a.b": {"c": {"d..e": {"f": 1}}}}""").toDS())
checkAnswer(
df.select(df("`a.b`.c.`d..e`.`f`")),
Row(1)
)
val df2 = spark.read.json(Seq("""{"a b": {"c": {"d e": {"f": 1}}}}""").toDS())
checkAnswer(
df2.select(df2("`a b`.c.d e.f")),
Row(1)
)
def checkError(testFun: => Unit): Unit = {
val e = intercept[org.apache.spark.sql.AnalysisException] {
testFun
}
assert(e.getMessage.contains("syntax error in attribute name:"))
}
checkError(df("`abc.`c`"))
checkError(df("`abc`..d"))
checkError(df("`a`.b."))
checkError(df("`a.b`.c.`d"))
}
}
test("SPARK-7324 dropDuplicates") {
val testData = sparkContext.parallelize(
(2, 1, 2) :: (1, 1, 1) ::
(1, 2, 1) :: (2, 1, 2) ::
(2, 2, 2) :: (2, 2, 1) ::
(2, 1, 1) :: (1, 1, 2) ::
(1, 2, 2) :: (1, 2, 1) :: Nil).toDF("key", "value1", "value2")
checkAnswer(
testData.dropDuplicates(),
Seq(Row(2, 1, 2), Row(1, 1, 1), Row(1, 2, 1),
Row(2, 2, 2), Row(2, 1, 1), Row(2, 2, 1),
Row(1, 1, 2), Row(1, 2, 2)))
checkAnswer(
testData.dropDuplicates(Seq("key", "value1")),
Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2)))
checkAnswer(
testData.dropDuplicates(Seq("value1", "value2")),
Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2)))
checkAnswer(
testData.dropDuplicates(Seq("key")),
Seq(Row(2, 1, 2), Row(1, 1, 1)))
checkAnswer(
testData.dropDuplicates(Seq("value1")),
Seq(Row(2, 1, 2), Row(1, 2, 1)))
checkAnswer(
testData.dropDuplicates(Seq("value2")),
Seq(Row(2, 1, 2), Row(1, 1, 1)))
checkAnswer(
testData.dropDuplicates("key", "value1"),
Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2)))
}
test("SPARK-8621: support empty string column name") {
val df = Seq(Tuple1(1)).toDF("").as("t")
// We should allow empty string as column name
df.col("")
df.col("t.``")
}
test("SPARK-8797: sort by float column containing NaN should not crash") {
val inputData = Seq.fill(10)(Tuple1(Float.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toFloat))
val df = Random.shuffle(inputData).toDF("a")
df.orderBy("a").collect()
}
test("SPARK-8797: sort by double column containing NaN should not crash") {
val inputData = Seq.fill(10)(Tuple1(Double.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toDouble))
val df = Random.shuffle(inputData).toDF("a")
df.orderBy("a").collect()
}
test("NaN is greater than all other non-NaN numeric values") {
val maxDouble = Seq(Double.NaN, Double.PositiveInfinity, Double.MaxValue)
.map(Tuple1.apply).toDF("a").selectExpr("max(a)").first()
assert(java.lang.Double.isNaN(maxDouble.getDouble(0)))
val maxFloat = Seq(Float.NaN, Float.PositiveInfinity, Float.MaxValue)
.map(Tuple1.apply).toDF("a").selectExpr("max(a)").first()
assert(java.lang.Float.isNaN(maxFloat.getFloat(0)))
}
test("SPARK-8072: Better Exception for Duplicate Columns") {
// only one duplicate column present
val e = intercept[org.apache.spark.sql.AnalysisException] {
Seq((1, 2, 3), (2, 3, 4), (3, 4, 5)).toDF("column1", "column2", "column1")
.write.format("parquet").save("temp")
}
assert(e.getMessage.contains("Found duplicate column(s) when inserting into"))
assert(e.getMessage.contains("column1"))
assert(!e.getMessage.contains("column2"))
// multiple duplicate columns present
val f = intercept[org.apache.spark.sql.AnalysisException] {
Seq((1, 2, 3, 4, 5), (2, 3, 4, 5, 6), (3, 4, 5, 6, 7))
.toDF("column1", "column2", "column3", "column1", "column3")
.write.format("json").save("temp")
}
assert(f.getMessage.contains("Found duplicate column(s) when inserting into"))
assert(f.getMessage.contains("column1"))
assert(f.getMessage.contains("column3"))
assert(!f.getMessage.contains("column2"))
}
test("SPARK-6941: Better error message for inserting into RDD-based Table") {
withTempDir { dir =>
withTempView("parquet_base", "json_base", "rdd_base", "indirect_ds", "one_row") {
val tempParquetFile = new File(dir, "tmp_parquet")
val tempJsonFile = new File(dir, "tmp_json")
val df = Seq(Tuple1(1)).toDF()
val insertion = Seq(Tuple1(2)).toDF("col")
// pass case: parquet table (HadoopFsRelation)
df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath)
val pdf = spark.read.parquet(tempParquetFile.getCanonicalPath)
pdf.createOrReplaceTempView("parquet_base")
insertion.write.insertInto("parquet_base")
// pass case: json table (InsertableRelation)
df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath)
val jdf = spark.read.json(tempJsonFile.getCanonicalPath)
jdf.createOrReplaceTempView("json_base")
insertion.write.mode(SaveMode.Overwrite).insertInto("json_base")
// error cases: insert into an RDD
df.createOrReplaceTempView("rdd_base")
val e1 = intercept[AnalysisException] {
insertion.write.insertInto("rdd_base")
}
assert(e1.getMessage.contains("Inserting into an RDD-based table is not allowed."))
// error case: insert into a logical plan that is not a LeafNode
val indirectDS = pdf.select("_1").filter($"_1" > 5)
indirectDS.createOrReplaceTempView("indirect_ds")
val e2 = intercept[AnalysisException] {
insertion.write.insertInto("indirect_ds")
}
assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed."))
// error case: insert into an OneRowRelation
Dataset.ofRows(spark, OneRowRelation()).createOrReplaceTempView("one_row")
val e3 = intercept[AnalysisException] {
insertion.write.insertInto("one_row")
}
assert(e3.getMessage.contains("Inserting into an RDD-based table is not allowed."))
}
}
}
test("SPARK-8608: call `show` on local DataFrame with random columns should return same value") {
val df = testData.select(rand(33))
assert(df.showString(5) == df.showString(5))
// We will reuse the same Expression object for LocalRelation.
val df1 = (1 to 10).map(Tuple1.apply).toDF().select(rand(33))
assert(df1.showString(5) == df1.showString(5))
}
test("SPARK-8609: local DataFrame with random columns should return same value after sort") {
checkAnswer(testData.sort(rand(33)), testData.sort(rand(33)))
// We will reuse the same Expression object for LocalRelation.
val df = (1 to 10).map(Tuple1.apply).toDF()
checkAnswer(df.sort(rand(33)), df.sort(rand(33)))
}
test("SPARK-9083: sort with non-deterministic expressions") {
val seed = 33
val df = (1 to 100).map(Tuple1.apply).toDF("i").repartition(1)
val random = new XORShiftRandom(seed)
val expected = (1 to 100).map(_ -> random.nextDouble()).sortBy(_._2).map(_._1)
val actual = df.sort(rand(seed)).collect().map(_.getInt(0))
assert(expected === actual)
}
test("Sorting columns are not in Filter and Project") {
checkAnswer(
upperCaseData.filter($"N" > 1).select("N").filter($"N" < 6).orderBy($"L".asc),
Row(2) :: Row(3) :: Row(4) :: Row(5) :: Nil)
}
test("SPARK-9323: DataFrame.orderBy should support nested column name") {
val df = spark.read.json(Seq("""{"a": {"b": 1}}""").toDS())
checkAnswer(df.orderBy("a.b"), Row(Row(1)))
}
test("SPARK-9950: correctly analyze grouping/aggregating on struct fields") {
val df = Seq(("x", (1, 1)), ("y", (2, 2))).toDF("a", "b")
checkAnswer(df.groupBy("b._1").agg(sum("b._2")), Row(1, 1) :: Row(2, 2) :: Nil)
}
test("SPARK-10093: Avoid transformations on executors") {
val df = Seq((1, 1)).toDF("a", "b")
df.where($"a" === 1)
.select($"a", $"b", struct($"b"))
.orderBy("a")
.select(struct($"b"))
.collect()
}
test("SPARK-10185: Read multiple Hadoop Filesystem paths and paths with a comma in it") {
withTempDir { dir =>
val df1 = Seq((1, 22)).toDF("a", "b")
val dir1 = new File(dir, "dir,1").getCanonicalPath
df1.write.format("json").save(dir1)
val df2 = Seq((2, 23)).toDF("a", "b")
val dir2 = new File(dir, "dir2").getCanonicalPath
df2.write.format("json").save(dir2)
checkAnswer(spark.read.format("json").load(dir1, dir2),
Row(1, 22) :: Row(2, 23) :: Nil)
checkAnswer(spark.read.format("json").load(dir1),
Row(1, 22) :: Nil)
}
}
test("Alias uses internally generated names 'aggOrder' and 'havingCondition'") {
val df = Seq(1 -> 2).toDF("i", "j")
val query1 = df.groupBy("i")
.agg(max("j").as("aggOrder"))
.orderBy(sum("j"))
checkAnswer(query1, Row(1, 2))
// In the plan, there are two attributes having the same name 'havingCondition'
// One is a user-provided alias name; another is an internally generated one.
val query2 = df.groupBy("i")
.agg(max("j").as("havingCondition"))
.where(sum("j") > 0)
.orderBy($"havingCondition".asc)
checkAnswer(query2, Row(1, 2))
}
test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") {
withTempDir { dir =>
(1 to 10).toDF("id").write.mode(SaveMode.Overwrite).json(dir.getCanonicalPath)
val input = spark.read.json(dir.getCanonicalPath)
val df = input.select($"id", rand(0).as("r"))
df.as("a").join(df.filter($"r" < 0.5).as("b"), $"a.id" === $"b.id").collect().foreach { row =>
assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001)
}
}
}
test("SPARK-10743: keep the name of expression if possible when do cast") {
val df = (1 to 10).map(Tuple1.apply).toDF("i").as("src")
assert(df.select($"src.i".cast(StringType)).columns.head === "i")
assert(df.select($"src.i".cast(StringType).cast(IntegerType)).columns.head === "i")
}
test("SPARK-11301: fix case sensitivity for filter on partitioned columns") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
withTempPath { path =>
Seq(2012 -> "a").toDF("year", "val").write.partitionBy("year").parquet(path.getAbsolutePath)
val df = spark.read.parquet(path.getAbsolutePath)
checkAnswer(df.filter($"yEAr" > 2000).select($"val"), Row("a"))
}
}
}
/**
* Verifies that there is no Exchange between the Aggregations for `df`
*/
private def verifyNonExchangingAgg(df: DataFrame) = {
var atFirstAgg: Boolean = false
df.queryExecution.executedPlan.foreach {
case agg: HashAggregateExec =>
atFirstAgg = !atFirstAgg
case _ =>
if (atFirstAgg) {
fail("Should not have operators between the two aggregations")
}
}
}
/**
* Verifies that there is an Exchange between the Aggregations for `df`
*/
private def verifyExchangingAgg(df: DataFrame) = {
var atFirstAgg: Boolean = false
df.queryExecution.executedPlan.foreach {
case agg: HashAggregateExec =>
if (atFirstAgg) {
fail("Should not have back to back Aggregates")
}
atFirstAgg = true
case e: ShuffleExchangeExec => atFirstAgg = false
case _ =>
}
}
test("distributeBy and localSort") {
val original = testData.repartition(1)
assert(original.rdd.partitions.length == 1)
val df = original.repartition(5, $"key")
assert(df.rdd.partitions.length == 5)
checkAnswer(original.select(), df.select())
val df2 = original.repartition(10, $"key")
assert(df2.rdd.partitions.length == 10)
checkAnswer(original.select(), df2.select())
// Group by the column we are distributed by. This should generate a plan with no exchange
// between the aggregates
val df3 = testData.repartition($"key").groupBy("key").count()
verifyNonExchangingAgg(df3)
verifyNonExchangingAgg(testData.repartition($"key", $"value")
.groupBy("key", "value").count())
// Grouping by just the first distributeBy expr, need to exchange.
verifyExchangingAgg(testData.repartition($"key", $"value")
.groupBy("key").count())
val data = spark.sparkContext.parallelize(
(1 to 100).map(i => TestData2(i % 10, i))).toDF()
// Distribute and order by.
val df4 = data.repartition(5, $"a").sortWithinPartitions($"b".desc)
// Walk each partition and verify that it is sorted descending and does not contain all
// the values.
df4.rdd.foreachPartition { p =>
// Skip empty partition
if (p.hasNext) {
var previousValue: Int = -1
var allSequential: Boolean = true
p.foreach { r =>
val v: Int = r.getInt(1)
if (previousValue != -1) {
if (previousValue < v) throw new SparkException("Partition is not ordered.")
if (v + 1 != previousValue) allSequential = false
}
previousValue = v
}
if (allSequential) throw new SparkException("Partition should not be globally ordered")
}
}
// Distribute and order by with multiple order bys
val df5 = data.repartition(2, $"a").sortWithinPartitions($"b".asc, $"a".asc)
// Walk each partition and verify that it is sorted ascending
df5.rdd.foreachPartition { p =>
var previousValue: Int = -1
var allSequential: Boolean = true
p.foreach { r =>
val v: Int = r.getInt(1)
if (previousValue != -1) {
if (previousValue > v) throw new SparkException("Partition is not ordered.")
if (v - 1 != previousValue) allSequential = false
}
previousValue = v
}
if (allSequential) throw new SparkException("Partition should not be all sequential")
}
// Distribute into one partition and order by. This partition should contain all the values.
val df6 = data.repartition(1, $"a").sortWithinPartitions("b")
// Walk each partition and verify that it is sorted ascending and not globally sorted.
df6.rdd.foreachPartition { p =>
var previousValue: Int = -1
var allSequential: Boolean = true
p.foreach { r =>
val v: Int = r.getInt(1)
if (previousValue != -1) {
if (previousValue > v) throw new SparkException("Partition is not ordered.")
if (v - 1 != previousValue) allSequential = false
}
previousValue = v
}
if (!allSequential) throw new SparkException("Partition should contain all sequential values")
}
}
test("fix case sensitivity of partition by") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
withTempPath { path =>
val p = path.getAbsolutePath
Seq(2012 -> "a").toDF("year", "val").write.partitionBy("yEAr").parquet(p)
checkAnswer(spark.read.parquet(p).select("YeaR"), Row(2012))
}
}
}
// This test case is to verify a bug when making a new instance of LogicalRDD.
test("SPARK-11633: LogicalRDD throws TreeNode Exception: Failed to Copy Node") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
val rdd = sparkContext.makeRDD(Seq(Row(1, 3), Row(2, 1)))
val df = spark.createDataFrame(
rdd,
new StructType().add("f1", IntegerType).add("f2", IntegerType))
.select($"F1", $"f2".as("f2"))
val df1 = df.as("a")
val df2 = df.as("b")
checkAnswer(df1.join(df2, $"a.f2" === $"b.f2"), Row(1, 3, 1, 3) :: Row(2, 1, 2, 1) :: Nil)
}
}
test("SPARK-10656: completely support special chars") {
val df = Seq(1 -> "a").toDF("i_$.a", "d^'a.")
checkAnswer(df.select(df("*")), Row(1, "a"))
checkAnswer(df.withColumnRenamed("d^'a.", "a"), Row(1, "a"))
}
test("SPARK-11725: correctly handle null inputs for ScalaUDF") {
val df = sparkContext.parallelize(Seq(
java.lang.Integer.valueOf(22) -> "John",
null.asInstanceOf[java.lang.Integer] -> "Lucy")).toDF("age", "name")
// passing null into the UDF that could handle it
val boxedUDF = udf[java.lang.Integer, java.lang.Integer] {
(i: java.lang.Integer) => if (i == null) -10 else null
}
checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil)
spark.udf.register("boxedUDF",
(i: java.lang.Integer) => (if (i == null) -10 else null): java.lang.Integer)
checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) :: Nil)
val primitiveUDF = udf((i: Int) => i * 2)
checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil)
}
test("SPARK-12398 truncated toString") {
val df1 = Seq((1L, "row1")).toDF("id", "name")
assert(df1.toString() === "[id: bigint, name: string]")
val df2 = Seq((1L, "c2", false)).toDF("c1", "c2", "c3")
assert(df2.toString === "[c1: bigint, c2: string ... 1 more field]")
val df3 = Seq((1L, "c2", false, 10)).toDF("c1", "c2", "c3", "c4")
assert(df3.toString === "[c1: bigint, c2: string ... 2 more fields]")
val df4 = Seq((1L, Tuple2(1L, "val"))).toDF("c1", "c2")
assert(df4.toString === "[c1: bigint, c2: struct<_1: bigint, _2: string>]")
val df5 = Seq((1L, Tuple2(1L, "val"), 20.0)).toDF("c1", "c2", "c3")
assert(df5.toString === "[c1: bigint, c2: struct<_1: bigint, _2: string> ... 1 more field]")
val df6 = Seq((1L, Tuple2(1L, "val"), 20.0, 1)).toDF("c1", "c2", "c3", "c4")
assert(df6.toString === "[c1: bigint, c2: struct<_1: bigint, _2: string> ... 2 more fields]")
val df7 = Seq((1L, Tuple3(1L, "val", 2), 20.0, 1)).toDF("c1", "c2", "c3", "c4")
assert(
df7.toString ===
"[c1: bigint, c2: struct<_1: bigint, _2: string ... 1 more field> ... 2 more fields]")
val df8 = Seq((1L, Tuple7(1L, "val", 2, 3, 4, 5, 6), 20.0, 1)).toDF("c1", "c2", "c3", "c4")
assert(
df8.toString ===
"[c1: bigint, c2: struct<_1: bigint, _2: string ... 5 more fields> ... 2 more fields]")
val df9 =
Seq((1L, Tuple4(1L, Tuple4(1L, 2L, 3L, 4L), 2L, 3L), 20.0, 1)).toDF("c1", "c2", "c3", "c4")
assert(
df9.toString ===
"[c1: bigint, c2: struct<_1: bigint," +
" _2: struct<_1: bigint," +
" _2: bigint ... 2 more fields> ... 2 more fields> ... 2 more fields]")
}
test("reuse exchange") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2") {
val df = spark.range(100).toDF()
val join = df.join(df, "id")
val plan = join.queryExecution.executedPlan
checkAnswer(join, df)
assert(
collect(join.queryExecution.executedPlan) {
case e: ShuffleExchangeExec => true }.size === 1)
assert(
collect(join.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size === 1)
val broadcasted = broadcast(join)
val join2 = join.join(broadcasted, "id").join(broadcasted, "id")
checkAnswer(join2, df)
assert(
collect(join2.queryExecution.executedPlan) {
case e: ShuffleExchangeExec => true }.size == 1)
assert(
collect(join2.queryExecution.executedPlan) {
case e: BroadcastExchangeExec => true }.size === 1)
assert(
collect(join2.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size == 4)
}
}
test("sameResult() on aggregate") {
val df = spark.range(100)
val agg1 = df.groupBy().count()
val agg2 = df.groupBy().count()
// two aggregates with different ExprId within them should have same result
assert(agg1.queryExecution.executedPlan.sameResult(agg2.queryExecution.executedPlan))
val agg3 = df.groupBy().sum()
assert(!agg1.queryExecution.executedPlan.sameResult(agg3.queryExecution.executedPlan))
val df2 = spark.range(101)
val agg4 = df2.groupBy().count()
assert(!agg1.queryExecution.executedPlan.sameResult(agg4.queryExecution.executedPlan))
}
test("SPARK-12512: support `.` in column name for withColumn()") {
val df = Seq("a" -> "b").toDF("col.a", "col.b")
checkAnswer(df.select(df("*")), Row("a", "b"))
checkAnswer(df.withColumn("col.a", lit("c")), Row("c", "b"))
checkAnswer(df.withColumn("col.c", lit("c")), Row("a", "b", "c"))
}
test("SPARK-12841: cast in filter") {
checkAnswer(
Seq(1 -> "a").toDF("i", "j").filter($"i".cast(StringType) === "1"),
Row(1, "a"))
}
test("SPARK-12982: Add table name validation in temp table registration") {
val df = Seq("foo", "bar").map(Tuple1.apply).toDF("col")
// invalid table names
Seq("11111", "t~", "#$@sum", "table!#").foreach { name =>
withTempView(name) {
val m = intercept[AnalysisException](df.createOrReplaceTempView(name)).getMessage
assert(m.contains(s"Invalid view name: $name"))
}
}
// valid table names
Seq("table1", "`11111`", "`t~`", "`#$@sum`", "`table!#`").foreach { name =>
withTempView(name) {
df.createOrReplaceTempView(name)
}
}
}
test("assertAnalyzed shouldn't replace original stack trace") {
val e = intercept[AnalysisException] {
spark.range(1).select($"id" as "a", $"id" as "b").groupBy("a").agg($"b")
}
assert(e.getStackTrace.head.getClassName != classOf[QueryExecution].getName)
}
test("SPARK-13774: Check error message for non existent path without globbed paths") {
val uuid = UUID.randomUUID().toString
val baseDir = Utils.createTempDir()
try {
val e = intercept[AnalysisException] {
spark.read.format("csv").load(
new File(baseDir, "file").getAbsolutePath,
new File(baseDir, "file2").getAbsolutePath,
new File(uuid, "file3").getAbsolutePath,
uuid).rdd
}
assert(e.getMessage.startsWith("Path does not exist"))
} finally {
}
}
test("SPARK-13774: Check error message for not existent globbed paths") {
// Non-existent initial path component:
val nonExistentBasePath = "/" + UUID.randomUUID().toString
assert(!new File(nonExistentBasePath).exists())
val e = intercept[AnalysisException] {
spark.read.format("text").load(s"$nonExistentBasePath/*")
}
assert(e.getMessage.startsWith("Path does not exist"))
// Existent initial path component, but no matching files:
val baseDir = Utils.createTempDir()
val childDir = Utils.createTempDir(baseDir.getAbsolutePath)
assert(childDir.exists())
try {
val e1 = intercept[AnalysisException] {
spark.read.json(s"${baseDir.getAbsolutePath}/*/*-xyz.json").rdd
}
assert(e1.getMessage.startsWith("Path does not exist"))
} finally {
Utils.deleteRecursively(baseDir)
}
}
test("SPARK-15230: distinct() does not handle column name with dot properly") {
val df = Seq(1, 1, 2).toDF("column.with.dot")
checkAnswer(df.distinct(), Row(1) :: Row(2) :: Nil)
}
test("SPARK-16181: outer join with isNull filter") {
val left = Seq("x").toDF("col")
val right = Seq("y").toDF("col").withColumn("new", lit(true))
val joined = left.join(right, left("col") === right("col"), "left_outer")
checkAnswer(joined, Row("x", null, null))
checkAnswer(joined.filter($"new".isNull), Row("x", null, null))
}
test("SPARK-16664: persist with more than 200 columns") {
val size = 201L
val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(Seq.range(0, size))))
val schemas = List.range(0, size).map(a => StructField("name" + a, LongType, true))
val df = spark.createDataFrame(rdd, StructType(schemas))
assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100)
}
test("SPARK-17409: Do Not Optimize Query in CTAS (Data source tables) More Than Once") {
withTable("bar") {
withTempView("foo") {
withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "json") {
sql("select 0 as id").createOrReplaceTempView("foo")
val df = sql("select * from foo group by id")
// If we optimize the query in CTAS more than once, the following saveAsTable will fail
// with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])`
df.write.mode("overwrite").saveAsTable("bar")
checkAnswer(spark.table("bar"), Row(0) :: Nil)
val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar"))
assert(tableMetadata.provider == Some("json"),
"the expected table is a data source table using json")
}
}
}
}
test("copy results for sampling with replacement") {
val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b")
val sampleDf = df.sample(true, 2.00)
val d = sampleDf.withColumn("c", monotonically_increasing_id).select($"c").collect
assert(d.size == d.distinct.size)
}
private def verifyNullabilityInFilterExec(
df: DataFrame,
expr: String,
expectedNonNullableColumns: Seq[String]): Unit = {
val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr)
dfWithFilter.queryExecution.executedPlan.collect {
// When the child expression in isnotnull is null-intolerant (i.e. any null input will
// result in null output), the involved columns are converted to not nullable;
// otherwise, no change should be made.
case e: FilterExec =>
assert(e.output.forall { o =>
if (expectedNonNullableColumns.contains(o.name)) !o.nullable else o.nullable
})
}
}
test("SPARK-17957: no change on nullability in FilterExec output") {
val df = sparkContext.parallelize(Seq(
null.asInstanceOf[java.lang.Integer] -> java.lang.Integer.valueOf(3),
java.lang.Integer.valueOf(1) -> null.asInstanceOf[java.lang.Integer],
java.lang.Integer.valueOf(2) -> java.lang.Integer.valueOf(4))).toDF()
verifyNullabilityInFilterExec(df,
expr = "Rand()", expectedNonNullableColumns = Seq.empty[String])
verifyNullabilityInFilterExec(df,
expr = "coalesce(_1, _2)", expectedNonNullableColumns = Seq.empty[String])
verifyNullabilityInFilterExec(df,
expr = "coalesce(_1, 0) + Rand()", expectedNonNullableColumns = Seq.empty[String])
verifyNullabilityInFilterExec(df,
expr = "cast(coalesce(cast(coalesce(_1, _2) as double), 0.0) as int)",
expectedNonNullableColumns = Seq.empty[String])
}
test("SPARK-17957: set nullability to false in FilterExec output") {
val df = sparkContext.parallelize(Seq(
null.asInstanceOf[java.lang.Integer] -> java.lang.Integer.valueOf(3),
java.lang.Integer.valueOf(1) -> null.asInstanceOf[java.lang.Integer],
java.lang.Integer.valueOf(2) -> java.lang.Integer.valueOf(4))).toDF()
verifyNullabilityInFilterExec(df,
expr = "_1 + _2 * 3", expectedNonNullableColumns = Seq("_1", "_2"))
verifyNullabilityInFilterExec(df,
expr = "_1 + _2", expectedNonNullableColumns = Seq("_1", "_2"))
verifyNullabilityInFilterExec(df,
expr = "_1", expectedNonNullableColumns = Seq("_1"))
// `constructIsNotNullConstraints` infers the IsNotNull(_2) from IsNotNull(_2 + Rand())
// Thus, we are able to set nullability of _2 to false.
// If IsNotNull(_2) is not given from `constructIsNotNullConstraints`, the impl of
// isNullIntolerant in `FilterExec` needs an update for more advanced inference.
verifyNullabilityInFilterExec(df,
expr = "_2 + Rand()", expectedNonNullableColumns = Seq("_2"))
verifyNullabilityInFilterExec(df,
expr = "_2 * 3 + coalesce(_1, 0)", expectedNonNullableColumns = Seq("_2"))
verifyNullabilityInFilterExec(df,
expr = "cast((_1 + _2) as boolean)", expectedNonNullableColumns = Seq("_1", "_2"))
}
test("SPARK-17897: Fixed IsNotNull Constraint Inference Rule") {
val data = Seq[java.lang.Integer](1, null).toDF("key")
checkAnswer(data.filter(!$"key".isNotNull), Row(null))
checkAnswer(data.filter(!(- $"key").isNotNull), Row(null))
}
test("SPARK-17957: outer join + na.fill") {
withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") {
val df1 = Seq((1, 2), (2, 3)).toDF("a", "b")
val df2 = Seq((2, 5), (3, 4)).toDF("a", "c")
val joinedDf = df1.join(df2, Seq("a"), "outer").na.fill(0)
val df3 = Seq((3, 1)).toDF("a", "d")
checkAnswer(joinedDf.join(df3, "a"), Row(3, 0, 4, 1))
}
}
test("SPARK-18070 binary operator should not consider nullability when comparing input types") {
val rows = Seq(Row(Seq(1), Seq(1)))
val schema = new StructType()
.add("array1", ArrayType(IntegerType))
.add("array2", ArrayType(IntegerType, containsNull = false))
val df = spark.createDataFrame(spark.sparkContext.makeRDD(rows), schema)
assert(df.filter($"array1" === $"array2").count() == 1)
}
test("SPARK-17913: compare long and string type column may return confusing result") {
val df = Seq(123L -> "123", 19157170390056973L -> "19157170390056971").toDF("i", "j")
checkAnswer(df.select($"i" === $"j"), Row(true) :: Row(false) :: Nil)
}
test("SPARK-19691 Calculating percentile of decimal column fails with ClassCastException") {
val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)")
checkAnswer(df, Row(BigDecimal(0)) :: Nil)
}
test("SPARK-20359: catalyst outer join optimization should not throw npe") {
val df1 = Seq("a", "b", "c").toDF("x")
.withColumn("y", udf{ (x: String) => x.substring(0, 1) + "!" }.apply($"x"))
val df2 = Seq("a", "b").toDF("x1")
df1
.join(df2, df1("x") === df2("x1"), "left_outer")
.filter($"x1".isNotNull || !$"y".isin("a!"))
.count
}
// The fix of SPARK-21720 avoid an exception regarding JVM code size limit
// TODO: When we make a threshold of splitting statements (1024) configurable,
// we will re-enable this with max threshold to cause an exception
// See https://github.com/apache/spark/pull/18972/files#r150223463
ignore("SPARK-19372: Filter can be executed w/o generated code due to JVM code size limit") {
val N = 400
val rows = Seq(Row.fromSeq(Seq.fill(N)("string")))
val schema = StructType(Seq.tabulate(N)(i => StructField(s"_c$i", StringType)))
val df = spark.createDataFrame(spark.sparkContext.makeRDD(rows), schema)
val filter = (0 until N)
.foldLeft(lit(false))((e, index) => e.or(df.col(df.columns(index)) =!= "string"))
withSQLConf(SQLConf.CODEGEN_FALLBACK.key -> "true") {
df.filter(filter).count()
}
withSQLConf(SQLConf.CODEGEN_FALLBACK.key -> "false") {
val e = intercept[SparkException] {
df.filter(filter).count()
}.getMessage
assert(e.contains("grows beyond 64 KiB"))
}
}
test("SPARK-20897: cached self-join should not fail") {
// force to plan sort merge join
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
val df = Seq(1 -> "a").toDF("i", "j")
val df1 = df.as("t1")
val df2 = df.as("t2")
assert(df1.join(df2, $"t1.i" === $"t2.i").cache().count() == 1)
}
}
test("order-by ordinal.") {
checkAnswer(
testData2.select(lit(7), $"a", $"b").orderBy(lit(1), lit(2), lit(3)),
Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2)))
}
test("SPARK-22271: mean overflows and returns null for some decimal variables") {
val d = 0.034567890
val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol")
val result = df.select($"DecimalCol" cast DecimalType(38, 33))
.select(col("DecimalCol")).describe()
val mean = result.select("DecimalCol").where($"summary" === "mean")
assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000")))
}
test("SPARK-22520: support code generation for large CaseWhen") {
val N = 30
var expr1 = when($"id" === lit(0), 0)
var expr2 = when($"id" === lit(0), 10)
(1 to N).foreach { i =>
expr1 = expr1.when($"id" === lit(i), -i)
expr2 = expr2.when($"id" === lit(i + 10), i)
}
val df = spark.range(1).select(expr1, expr2.otherwise(0))
checkAnswer(df, Row(0, 10) :: Nil)
assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
}
test("SPARK-24165: CaseWhen/If - nullability of nested types") {
val rows = new java.util.ArrayList[Row]()
rows.add(Row(true, ("x", 1), Seq("x", "y"), Map(0 -> "x")))
rows.add(Row(false, (null, 2), Seq(null, "z"), Map(0 -> null)))
val schema = StructType(Seq(
StructField("cond", BooleanType, true),
StructField("s", StructType(Seq(
StructField("val1", StringType, true),
StructField("val2", IntegerType, false)
)), false),
StructField("a", ArrayType(StringType, true)),
StructField("m", MapType(IntegerType, StringType, true))
))
val sourceDF = spark.createDataFrame(rows, schema)
def structWhenDF: DataFrame = sourceDF
.select(when($"cond",
struct(lit("a").as("val1"), lit(10).as("val2"))).otherwise($"s") as "res")
.select($"res".getField("val1"))
def arrayWhenDF: DataFrame = sourceDF
.select(when($"cond", array(lit("a"), lit("b"))).otherwise($"a") as "res")
.select($"res".getItem(0))
def mapWhenDF: DataFrame = sourceDF
.select(when($"cond", map(lit(0), lit("a"))).otherwise($"m") as "res")
.select($"res".getItem(0))
def structIfDF: DataFrame = sourceDF
.select(expr("if(cond, struct('a' as val1, 10 as val2), s)") as "res")
.select($"res".getField("val1"))
def arrayIfDF: DataFrame = sourceDF
.select(expr("if(cond, array('a', 'b'), a)") as "res")
.select($"res".getItem(0))
def mapIfDF: DataFrame = sourceDF
.select(expr("if(cond, map(0, 'a'), m)") as "res")
.select($"res".getItem(0))
def checkResult(): Unit = {
checkAnswer(structWhenDF, Seq(Row("a"), Row(null)))
checkAnswer(arrayWhenDF, Seq(Row("a"), Row(null)))
checkAnswer(mapWhenDF, Seq(Row("a"), Row(null)))
checkAnswer(structIfDF, Seq(Row("a"), Row(null)))
checkAnswer(arrayIfDF, Seq(Row("a"), Row(null)))
checkAnswer(mapIfDF, Seq(Row("a"), Row(null)))
}
// Test with local relation, the Project will be evaluated without codegen
checkResult()
// Test with cached relation, the Project will be evaluated with codegen
sourceDF.cache()
checkResult()
}
test("Uuid expressions should produce same results at retries in the same DataFrame") {
val df = spark.range(1).select($"id", new Column(Uuid()))
checkAnswer(df, df.collect())
}
test("SPARK-24313: access map with binary keys") {
val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1))
checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1))
}
test("SPARK-24781: Using a reference from Dataset in Filter/Sort") {
val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id")
val filter1 = df.select(df("name")).filter(df("id") === 0)
val filter2 = df.select(col("name")).filter(col("id") === 0)
checkAnswer(filter1, filter2.collect())
val sort1 = df.select(df("name")).orderBy(df("id"))
val sort2 = df.select(col("name")).orderBy(col("id"))
checkAnswer(sort1, sort2.collect())
}
test("SPARK-24781: Using a reference not in aggregation in Filter/Sort") {
withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") {
val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id")
val aggPlusSort1 = df.groupBy(df("name")).agg(count(df("name"))).orderBy(df("name"))
val aggPlusSort2 = df.groupBy(col("name")).agg(count(col("name"))).orderBy(col("name"))
checkAnswer(aggPlusSort1, aggPlusSort2.collect())
val aggPlusFilter1 = df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === 0)
val aggPlusFilter2 = df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === 0)
checkAnswer(aggPlusFilter1, aggPlusFilter2.collect())
}
}
test("SPARK-34806: observation on datasets") {
val namedObservation = Observation("named")
val unnamedObservation = Observation()
val df = spark
.range(100)
.observe(
namedObservation,
min($"id").as("min_val"),
max($"id").as("max_val"),
sum($"id").as("sum_val"),
count(when($"id" % 2 === 0, 1)).as("num_even")
)
.observe(
unnamedObservation,
avg($"id").cast("int").as("avg_val")
)
def checkMetrics(namedMetric: Row, unnamedMetric: Row): Unit = {
assert(namedMetric === Row(0L, 99L, 4950L, 50L))
assert(unnamedMetric === Row(49))
}
df.collect()
// we can get the result multiple times
checkMetrics(namedObservation.get, unnamedObservation.get)
checkMetrics(namedObservation.get, unnamedObservation.get)
// an observation can be used only once
val err = intercept[IllegalArgumentException] {
spark.range(100).observe(namedObservation, sum($"id").as("sum_val"))
}
assert(err.getMessage.contains("An Observation can be used with a Dataset only once"))
// streaming datasets are not supported
val streamDf = new MemoryStream[Int](0, sqlContext).toDF()
val streamObservation = Observation("stream")
val streamErr = intercept[IllegalArgumentException] {
streamDf.observe(streamObservation, avg($"value").cast("int").as("avg_val"))
}
assert(streamErr.getMessage.contains("Observation does not support streaming Datasets"))
}
test("SPARK-25159: json schema inference should only trigger one job") {
withTempPath { path =>
// This test is to prove that the `JsonInferSchema` does not use `RDD#toLocalIterator` which
// triggers one Spark job per RDD partition.
Seq(1 -> "a", 2 -> "b").toDF("i", "p")
// The data set has 2 partitions, so Spark will write at least 2 json files.
// Use a non-splittable compression (gzip), to make sure the json scan RDD has at least 2
// partitions.
.write.partitionBy("p").option("compression", "gzip").json(path.getCanonicalPath)
val numJobs = new AtomicLong(0)
sparkContext.addSparkListener(new SparkListener {
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
numJobs.incrementAndGet()
}
})
val df = spark.read.json(path.getCanonicalPath)
assert(df.columns === Array("i", "p"))
spark.sparkContext.listenerBus.waitUntilEmpty()
assert(numJobs.get() == 1L)
}
}
test("SPARK-25402 Null handling in BooleanSimplification") {
val schema = StructType.fromDDL("a boolean, b int")
val rows = Seq(Row(null, 1))
val rdd = sparkContext.parallelize(rows)
val df = spark.createDataFrame(rdd, schema)
checkAnswer(df.where("(NOT a) OR a"), Seq.empty)
}
test("SPARK-25714 Null handling in BooleanSimplification") {
withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> ConvertToLocalRelation.ruleName) {
val df = Seq(("abc", 1), (null, 3)).toDF("col1", "col2")
checkAnswer(
df.filter("col1 = 'abc' OR (col1 != 'abc' AND col2 == 3)"),
Row ("abc", 1))
}
}
test("SPARK-25816 ResolveReferences works with nested extractors") {
val df = Seq((1, Map(1 -> "a")), (2, Map(2 -> "b"))).toDF("key", "map")
val swappedDf = df.select($"key".as("map"), $"map".as("key"))
checkAnswer(swappedDf.filter($"key"($"map") > "a"), Row(2, Map(2 -> "b")))
}
test("SPARK-26057: attribute deduplication on already analyzed plans") {
withTempView("a", "b", "v") {
val df1 = Seq(("1-1", 6)).toDF("id", "n")
df1.createOrReplaceTempView("a")
val df3 = Seq("1-1").toDF("id")
df3.createOrReplaceTempView("b")
spark.sql(
"""
|SELECT a.id, n as m
|FROM a
|WHERE EXISTS(
| SELECT 1
| FROM b
| WHERE b.id = a.id)
""".stripMargin).createOrReplaceTempView("v")
val res = spark.sql(
"""
|SELECT a.id, n, m
| FROM a
| LEFT OUTER JOIN v ON v.id = a.id
""".stripMargin)
checkAnswer(res, Row("1-1", 6, 6))
}
}
test("SPARK-27671: Fix analysis exception when casting null in nested field in struct") {
val df = sql("SELECT * FROM VALUES (('a', (10, null))), (('b', (10, 50))), " +
"(('c', null)) AS tab(x, y)")
checkAnswer(df, Row("a", Row(10, null)) :: Row("b", Row(10, 50)) :: Row("c", null) :: Nil)
val cast = sql("SELECT cast(struct(1, null) AS struct<a:int,b:int>)")
checkAnswer(cast, Row(Row(1, null)) :: Nil)
}
test("SPARK-27439: Explain result should match collected result after view change") {
withTempView("test", "test2", "tmp") {
spark.range(10).createOrReplaceTempView("test")
spark.range(5).createOrReplaceTempView("test2")
spark.sql("select * from test").createOrReplaceTempView("tmp")
val df = spark.sql("select * from tmp")
spark.sql("select * from test2").createOrReplaceTempView("tmp")
val captured = new ByteArrayOutputStream()
Console.withOut(captured) {
df.explain(extended = true)
}
checkAnswer(df, spark.range(10).toDF)
val output = captured.toString
assert(output.contains(
"""== Parsed Logical Plan ==
|'Project [*]
|+- 'UnresolvedRelation [tmp]""".stripMargin))
assert(output.contains(
"""== Physical Plan ==
|*(1) Range (0, 10, step=1, splits=2)""".stripMargin))
}
}
test("SPARK-29442 Set `default` mode should override the existing mode") {
val df = Seq(Tuple1(1)).toDF()
val writer = df.write.mode("overwrite").mode("default")
val modeField = classOf[DataFrameWriter[Tuple1[Int]]].getDeclaredField("mode")
modeField.setAccessible(true)
assert(SaveMode.ErrorIfExists === modeField.get(writer).asInstanceOf[SaveMode])
}
test("sample should not duplicated the input data") {
val df1 = spark.range(10).select($"id" as "id1", $"id" % 5 as "key1")
val df2 = spark.range(10).select($"id" as "id2", $"id" % 5 as "key2")
val sampled = df1.join(df2, $"key1" === $"key2")
.sample(0.5, 42)
.select("id1", "id2")
val idTuples = sampled.collect().map(row => row.getLong(0) -> row.getLong(1))
assert(idTuples.length == idTuples.toSet.size)
}
test("groupBy.as") {
val df1 = Seq((1, 2, 3), (2, 3, 4)).toDF("a", "b", "c")
.repartition($"a", $"b").sortWithinPartitions("a", "b")
val df2 = Seq((1, 2, 4), (2, 3, 5)).toDF("a", "b", "c")
.repartition($"a", $"b").sortWithinPartitions("a", "b")
implicit val valueEncoder = RowEncoder(df1.schema)
val df3 = df1.groupBy("a", "b").as[GroupByKey, Row]
.cogroup(df2.groupBy("a", "b").as[GroupByKey, Row]) { case (_, data1, data2) =>
data1.zip(data2).map { p =>
p._1.getInt(2) + p._2.getInt(2)
}
}.toDF
checkAnswer(df3.sort("value"), Row(7) :: Row(9) :: Nil)
// Assert that no extra shuffle introduced by cogroup.
val exchanges = collect(df3.queryExecution.executedPlan) {
case h: ShuffleExchangeExec => h
}
assert(exchanges.size == 2)
}
test("groupBy.as: custom grouping expressions") {
val df1 = Seq((1, 2, 3), (2, 3, 4)).toDF("a1", "b", "c")
.repartition($"a1", $"b").sortWithinPartitions("a1", "b")
val df2 = Seq((1, 2, 4), (2, 3, 5)).toDF("a1", "b", "c")
.repartition($"a1", $"b").sortWithinPartitions("a1", "b")
implicit val valueEncoder = RowEncoder(df1.schema)
val groupedDataset1 = df1.groupBy(($"a1" + 1).as("a"), $"b").as[GroupByKey, Row]
val groupedDataset2 = df2.groupBy(($"a1" + 1).as("a"), $"b").as[GroupByKey, Row]
val df3 = groupedDataset1
.cogroup(groupedDataset2) { case (_, data1, data2) =>
data1.zip(data2).map { p =>
p._1.getInt(2) + p._2.getInt(2)
}
}.toDF
checkAnswer(df3.sort("value"), Row(7) :: Row(9) :: Nil)
}
test("groupBy.as: throw AnalysisException for unresolved grouping expr") {
val df = Seq((1, 2, 3), (2, 3, 4)).toDF("a", "b", "c")
implicit val valueEncoder = RowEncoder(df.schema)
val err = intercept[AnalysisException] {
df.groupBy($"d", $"b").as[GroupByKey, Row]
}
assert(err.getMessage.contains("cannot resolve 'd'"))
}
test("emptyDataFrame should be foldable") {
val emptyDf = spark.emptyDataFrame.withColumn("id", lit(1L))
val joined = spark.range(10).join(emptyDf, "id")
joined.queryExecution.optimizedPlan match {
case LocalRelation(Seq(id), Nil, _) =>
assert(id.name == "id")
case _ =>
fail("emptyDataFrame should be foldable")
}
}
test("SPARK-30811: CTE should not cause stack overflow when " +
"it refers to non-existent table with same name") {
val e = intercept[AnalysisException] {
sql("WITH t AS (SELECT 1 FROM nonexist.t) SELECT * FROM t")
}
assert(e.getMessage.contains("Table or view not found:"))
}
test("SPARK-32680: Don't analyze CTAS with unresolved query") {
val v2Source = classOf[FakeV2Provider].getName
val e = intercept[AnalysisException] {
sql(s"CREATE TABLE t USING $v2Source AS SELECT * from nonexist")
}
assert(e.getMessage.contains("Table or view not found:"))
}
test("CalendarInterval reflection support") {
val df = Seq((1, new CalendarInterval(1, 2, 3))).toDF("a", "b")
checkAnswer(df.selectExpr("b"), Row(new CalendarInterval(1, 2, 3)))
}
test("SPARK-31552: array encoder with different types") {
// primitives
val booleans = Array(true, false)
checkAnswer(Seq(booleans).toDF(), Row(booleans))
val bytes = Array(1.toByte, 2.toByte)
checkAnswer(Seq(bytes).toDF(), Row(bytes))
val shorts = Array(1.toShort, 2.toShort)
checkAnswer(Seq(shorts).toDF(), Row(shorts))
val ints = Array(1, 2)
checkAnswer(Seq(ints).toDF(), Row(ints))
val longs = Array(1L, 2L)
checkAnswer(Seq(longs).toDF(), Row(longs))
val floats = Array(1.0F, 2.0F)
checkAnswer(Seq(floats).toDF(), Row(floats))
val doubles = Array(1.0D, 2.0D)
checkAnswer(Seq(doubles).toDF(), Row(doubles))
val strings = Array("2020-04-24", "2020-04-25")
checkAnswer(Seq(strings).toDF(), Row(strings))
// tuples
val decOne = Decimal(1, 38, 18)
val decTwo = Decimal(2, 38, 18)
val tuple1 = (1, 2.2, "3.33", decOne, Date.valueOf("2012-11-22"))
val tuple2 = (2, 3.3, "4.44", decTwo, Date.valueOf("2022-11-22"))
checkAnswer(Seq(Array(tuple1, tuple2)).toDF(), Seq(Seq(tuple1, tuple2)).toDF())
// case classes
val gbks = Array(GroupByKey(1, 2), GroupByKey(4, 5))
checkAnswer(Seq(gbks).toDF(), Row(Array(Row(1, 2), Row(4, 5))))
// We can move this implicit def to [[SQLImplicits]] when we eventually make fully
// support for array encoder like Seq and Set
// For now cases below, decimal/datetime/interval/binary/nested types, etc,
// are not supported by array
implicit def newArrayEncoder[T <: Array[_] : TypeTag]: Encoder[T] = ExpressionEncoder()
// decimals
val decSpark = Array(decOne, decTwo)
val decScala = decSpark.map(_.toBigDecimal)
val decJava = decSpark.map(_.toJavaBigDecimal)
checkAnswer(Seq(decSpark).toDF(), Row(decJava))
checkAnswer(Seq(decScala).toDF(), Row(decJava))
checkAnswer(Seq(decJava).toDF(), Row(decJava))
// datetimes and intervals
val dates = strings.map(Date.valueOf)
checkAnswer(Seq(dates).toDF(), Row(dates))
val localDates = dates.map(d => DateTimeUtils.daysToLocalDate(DateTimeUtils.fromJavaDate(d)))
checkAnswer(Seq(localDates).toDF(), Row(dates))
val timestamps =
Array(Timestamp.valueOf("2020-04-24 12:34:56"), Timestamp.valueOf("2020-04-24 11:22:33"))
checkAnswer(Seq(timestamps).toDF(), Row(timestamps))
val instants =
timestamps.map(t => DateTimeUtils.microsToInstant(DateTimeUtils.fromJavaTimestamp(t)))
checkAnswer(Seq(instants).toDF(), Row(timestamps))
val intervals = Array(new CalendarInterval(1, 2, 3), new CalendarInterval(4, 5, 6))
checkAnswer(Seq(intervals).toDF(), Row(intervals))
// binary
val bins = Array(Array(1.toByte), Array(2.toByte), Array(3.toByte), Array(4.toByte))
checkAnswer(Seq(bins).toDF(), Row(bins))
// nested
val nestedIntArray = Array(Array(1), Array(2))
checkAnswer(Seq(nestedIntArray).toDF(), Row(nestedIntArray.map(wrapIntArray)))
val nestedDecArray = Array(decSpark)
checkAnswer(Seq(nestedDecArray).toDF(), Row(Array(wrapRefArray(decJava))))
}
test("SPARK-31750: eliminate UpCast if child's dataType is DecimalType") {
withTempPath { f =>
sql("select cast(1 as decimal(38, 0)) as d")
.write.mode("overwrite")
.parquet(f.getAbsolutePath)
val df = spark.read.parquet(f.getAbsolutePath).as[BigDecimal]
assert(df.schema === new StructType().add(StructField("d", DecimalType(38, 0))))
}
}
test("SPARK-32640: ln(NaN) should return NaN") {
val df = Seq(Double.NaN).toDF("d")
checkAnswer(df.selectExpr("ln(d)"), Row(Double.NaN))
}
test("SPARK-32761: aggregating multiple distinct CONSTANT columns") {
checkAnswer(sql("select count(distinct 2), count(distinct 2,3)"), Row(1, 1))
}
test("SPARK-32764: -0.0 and 0.0 should be equal") {
val df = Seq(0.0 -> -0.0).toDF("pos", "neg")
checkAnswer(df.select($"pos" > $"neg"), Row(false))
}
test("SPARK-32635: Replace references with foldables coming only from the node's children") {
val a = Seq("1").toDF("col1").withColumn("col2", lit("1"))
val b = Seq("2").toDF("col1").withColumn("col2", lit("2"))
val aub = a.union(b)
val c = aub.filter($"col1" === "2").cache()
val d = Seq("2").toDF("col4")
val r = d.join(aub, $"col2" === $"col4").select("col4")
val l = c.select("col2")
val df = l.join(r, $"col2" === $"col4", "LeftOuter")
checkAnswer(df, Row("2", "2"))
}
test("SPARK-33939: Make Column.named use UnresolvedAlias to assign name") {
val df = spark.range(1).selectExpr("id as id1", "id as id2")
val df1 = df.selectExpr("cast(struct(id1, id2).id1 as int)")
assert(df1.schema.head.name == "CAST(struct(id1, id2).id1 AS INT)")
val df2 = df.selectExpr("cast(array(struct(id1, id2))[0].id1 as int)")
assert(df2.schema.head.name == "CAST(array(struct(id1, id2))[0].id1 AS INT)")
val df3 = df.select(hex(expr("struct(id1, id2).id1")))
assert(df3.schema.head.name == "hex(struct(id1, id2).id1)")
// this test is to make sure we don't have a regression.
val df4 = df.selectExpr("id1 == null")
assert(df4.schema.head.name == "(id1 = NULL)")
}
test("SPARK-33989: Strip auto-generated cast when using Cast.sql") {
Seq("SELECT id == null FROM VALUES(1) AS t(id)",
"SELECT floor(1)",
"SELECT split(struct(c1, c2).c1, ',') FROM VALUES(1, 2) AS t(c1, c2)").foreach { sqlStr =>
assert(!sql(sqlStr).schema.fieldNames.head.toLowerCase(Locale.getDefault).contains("cast"))
}
Seq("SELECT id == CAST(null AS int) FROM VALUES(1) AS t(id)",
"SELECT floor(CAST(1 AS double))",
"SELECT split(CAST(struct(c1, c2).c1 AS string), ',') FROM VALUES(1, 2) AS t(c1, c2)"
).foreach { sqlStr =>
assert(sql(sqlStr).schema.fieldNames.head.toLowerCase(Locale.getDefault).contains("cast"))
}
}
test("SPARK-34318: colRegex should work with column names & qualifiers which contain newlines") {
val df = Seq(1, 2, 3).toDF("test\n_column").as("test\n_table")
val col1 = df.colRegex("`tes.*\n.*mn`")
checkAnswer(df.select(col1), Row(1) :: Row(2) :: Row(3) :: Nil)
val col2 = df.colRegex("test\n_table.`tes.*\n.*mn`")
checkAnswer(df.select(col2), Row(1) :: Row(2) :: Row(3) :: Nil)
}
test("SPARK-34763: col(), $\"<name>\", df(\"name\") should handle quoted column name properly") {
val df1 = spark.sql("SELECT 'col1' AS `a``b.c`")
checkAnswer(df1.selectExpr("`a``b.c`"), Row("col1"))
checkAnswer(df1.select(df1("`a``b.c`")), Row("col1"))
checkAnswer(df1.select(col("`a``b.c`")), Row("col1"))
checkAnswer(df1.select($"`a``b.c`"), Row("col1"))
val df2 = df1.as("d.e`f")
checkAnswer(df2.selectExpr("`a``b.c`"), Row("col1"))
checkAnswer(df2.select(df2("`a``b.c`")), Row("col1"))
checkAnswer(df2.select(col("`a``b.c`")), Row("col1"))
checkAnswer(df2.select($"`a``b.c`"), Row("col1"))
checkAnswer(df2.selectExpr("`d.e``f`.`a``b.c`"), Row("col1"))
checkAnswer(df2.select(df2("`d.e``f`.`a``b.c`")), Row("col1"))
checkAnswer(df2.select(col("`d.e``f`.`a``b.c`")), Row("col1"))
checkAnswer(df2.select($"`d.e``f`.`a``b.c`"), Row("col1"))
val df3 = df1.as("*-#&% ?")
checkAnswer(df3.selectExpr("`*-#&% ?`.`a``b.c`"), Row("col1"))
checkAnswer(df3.select(df3("*-#&% ?.`a``b.c`")), Row("col1"))
checkAnswer(df3.select(col("*-#&% ?.`a``b.c`")), Row("col1"))
checkAnswer(df3.select($"*-#&% ?.`a``b.c`"), Row("col1"))
}
test("SPARK-34776: Nested column pruning should not prune Window produced attributes") {
val df = Seq(
("t1", "123", "bob"),
("t1", "456", "bob"),
("t2", "123", "sam")
).toDF("type", "value", "name")
val test = df.select(
$"*",
struct(count($"*").over(Window.partitionBy($"type", $"value", $"name"))
.as("count"), $"name").as("name_count")
).select(
$"*",
max($"name_count").over(Window.partitionBy($"type", $"value")).as("best_name")
)
checkAnswer(test.select($"best_name.name"), Row("bob") :: Row("bob") :: Row("sam") :: Nil)
}
test("SPARK-34829: Multiple applications of typed ScalaUDFs in higher order functions work") {
val reverse = udf((s: String) => s.reverse)
val reverse2 = udf((b: Bar2) => Bar2(b.s.reverse))
val df = Seq(Array("abc", "def")).toDF("array")
val test = df.select(transform(col("array"), s => reverse(s)))
checkAnswer(test, Row(Array("cba", "fed")) :: Nil)
val df2 = Seq(Array(Bar2("abc"), Bar2("def"))).toDF("array")
val test2 = df2.select(transform(col("array"), b => reverse2(b)))
checkAnswer(test2, Row(Array(Row("cba"), Row("fed"))) :: Nil)
val df3 = Seq(Map("abc" -> 1, "def" -> 2)).toDF("map")
val test3 = df3.select(transform_keys(col("map"), (s, _) => reverse(s)))
checkAnswer(test3, Row(Map("cba" -> 1, "fed" -> 2)) :: Nil)
val df4 = Seq(Map(Bar2("abc") -> 1, Bar2("def") -> 2)).toDF("map")
val test4 = df4.select(transform_keys(col("map"), (b, _) => reverse2(b)))
checkAnswer(test4, Row(Map(Row("cba") -> 1, Row("fed") -> 2)) :: Nil)
val df5 = Seq(Map(1 -> "abc", 2 -> "def")).toDF("map")
val test5 = df5.select(transform_values(col("map"), (_, s) => reverse(s)))
checkAnswer(test5, Row(Map(1 -> "cba", 2 -> "fed")) :: Nil)
val df6 = Seq(Map(1 -> Bar2("abc"), 2 -> Bar2("def"))).toDF("map")
val test6 = df6.select(transform_values(col("map"), (_, b) => reverse2(b)))
checkAnswer(test6, Row(Map(1 -> Row("cba"), 2 -> Row("fed"))) :: Nil)
val reverseThenConcat = udf((s1: String, s2: String) => s1.reverse ++ s2.reverse)
val reverseThenConcat2 = udf((b1: Bar2, b2: Bar2) => Bar2(b1.s.reverse ++ b2.s.reverse))
val df7 = Seq((Map(1 -> "abc", 2 -> "def"), Map(1 -> "ghi", 2 -> "jkl"))).toDF("map1", "map2")
val test7 =
df7.select(map_zip_with(col("map1"), col("map2"), (_, s1, s2) => reverseThenConcat(s1, s2)))
checkAnswer(test7, Row(Map(1 -> "cbaihg", 2 -> "fedlkj")) :: Nil)
val df8 = Seq((Map(1 -> Bar2("abc"), 2 -> Bar2("def")),
Map(1 -> Bar2("ghi"), 2 -> Bar2("jkl")))).toDF("map1", "map2")
val test8 =
df8.select(map_zip_with(col("map1"), col("map2"), (_, b1, b2) => reverseThenConcat2(b1, b2)))
checkAnswer(test8, Row(Map(1 -> Row("cbaihg"), 2 -> Row("fedlkj"))) :: Nil)
val df9 = Seq((Array("abc", "def"), Array("ghi", "jkl"))).toDF("array1", "array2")
val test9 =
df9.select(zip_with(col("array1"), col("array2"), (s1, s2) => reverseThenConcat(s1, s2)))
checkAnswer(test9, Row(Array("cbaihg", "fedlkj")) :: Nil)
val df10 = Seq((Array(Bar2("abc"), Bar2("def")), Array(Bar2("ghi"), Bar2("jkl"))))
.toDF("array1", "array2")
val test10 =
df10.select(zip_with(col("array1"), col("array2"), (b1, b2) => reverseThenConcat2(b1, b2)))
checkAnswer(test10, Row(Array(Row("cbaihg"), Row("fedlkj"))) :: Nil)
}
test("SPARK-34882: Aggregate with multiple distinct null sensitive aggregators") {
withUserDefinedFunction(("countNulls", true)) {
spark.udf.register("countNulls", udaf(new Aggregator[JLong, JLong, JLong] {
def zero: JLong = 0L
def reduce(b: JLong, a: JLong): JLong = if (a == null) {
b + 1
} else {
b
}
def merge(b1: JLong, b2: JLong): JLong = b1 + b2
def finish(r: JLong): JLong = r
def bufferEncoder: Encoder[JLong] = Encoders.LONG
def outputEncoder: Encoder[JLong] = Encoders.LONG
}))
val result = testData.selectExpr(
"countNulls(key)",
"countNulls(DISTINCT key)",
"countNulls(key) FILTER (WHERE key > 50)",
"countNulls(DISTINCT key) FILTER (WHERE key > 50)",
"count(DISTINCT key)")
checkAnswer(result, Row(0, 0, 0, 0, 100))
}
}
test("SPARK-35410: SubExpr elimination should not include redundant child exprs " +
"for conditional expressions") {
val accum = sparkContext.longAccumulator("call")
val simpleUDF = udf((s: String) => {
accum.add(1)
s
})
val df1 = spark.range(5).select(when(functions.length(simpleUDF($"id")) > 0,
functions.length(simpleUDF($"id"))).otherwise(
functions.length(simpleUDF($"id")) + 1))
df1.collect()
assert(accum.value == 5)
val nondeterministicUDF = simpleUDF.asNondeterministic()
val df2 = spark.range(5).select(when(functions.length(nondeterministicUDF($"id")) > 0,
functions.length(nondeterministicUDF($"id"))).otherwise(
functions.length(nondeterministicUDF($"id")) + 1))
df2.collect()
assert(accum.value == 15)
}
test("SPARK-35560: Remove redundant subexpression evaluation in nested subexpressions") {
Seq(1, Int.MaxValue).foreach { splitThreshold =>
withSQLConf(SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> splitThreshold.toString) {
val accum = sparkContext.longAccumulator("call")
val simpleUDF = udf((s: String) => {
accum.add(1)
s
})
// Common exprs:
// 1. simpleUDF($"id")
// 2. functions.length(simpleUDF($"id"))
// We should only evaluate `simpleUDF($"id")` once, i.e.
// subExpr1 = simpleUDF($"id");
// subExpr2 = functions.length(subExpr1);
val df = spark.range(5).select(
when(functions.length(simpleUDF($"id")) === 1, lower(simpleUDF($"id")))
.when(functions.length(simpleUDF($"id")) === 0, upper(simpleUDF($"id")))
.otherwise(simpleUDF($"id")).as("output"))
df.collect()
assert(accum.value == 5)
}
}
}
test("isLocal should consider CommandResult and LocalRelation") {
val df1 = sql("SHOW TABLES")
assert(df1.isLocal)
val df2 = (1 to 10).toDF()
assert(df2.isLocal)
}
test("SPARK-35886: PromotePrecision should be subexpr replaced") {
withTable("tbl") {
sql(
"""
|CREATE TABLE tbl (
| c1 DECIMAL(18,6),
| c2 DECIMAL(18,6),
| c3 DECIMAL(18,6))
|USING parquet;
|""".stripMargin)
sql("INSERT INTO tbl SELECT 1, 1, 1")
checkAnswer(sql("SELECT sum(c1 * c3) + sum(c2 * c3) FROM tbl"), Row(2.00000000000) :: Nil)
}
}
test("SPARK-36338: DataFrame.withSequenceColumn should append unique sequence IDs") {
val ids = spark.range(10).repartition(5)
.withSequenceColumn("default_index").collect().map(_.getLong(0))
assert(ids.toSet === Range(0, 10).toSet)
}
}
case class GroupByKey(a: Int, b: Int)
case class Bar2(s: String)