blob: 73a2f6d4f88e18e2d060570311e226a94c8b54ac [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import java.io.{ByteArrayOutputStream, PrintStream}
import java.nio.file.Files
import java.time.DateTimeException
import java.util.Properties
import scala.collection.mutable
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.DurationInt
import scala.jdk.CollectionConverters._
import org.apache.commons.io.FileUtils
import org.apache.commons.io.output.TeeOutputStream
import org.scalactic.TolerantNumerics
import org.scalatest.PrivateMethodTester
import org.apache.spark.{SparkArithmeticException, SparkException, SparkUpgradeException}
import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION}
import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchDatabaseException, TableAlreadyExistsException, TempTableAlreadyExistsException}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.test.{IntegrationTestUtils, RemoteSparkSession, SQLHelper}
import org.apache.spark.sql.test.SparkConnectServerUtils.port
import org.apache.spark.sql.types._
import org.apache.spark.util.SparkThreadUtils
class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester {
test("throw SparkException with null filename in stack trace elements") {
withSQLConf("spark.sql.connect.enrichError.enabled" -> "true") {
val session = spark
import session.implicits._
val throwException =
udf((_: String) => {
val testError = new SparkException("test")
val stackTrace = testError.getStackTrace()
stackTrace(0) = new StackTraceElement(
stackTrace(0).getClassName,
stackTrace(0).getMethodName,
null,
stackTrace(0).getLineNumber)
testError.setStackTrace(stackTrace)
throw testError
})
val ex = intercept[SparkException] {
Seq("1").toDS().withColumn("udf_val", throwException($"value")).collect()
}
assert(ex.getCause.isInstanceOf[SparkException])
assert(ex.getCause.getStackTrace().length > 0)
assert(ex.getCause.getStackTrace()(0).getFileName == null)
}
}
for (enrichErrorEnabled <- Seq(false, true)) {
test(s"cause exception - ${enrichErrorEnabled}") {
withSQLConf(
"spark.sql.connect.enrichError.enabled" -> enrichErrorEnabled.toString,
"spark.sql.legacy.timeParserPolicy" -> "EXCEPTION") {
val ex = intercept[SparkUpgradeException] {
spark
.sql("""
|select from_json(
| '{"d": "02-29"}',
| 'd date',
| map('dateFormat', 'MM-dd'))
|""".stripMargin)
.collect()
}
assert(
ex.getErrorClass ===
"INCONSISTENT_BEHAVIOR_CROSS_VERSION.PARSE_DATETIME_BY_NEW_PARSER")
assert(
ex.getMessageParameters.asScala == Map(
"datetime" -> "'02-29'",
"config" -> "\"spark.sql.legacy.timeParserPolicy\""))
if (enrichErrorEnabled) {
assert(ex.getCause.isInstanceOf[DateTimeException])
} else {
assert(ex.getCause == null)
}
}
}
}
test("throw SparkException with large cause exception") {
withSQLConf("spark.sql.connect.enrichError.enabled" -> "true") {
val session = spark
import session.implicits._
val throwException =
udf((_: String) => throw new SparkException("test" * 10000))
val ex = intercept[SparkException] {
Seq("1").toDS().withColumn("udf_val", throwException($"value")).collect()
}
assert(ex.getErrorClass != null)
assert(!ex.getMessageParameters.isEmpty)
assert(ex.getCause.isInstanceOf[SparkException])
val cause = ex.getCause.asInstanceOf[SparkException]
assert(cause.getErrorClass == null)
assert(cause.getMessageParameters.isEmpty)
assert(cause.getMessage.contains("test" * 10000))
}
}
for (isServerStackTraceEnabled <- Seq(false, true)) {
test(s"server-side stack trace is set in exceptions - ${isServerStackTraceEnabled}") {
withSQLConf(
"spark.sql.connect.serverStacktrace.enabled" -> isServerStackTraceEnabled.toString,
"spark.sql.pyspark.jvmStacktrace.enabled" -> "false") {
val ex = intercept[AnalysisException] {
spark.sql("select x").collect()
}
assert(ex.getErrorClass != null)
assert(!ex.messageParameters.isEmpty)
assert(ex.getSqlState != null)
assert(!ex.isInternalError)
assert(ex.getQueryContext.length == 1)
assert(ex.getQueryContext.head.startIndex() == 7)
assert(ex.getQueryContext.head.stopIndex() == 7)
assert(ex.getQueryContext.head.fragment() == "x")
assert(
ex.getStackTrace
.find(_.getClassName.contains("org.apache.spark.sql.catalyst.analysis.CheckAnalysis"))
.isDefined)
}
}
}
test("throw SparkArithmeticException") {
withSQLConf("spark.sql.ansi.enabled" -> "true") {
intercept[SparkArithmeticException] {
spark.sql("select 1/0").collect()
}
}
}
test("throw NoSuchDatabaseException") {
val ex = intercept[NoSuchDatabaseException] {
spark.sql("use database123")
}
assert(ex.getErrorClass != null)
}
test("table not found for spark.catalog.getTable") {
val ex = intercept[AnalysisException] {
spark.catalog.getTable("test_table")
}
assert(ex.getErrorClass != null)
}
test("throw NamespaceAlreadyExistsException") {
try {
spark.sql("create database test_db")
val ex = intercept[NamespaceAlreadyExistsException] {
spark.sql("create database test_db")
}
assert(ex.getErrorClass != null)
} finally {
spark.sql("drop database test_db")
}
}
test("throw TempTableAlreadyExistsException") {
try {
spark.sql("create temporary view test_view as select 1")
val ex = intercept[TempTableAlreadyExistsException] {
spark.sql("create temporary view test_view as select 1")
}
assert(ex.getErrorClass != null)
} finally {
spark.sql("drop view test_view")
}
}
test("throw TableAlreadyExistsException") {
withTable("testcat.test_table") {
spark.sql(s"create table testcat.test_table (id int)")
val ex = intercept[TableAlreadyExistsException] {
spark.sql(s"create table testcat.test_table (id int)")
}
assert(ex.getErrorClass != null)
}
}
test("throw ParseException") {
val ex = intercept[ParseException] {
spark.sql("selet 1").collect()
}
assert(ex.getErrorClass != null)
assert(!ex.messageParameters.isEmpty)
assert(ex.getSqlState != null)
assert(!ex.isInternalError)
}
test("spark deep recursion") {
var df = spark.range(1)
for (a <- 1 to 500) {
df = df.union(spark.range(a, a + 1))
}
assert(df.collect().length == 501)
}
test("handle unknown exception") {
var df = spark.range(1)
val limit = spark.conf.get("spark.connect.grpc.marshallerRecursionLimit").toInt + 1
for (a <- 1 to limit) {
df = df.union(spark.range(a, a + 1))
}
val ex = intercept[SparkException] {
df.collect()
}
assert(ex.getMessage.contains("io.grpc.StatusRuntimeException: UNKNOWN"))
}
test("many tables") {
withSQLConf("spark.sql.execution.arrow.maxRecordsPerBatch" -> "10") {
val numTables = 20
try {
for (i <- 0 to numTables) {
spark.sql(s"create table testcat.table${i} (id int)")
}
assert(spark.sql("show tables in testcat").collect().length == numTables + 1)
} finally {
for (i <- 0 to numTables) {
spark.sql(s"drop table if exists testcat.table${i}")
}
}
}
}
// Spark Result
test("spark result schema") {
val df = spark.sql("select val from (values ('Hello'), ('World')) as t(val)")
df.withResult { result =>
val schema = result.schema
assert(schema == StructType(StructField("val", StringType, nullable = false) :: Nil))
}
}
test("spark result array") {
val df = spark.sql("select val from (values ('Hello'), ('World')) as t(val)")
val result = df.collect()
assert(result.length == 2)
assert(result(0).getString(0) == "Hello")
assert(result(1).getString(0) == "World")
}
test("eager execution of sql") {
assume(IntegrationTestUtils.isSparkHiveJarAvailable)
withTable("test_martin") {
// Fails, because table does not exist.
assertThrows[AnalysisException] {
spark.sql("select * from test_martin").collect()
}
// Execute eager, DML
spark.sql("create table test_martin (id int)")
// Execute read again.
val rows = spark.sql("select * from test_martin").collect()
assert(rows.length == 0)
spark.sql("insert into test_martin values (1), (2)")
val rows_new = spark.sql("select * from test_martin").collect()
assert(rows_new.length == 2)
}
}
test("simple dataset") {
val df = spark.range(10).limit(3)
val result = df.collect()
assert(result.length == 3)
assert(result(0) == 0)
assert(result(1) == 1)
assert(result(2) == 2)
}
test("read and write") {
assume(IntegrationTestUtils.isSparkHiveJarAvailable)
val testDataPath = java.nio.file.Paths
.get(
IntegrationTestUtils.sparkHome,
"connector",
"connect",
"common",
"src",
"test",
"resources",
"query-tests",
"test-data",
"people.csv")
.toAbsolutePath
val df = spark.read
.format("csv")
.option("path", testDataPath.toString)
.options(Map("header" -> "true", "delimiter" -> ";"))
.schema(
StructType(
StructField("name", StringType) ::
StructField("age", IntegerType) ::
StructField("job", StringType) :: Nil))
.load()
val outputFolderPath = Files.createTempDirectory("output").toAbsolutePath
df.write
.format("csv")
.mode("overwrite")
.options(Map("header" -> "true", "delimiter" -> ";"))
.save(outputFolderPath.toString)
// We expect only one csv file saved.
val outputFile = outputFolderPath.toFile
.listFiles()
.filter(file => file.getPath.endsWith(".csv"))(0)
assert(FileUtils.contentEquals(testDataPath.toFile, outputFile))
}
test("read path collision") {
val testDataPath = java.nio.file.Paths
.get(
IntegrationTestUtils.sparkHome,
"connector",
"connect",
"common",
"src",
"test",
"resources",
"query-tests",
"test-data",
"people.csv")
.toAbsolutePath
val df = spark.read
.format("csv")
.option("path", testDataPath.toString)
.options(Map("header" -> "true", "delimiter" -> ";"))
.schema(
StructType(
StructField("name", StringType) ::
StructField("age", IntegerType) ::
StructField("job", StringType) :: Nil))
.csv(testDataPath.toString)
// Failed because the path cannot be provided both via option and load method (csv).
assertThrows[AnalysisException] {
df.collect()
}
}
test("textFile") {
assume(IntegrationTestUtils.isSparkHiveJarAvailable)
val testDataPath = java.nio.file.Paths
.get(
IntegrationTestUtils.sparkHome,
"connector",
"connect",
"common",
"src",
"test",
"resources",
"query-tests",
"test-data",
"people.txt")
.toAbsolutePath
val result = spark.read.textFile(testDataPath.toString).collect()
val expected = Array("Michael, 29", "Andy, 30", "Justin, 19")
assert(result.length == 3)
assert(result === expected)
}
test("write table") {
assume(IntegrationTestUtils.isSparkHiveJarAvailable)
withTable("myTable") {
val df = spark.range(10).limit(3)
df.write.mode(SaveMode.Overwrite).saveAsTable("myTable")
spark.range(2).write.insertInto("myTable")
val result = spark.sql("select * from myTable").sort("id").collect()
assert(result.length == 5)
assert(result(0).getLong(0) == 0)
assert(result(1).getLong(0) == 0)
assert(result(2).getLong(0) == 1)
assert(result(3).getLong(0) == 1)
assert(result(4).getLong(0) == 2)
}
}
test("different spark session join/union") {
val df = spark.range(10).limit(3)
val spark2 = SparkSession
.builder()
.client(
SparkConnectClient
.builder()
.port(port)
.build())
.create()
val df2 = spark2.range(10).limit(3)
assertThrows[SparkException] {
df.union(df2).collect()
}
assertThrows[SparkException] {
df.unionByName(df2).collect()
}
assertThrows[SparkException] {
df.join(df2).collect()
}
}
test("write without table or path") {
assume(IntegrationTestUtils.isSparkHiveJarAvailable)
// Should receive no error to write noop
spark.range(10).write.format("noop").mode("append").save()
}
test("write jdbc") {
assume(IntegrationTestUtils.isSparkHiveJarAvailable)
val url = "jdbc:derby:memory:1234"
val table = "t1"
try {
spark.range(10).write.jdbc(url = s"$url;create=true", table, new Properties())
val result = spark.read.jdbc(url = url, table, new Properties()).collect()
assert(result.length == 10)
} finally {
// clean up
assertThrows[SparkException] {
spark.read.jdbc(url = s"$url;drop=true", table, new Properties()).collect()
}
}
}
test("writeTo with create") {
withTable("testcat.myTableV2") {
val rows = Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))
val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))
spark.createDataFrame(rows.asJava, schema).writeTo("testcat.myTableV2").create()
val outputRows = spark.table("testcat.myTableV2").collect()
assert(outputRows.length == 3)
}
}
test("writeTo with create and using") {
withTable("testcat.myTableV2") {
val rows = Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))
val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))
spark.createDataFrame(rows.asJava, schema).writeTo("testcat.myTableV2").create()
val outputRows = spark.table("testcat.myTableV2").collect()
assert(outputRows.length == 3)
val columns = spark.table("testcat.myTableV2").columns
assert(columns.length == 2)
val sqlOutputRows = spark.sql("select * from testcat.myTableV2").collect()
assert(outputRows.length == 3)
assert(sqlOutputRows(0).schema == schema)
assert(sqlOutputRows(1).getString(1) == "b")
}
}
test("writeTo with create and append") {
withTable("testcat.myTableV2") {
val rows = Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))
val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))
spark.sql("CREATE TABLE testcat.myTableV2 (id bigint, data string) USING foo")
assert(spark.table("testcat.myTableV2").collect().isEmpty)
spark.createDataFrame(rows.asJava, schema).writeTo("testcat.myTableV2").append()
val outputRows = spark.table("testcat.myTableV2").collect()
assert(outputRows.length == 3)
}
}
test("WriteTo with overwrite") {
withTable("testcat.myTableV2") {
val rows1 = (1L to 3L).map { i =>
Row(i, "" + (i - 1 + 'a'))
}
val rows2 = (4L to 7L).map { i =>
Row(i, "" + (i - 1 + 'a'))
}
val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))
spark.sql(
"CREATE TABLE testcat.myTableV2 (id bigint, data string) USING foo PARTITIONED BY (id)")
assert(spark.table("testcat.myTableV2").collect().isEmpty)
spark.createDataFrame(rows1.asJava, schema).writeTo("testcat.myTableV2").append()
val outputRows = spark.table("testcat.myTableV2").collect()
assert(outputRows.length == 3)
spark
.createDataFrame(rows2.asJava, schema)
.writeTo("testcat.myTableV2")
.overwrite(functions.expr("true"))
val outputRows2 = spark.table("testcat.myTableV2").collect()
assert(outputRows2.length == 4)
}
}
test("WriteTo with overwritePartitions") {
withTable("testcat.myTableV2") {
val rows = (4L to 7L).map { i =>
Row(i, "" + (i - 1 + 'a'))
}
val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))
spark.sql(
"CREATE TABLE testcat.myTableV2 (id bigint, data string) USING foo PARTITIONED BY (id)")
assert(spark.table("testcat.myTableV2").collect().isEmpty)
spark
.createDataFrame(rows.asJava, schema)
.writeTo("testcat.myTableV2")
.overwritePartitions()
val outputRows = spark.table("testcat.myTableV2").collect()
assert(outputRows.length == 4)
}
}
test("write path collision") {
val df = spark.range(10)
val outputFolderPath = Files.createTempDirectory("output").toAbsolutePath
// Failed because the path cannot be provided both via option and save method.
assertThrows[AnalysisException] {
df.write.option("path", outputFolderPath.toString).save(outputFolderPath.toString)
}
}
// TODO test large result when we can create table or view
// test("test spark large result")
private def captureStdOut(block: => Unit): String = {
val currentOut = Console.out
val capturedOut = new ByteArrayOutputStream()
val newOut = new PrintStream(new TeeOutputStream(currentOut, capturedOut))
Console.withOut(newOut) {
block
}
capturedOut.toString
}
private def checkFragments(result: String, fragmentsToCheck: Seq[String]): Unit = {
fragmentsToCheck.foreach { fragment =>
assert(result.contains(fragment))
}
}
private def testCapturedStdOut(block: => Unit, fragmentsToCheck: String*): Unit = {
checkFragments(captureStdOut(block), fragmentsToCheck)
}
private def testCapturedStdOut(
block: => Unit,
expectedNumLines: Int,
expectedMaxWidth: Int,
fragmentsToCheck: String*): Unit = {
val result = captureStdOut(block)
val lines = result.split('\n')
assert(lines.length === expectedNumLines)
assert(lines.map((s: String) => s.length).max <= expectedMaxWidth)
checkFragments(result, fragmentsToCheck)
}
private val simpleSchema = new StructType().add("id", "long", nullable = false)
// Dataset tests
test("Dataset inspection") {
val df = spark.range(10)
val local = spark.newDataFrame { builder =>
builder.getLocalRelationBuilder.setSchema(simpleSchema.catalogString)
}
assert(!df.isLocal)
assert(local.isLocal)
assert(!df.isStreaming)
assert(df.toString.contains("[id: bigint]"))
assert(df.inputFiles.isEmpty)
}
test("Dataset schema") {
val df = spark.range(10)
assert(df.schema === simpleSchema)
assert(df.dtypes === Array(("id", "LongType")))
assert(df.columns === Array("id"))
testCapturedStdOut(df.printSchema(), simpleSchema.treeString)
testCapturedStdOut(df.printSchema(5), simpleSchema.treeString(5))
}
test("Dataframe schema") {
val df = spark.sql("select * from range(10)")
val expectedSchema = new StructType().add("id", "long", nullable = false)
assert(df.schema === expectedSchema)
assert(df.dtypes === Array(("id", "LongType")))
assert(df.columns === Array("id"))
testCapturedStdOut(df.printSchema(), expectedSchema.treeString)
testCapturedStdOut(df.printSchema(5), expectedSchema.treeString(5))
}
test("Dataset explain") {
val df = spark.range(10)
val simpleExplainFragments = Seq("== Physical Plan ==")
testCapturedStdOut(df.explain(), simpleExplainFragments: _*)
testCapturedStdOut(df.explain(false), simpleExplainFragments: _*)
testCapturedStdOut(df.explain("simple"), simpleExplainFragments: _*)
val extendedExplainFragments = Seq(
"== Parsed Logical Plan ==",
"== Analyzed Logical Plan ==",
"== Optimized Logical Plan ==") ++
simpleExplainFragments
testCapturedStdOut(df.explain(true), extendedExplainFragments: _*)
testCapturedStdOut(df.explain("extended"), extendedExplainFragments: _*)
testCapturedStdOut(
df.explain("cost"),
simpleExplainFragments :+ "== Optimized Logical Plan ==": _*)
testCapturedStdOut(df.explain("codegen"), "WholeStageCodegen subtrees.")
testCapturedStdOut(df.explain("formatted"), "Range", "Arguments: ")
}
test("Dataset result collection") {
def checkResult(rows: IterableOnce[java.lang.Long], expectedValues: Long*): Unit = {
rows.iterator.zipAll(expectedValues.iterator, null, null).foreach {
case (actual, expected) => assert(actual === expected)
}
}
val df = spark.range(10)
checkResult(df.head() :: Nil, 0L)
checkResult(df.head(5), 0L, 1L, 2L, 3L, 4L)
checkResult(df.first() :: Nil, 0L)
assert(!df.isEmpty)
assert(df.filter("id > 100").isEmpty)
checkResult(df.take(3), 0L, 1L, 2L)
checkResult(df.tail(3), 7L, 8L, 9L)
checkResult(df.takeAsList(10).asScala, 0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L)
checkResult(df.filter("id % 3 = 0").collect(), 0L, 3L, 6L, 9L)
checkResult(df.filter("id < 3").collectAsList().asScala, 0L, 1L, 2L)
val iterator = df.filter("id > 5 and id < 9").toLocalIterator()
try {
checkResult(iterator.asScala, 6L, 7L, 8L)
} finally {
iterator.asInstanceOf[AutoCloseable].close()
}
}
test("Dataset show") {
val df = spark.range(20)
testCapturedStdOut(df.show(), 24, 5, "+---+", "| id|", "| 0|", "| 19|")
testCapturedStdOut(
df.show(10),
15,
24,
"+---+",
"| id|",
"| 0|",
"| 9|",
"only showing top 10 rows")
val wideDf =
spark.range(4).selectExpr("id", "concat('very_very_very_long_string', id) as val")
testCapturedStdOut(
wideDf.show(true),
8,
26,
"+---+--------------------+",
"| id| val|",
"| 0|very_very_very_lo...|")
testCapturedStdOut(
wideDf.show(false),
8,
33,
"+---+---------------------------+",
"|id |val |",
"|2 |very_very_very_long_string2|")
testCapturedStdOut(
wideDf.show(2, truncate = false),
7,
33,
"+---+---------------------------+",
"|id |val |",
"|1 |very_very_very_long_string1|",
"only showing top 2 rows")
testCapturedStdOut(
df.show(8, 10, vertical = true),
17,
23,
"-RECORD 3--",
"id | 7",
"only showing top 8 rows")
}
test("Dataset randomSplit") {
implicit val tolerance = TolerantNumerics.tolerantDoubleEquality(0.01)
val df = spark.range(100)
def checkSample(
ds: Dataset[java.lang.Long],
lower: Double,
upper: Double,
seed: Long): Unit = {
assert(ds.plan.getRoot.hasSample)
val sample = ds.plan.getRoot.getSample
assert(sample.getSeed === seed)
assert(sample.getLowerBound === lower)
assert(sample.getUpperBound === upper)
}
val Array(ds1, ds2, ds3) = df.randomSplit(Array(8, 9, 7), 123L)
checkSample(ds1, 0, 8.0 / 24.0, 123L)
checkSample(ds2, 8.0 / 24.0, 17.0 / 24.0, 123L)
checkSample(ds3, 17.0 / 24.0, 1.0, 123L)
val datasets = df.randomSplitAsList(Array(1, 2, 3, 4), 9L)
assert(datasets.size() === 4)
checkSample(datasets.get(0), 0, 1.0 / 10.0, 9L)
checkSample(datasets.get(1), 1.0 / 10.0, 3.0 / 10.0, 9L)
checkSample(datasets.get(2), 3.0 / 10.0, 6.0 / 10.0, 9L)
checkSample(datasets.get(3), 6.0 / 10.0, 1.0, 9L)
}
test("Dataset count") {
assert(spark.range(10).count() === 10)
}
test("Dataset collect tuple") {
val session = spark
import session.implicits._
val result = session
.range(3)
.select(col("id"), (col("id") % 2).cast("int").as("a"), (col("id") / lit(10.0d)).as("b"))
.as[(Long, Int, Double)]
.collect()
result.zipWithIndex.foreach { case ((id, a, b), i) =>
assert(id == i)
assert(a == id % 2)
assert(b == id / 10.0d)
}
}
private val generateMyTypeColumns = Seq(
(col("id") / lit(10.0d)).as("b"),
col("id"),
lit("world").as("d"),
(col("id") % 2).as("a"))
private def validateMyTypeResult(result: Array[MyType]): Unit = {
result.zipWithIndex.foreach { case (MyType(id, a, b), i) =>
assert(id == i)
assert(a == id % 2)
assert(b == id / 10.0d)
}
}
private def validateMyTypeResult(result: Array[(MyType, MyType, MyType)]): Unit = {
result.zipWithIndex.foreach { case (row, i) =>
val t1 = row._1
val t2 = row._2
val t3 = row._3
assert(t1 === t2)
assert(t2 === t3)
assert(t1.id == i)
assert(t1.a == t1.id % 2)
assert(t1.b == t1.id / 10.0d)
}
}
test("Dataset collect complex type") {
val session = spark
import session.implicits._
val result = session
.range(3)
.select(generateMyTypeColumns: _*)
.as[MyType]
.collect()
validateMyTypeResult(result)
}
test("Dataset typed select - simple column") {
val numRows = spark.range(1000).select(count("id")).first()
assert(numRows === 1000)
}
test("Dataset typed select - multiple columns") {
val result = spark.range(1000).select(count("id"), sum("id")).first()
assert(result.getLong(0) === 1000)
assert(result.getLong(1) === 499500)
}
test("Dataset typed select - complex column") {
val session = spark
import session.implicits._
val ds = session
.range(3)
.select(struct(generateMyTypeColumns: _*).as[MyType])
validateMyTypeResult(ds.collect())
}
test("Dataset typed select - multiple complex columns") {
val session = spark
import session.implicits._
val s = struct(generateMyTypeColumns: _*).as[MyType]
val ds = session
.range(3)
.select(s, s, s)
validateMyTypeResult(ds.collect())
}
test("lambda functions") {
// This test is mostly to validate lambda variables are properly resolved.
val result = spark
.range(3)
.select(
col("id"),
array(sequence(col("id"), lit(10)), sequence(col("id") * 2, lit(10))).as("data"))
.select(col("id"), transform(col("data"), x => transform(x, x => x + 1)).as("data"))
.select(
col("id"),
transform(col("data"), x => aggregate(x, lit(0L), (x, y) => x + y)).as("summaries"))
.collect()
val expected = Array(Row(0L, Seq(66L, 66L)), Row(1L, Seq(65L, 63L)), Row(2L, Seq(63L, 56L)))
assert(result === expected)
}
test("shuffle array") {
// We cannot do structural tests for shuffle because its random seed will always change.
val result = spark
.sql("select 1")
.select(shuffle(array(lit(1), lit(2), lit(3), lit(74))))
.head()
.getSeq[Int](0)
assert(result.toSet === Set(1, 2, 3, 74))
}
test("ambiguous joins") {
val left = spark.range(100).select(col("id"), rand(10).as("a"))
val right = spark.range(100).select(col("id"), rand(12).as("a"))
val joined = left.join(right, left("id") === right("id")).select(left("id"), right("a"))
assert(joined.schema.catalogString === "struct<id:bigint,a:double>")
val joined2 = left
.join(right, left.colRegex("id") === right.colRegex("id"))
.select(left("id"), right("a"))
assert(joined2.schema.catalogString === "struct<id:bigint,a:double>")
}
test("join with dataframe star") {
val left = spark.range(100)
val right = spark.range(100).select(col("id"), rand(12).as("a"))
val join1 = left.join(right, left("id") === right("id"))
assert(
join1.select(join1.col("*")).schema.catalogString ===
"struct<id:bigint,id:bigint,a:double>")
assert(join1.select(left.col("*")).schema.catalogString === "struct<id:bigint>")
assert(join1.select(right.col("*")).schema.catalogString === "struct<id:bigint,a:double>")
val join2 = left.join(right)
assert(
join2.select(join2.col("*")).schema.catalogString ===
"struct<id:bigint,id:bigint,a:double>")
assert(join2.select(left.col("*")).schema.catalogString === "struct<id:bigint>")
assert(join2.select(right.col("*")).schema.catalogString === "struct<id:bigint,a:double>")
val join3 = left.join(right, "id")
assert(
join3.select(join3.col("*")).schema.catalogString ===
"struct<id:bigint,a:double>")
assert(join3.select(left.col("*")).schema.catalogString === "struct<id:bigint>")
assert(join3.select(right.col("*")).schema.catalogString === "struct<id:bigint,a:double>")
}
test("SPARK-45509: ambiguous column reference") {
val session = spark
import session.implicits._
val df1 = Seq(1 -> "a").toDF("i", "j")
val df1_filter = df1.filter(df1("i") > 0)
val df2 = Seq(2 -> "b").toDF("i", "y")
checkSameResult(
Seq(Row(1)),
// df1("i") is not ambiguous, and it's still valid in the filtered df.
df1_filter.select(df1("i")))
val e1 = intercept[AnalysisException] {
// df1("i") is not ambiguous, but it's not valid in the projected df.
df1.select((df1("i") + 1).as("plus")).select(df1("i")).collect()
}
assert(e1.getMessage.contains("UNRESOLVED_COLUMN.WITH_SUGGESTION"))
checkSameResult(
Seq(Row(1, "a")),
// All these column references are not ambiguous and are still valid after join.
df1.join(df2, df1("i") + 1 === df2("i")).sort(df1("i").desc).select(df1("i"), df1("j")))
val e2 = intercept[AnalysisException] {
// df1("i") is ambiguous as df1 appears in both join sides.
df1.join(df1, df1("i") === 1).collect()
}
assert(e2.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE"))
val e3 = intercept[AnalysisException] {
// df1("i") is ambiguous as df1 appears in both join sides.
df1.join(df1).select(df1("i")).collect()
}
assert(e3.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE"))
// TODO(SPARK-47749): Dataframe.collect should accept duplicated column names
assert(
// df1.join(df1_filter, df1("i") === 1) fails in classic spark due to:
// org.apache.spark.sql.AnalysisException: Column i#24 are ambiguous
df1.join(df1_filter, df1("i") === 1).columns ===
Array("i", "j", "i", "j"))
checkSameResult(
Seq(Row("a")),
// df1_filter("i") is not ambiguous as df1_filter does not exist in the join left side.
df1.join(df1_filter, df1_filter("i") === 1).select(df1_filter("j")))
val e5 = intercept[AnalysisException] {
// df1("i") is ambiguous as df1 appears in both sides of the first join.
df1.join(df1_filter, df1_filter("i") === 1).join(df2, df1("i") === 1).collect()
}
assert(e5.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE"))
checkSameResult(
Seq(Row("a")),
// df1_filter("i") is not ambiguous as df1_filter only appears once.
df1.join(df1_filter).join(df2, df1_filter("i") === 1).select(df1_filter("j")))
}
test("broadcast join") {
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
val left = spark.range(100).select(col("id"), rand(10).as("a"))
val right = spark.range(100).select(col("id"), rand(12).as("a"))
val joined =
left.join(broadcast(right), left("id") === right("id")).select(left("id"), right("a"))
assert(joined.schema.catalogString === "struct<id:bigint,a:double>")
testCapturedStdOut(joined.explain(), "BroadcastHashJoin")
}
}
test("test temp view") {
try {
spark.range(100).createTempView("test1")
assert(spark.sql("SELECT * FROM test1").count() == 100)
spark.range(1000).createOrReplaceTempView("test1")
assert(spark.sql("SELECT * FROM test1").count() == 1000)
spark.range(100).createGlobalTempView("view1")
assert(spark.sql("SELECT * FROM global_temp.view1").count() == 100)
spark.range(1000).createOrReplaceGlobalTempView("view1")
assert(spark.sql("SELECT * FROM global_temp.view1").count() == 1000)
} finally {
spark.sql("DROP VIEW IF EXISTS test1")
spark.sql("DROP VIEW IF EXISTS global_temp.view1")
}
}
test("time") {
val timeFragments = Seq("Time taken: ", " ms")
testCapturedStdOut(spark.time(spark.sql("select 1").collect()), timeFragments: _*)
}
test("RuntimeConfig") {
intercept[NoSuchElementException](spark.conf.get("foo.bar"))
assert(spark.conf.getOption("foo.bar").isEmpty)
spark.conf.set("foo.bar", value = true)
assert(spark.conf.getOption("foo.bar") === Option("true"))
spark.conf.set("foo.bar.numBaz", 100L)
assert(spark.conf.get("foo.bar.numBaz") === "100")
spark.conf.set("foo.bar.name", "donkey")
assert(spark.conf.get("foo.bar.name") === "donkey")
spark.conf.unset("foo.bar.name")
val allKeyValues = spark.conf.getAll
assert(allKeyValues("foo.bar") === "true")
assert(allKeyValues("foo.bar.numBaz") === "100")
assert(!spark.conf.isModifiable("foo.bar")) // This is a bit odd.
assert(spark.conf.isModifiable("spark.sql.ansi.enabled"))
assert(!spark.conf.isModifiable("spark.sql.globalTempDatabase"))
intercept[Exception](spark.conf.set("spark.sql.globalTempDatabase", "/dev/null"))
}
test("SparkVersion") {
assert(spark.version.nonEmpty)
assert(spark.version == SPARK_VERSION)
}
private def checkSameResult[E](expected: scala.collection.Seq[E], dataset: Dataset[E]): Unit = {
dataset.withResult { result =>
assert(expected === result.iterator.toBuffer)
}
}
test("Local Relation implicit conversion") {
val session = spark
import session.implicits._
val simpleValues = Seq(1, 24, 3)
checkSameResult(simpleValues, simpleValues.toDS())
checkSameResult(simpleValues.map(v => Row(v)), simpleValues.toDF())
val complexValues = Seq((5, "a"), (6, "b"))
checkSameResult(complexValues, complexValues.toDS())
checkSameResult(
complexValues.map(kv => KV(kv._2, kv._1)),
complexValues.toDF("value", "key").as[KV])
}
test("SparkSession.createDataFrame - row") {
val rows = java.util.Arrays.asList(Row("bob", 99), Row("Club", 5), Row("Bag", 5))
val schema = new StructType().add("key", "string").add("value", "int")
checkSameResult(rows.asScala, spark.createDataFrame(rows, schema))
}
test("SparkSession.createDataFrame - bean") {
def bean(v: String): SimpleBean = {
val bean = new SimpleBean
bean.setValue(v)
bean
}
val beans = java.util.Arrays.asList(bean("x"), bean("s"), bean("d"))
checkSameResult(
beans.asScala.map(b => Row(b.value)),
spark.createDataFrame(beans, classOf[SimpleBean]))
}
test("SparkSession typed createDataSet/createDataframe") {
val session = spark
import session.implicits._
val list = java.util.Arrays.asList(KV("bob", 99), KV("Club", 5), KV("Bag", 5))
checkSameResult(list.asScala, session.createDataset(list))
checkSameResult(
list.asScala.map(kv => Row(kv.key, kv.value)),
session.createDataFrame(list.asScala.toSeq))
}
test("SparkSession newSession") {
val oldId = spark.sql("SELECT 1").analyze.getSessionId
val newId = spark.newSession().sql("SELECT 1").analyze.getSessionId
assert(oldId != newId)
}
test("createDataFrame from complex type schema") {
val schema = new StructType()
.add(
"c1",
new StructType()
.add("c1-1", StringType)
.add("c1-2", StringType))
val data = Seq(Row(Row(null, "a2")), Row(Row("b1", "b2")), Row(null))
val result = spark.createDataFrame(data.asJava, schema).collect()
assert(result === data)
}
test("SameSemantics") {
val plan = spark.sql("select 1")
val otherPlan = spark.sql("select 1")
assert(plan.sameSemantics(otherPlan))
}
test("sameSemantics and semanticHash") {
val df1 = spark.createDataFrame(Seq((1, 2), (4, 5)))
val df2 = spark.createDataFrame(Seq((1, 2), (4, 5)))
val df3 = spark.createDataFrame(Seq((0, 2), (4, 5)))
val df4 = spark.createDataFrame(Seq((0, 2), (4, 5)))
assert(df1.sameSemantics(df2) === true)
assert(df1.sameSemantics(df3) === false)
assert(df3.sameSemantics(df4) === true)
assert(df1.semanticHash() === df2.semanticHash())
assert(df1.semanticHash() !== df3.semanticHash())
assert(df3.semanticHash() === df4.semanticHash())
}
test("toJSON") {
val expected = Array(
"""{"b":0.0,"id":0,"d":"world","a":0}""",
"""{"b":0.1,"id":1,"d":"world","a":1}""",
"""{"b":0.2,"id":2,"d":"world","a":0}""")
val result = spark
.range(3)
.select(generateMyTypeColumns: _*)
.toJSON
.collect()
assert(result sameElements expected)
}
test("json from Dataset[String] inferSchema") {
val session = spark
import session.implicits._
val expected = Seq(
new GenericRowWithSchema(
Array(73, "Shandong", "Kong"),
new StructType().add("age", LongType).add("city", StringType).add("name", StringType)))
val ds = Seq("""{"name":"Kong","age":73,"city":'Shandong'}""").toDS()
val result = spark.read.option("allowSingleQuotes", "true").json(ds)
checkSameResult(expected, result)
}
test("json from Dataset[String] with schema") {
val session = spark
import session.implicits._
val schema = new StructType().add("city", StringType).add("name", StringType)
val expected = Seq(new GenericRowWithSchema(Array("Shandong", "Kong"), schema))
val ds = Seq("""{"name":"Kong","age":73,"city":'Shandong'}""").toDS()
val result = spark.read.schema(schema).option("allowSingleQuotes", "true").json(ds)
checkSameResult(expected, result)
}
test("json from Dataset[String] with invalid schema") {
val message = intercept[ParseException] {
spark.read.schema("123").json(spark.createDataset(Seq.empty[String])(StringEncoder))
}.getMessage
assert(message.contains("PARSE_SYNTAX_ERROR"))
}
test("csv from Dataset[String] inferSchema") {
val session = spark
import session.implicits._
val expected = Seq(
new GenericRowWithSchema(
Array("Meng", 84, "Shandong"),
new StructType().add("name", StringType).add("age", LongType).add("city", StringType)))
val ds = Seq("name,age,city", """"Meng",84,"Shandong"""").toDS()
val result = spark.read
.option("header", "true")
.option("inferSchema", "true")
.csv(ds)
checkSameResult(expected, result)
}
test("csv from Dataset[String] with schema") {
val session = spark
import session.implicits._
val schema = new StructType().add("name", StringType).add("age", LongType)
val expected = Seq(new GenericRowWithSchema(Array("Meng", 84), schema))
val ds = Seq(""""Meng",84,"Shandong"""").toDS()
val result = spark.read.schema(schema).csv(ds)
checkSameResult(expected, result)
}
test("csv from Dataset[String] with invalid schema") {
val message = intercept[ParseException] {
spark.read.schema("123").csv(spark.createDataset(Seq.empty[String])(StringEncoder))
}.getMessage
assert(message.contains("PARSE_SYNTAX_ERROR"))
}
test("Dataset result destructive iterator") {
// Helper methods for accessing private field `idxToBatches` from SparkResult
val getResultMap =
PrivateMethod[mutable.Map[Int, Any]](Symbol("resultMap"))
def assertResultsMapEmpty(result: SparkResult[_]): Unit = {
val resultMap = result invokePrivate getResultMap()
assert(resultMap.isEmpty)
}
val df = spark
.range(0, 10, 1, 10)
.filter("id > 5 and id < 9")
df.withResult { result =>
try {
// build and verify the destructive iterator
val iterator = result.destructiveIterator
// resultMap Map is empty before traversing the result iterator
assertResultsMapEmpty(result)
val buffer = mutable.Set.empty[Long]
while (iterator.hasNext) {
// resultMap is empty during iteration because results get removed immediately on access.
assertResultsMapEmpty(result)
buffer += iterator.next()
}
// resultMap Map is empty afterward because all results have been removed.
assertResultsMapEmpty(result)
val expectedResult = Set(6L, 7L, 8L)
assert(buffer.size === 3 && expectedResult == buffer)
} finally {
result.close()
}
}
}
test("SparkSession.createDataFrame - large data set") {
val threshold = 1024 * 1024
withSQLConf(SqlApiConf.LOCAL_RELATION_CACHE_THRESHOLD_KEY -> threshold.toString) {
val count = 2
val suffix = "abcdef"
val str = scala.util.Random.alphanumeric.take(1024 * 1024).mkString + suffix
val data = Seq.tabulate(count)(i => (i, str))
for (_ <- 0 until 2) {
val df = spark.createDataFrame(data)
assert(df.count() === count)
assert(!df.filter(df("_2").endsWith(suffix)).isEmpty)
}
}
}
test("sql() with positional parameters") {
val result0 = spark.sql("select 1", Array.empty).collect()
assert(result0.length == 1 && result0(0).getInt(0) === 1)
val result1 = spark.sql("select ?", Array(1)).collect()
assert(result1.length == 1 && result1(0).getInt(0) === 1)
val result2 = spark.sql("select ?, ?", Array(1, "abc")).collect()
assert(result2.length == 1)
assert(result2(0).getInt(0) === 1)
assert(result2(0).getString(1) === "abc")
val result3 = spark.sql("select element_at(?, 1)", Array(array(lit(1)))).collect()
assert(result3.length == 1 && result3(0).getInt(0) === 1)
}
test("sql() with named parameters") {
val result0 = spark.sql("select 1", Map.empty[String, Any]).collect()
assert(result0.length == 1 && result0(0).getInt(0) === 1)
val result1 = spark.sql("select :abc", Map("abc" -> 1)).collect()
assert(result1.length == 1 && result1(0).getInt(0) === 1)
val result2 = spark.sql("select :c0 limit :l0", Map("l0" -> 1, "c0" -> "abc")).collect()
assert(result2.length == 1 && result2(0).getString(0) === "abc")
val result3 =
spark.sql("select element_at(:m, 'a')", Map("m" -> map(lit("a"), lit(1)))).collect()
assert(result3.length == 1 && result3(0).getInt(0) === 1)
}
test("joinWith, flat schema") {
val session: SparkSession = spark
import session.implicits._
val ds1 = Seq(1, 2, 3).toDS().as("a")
val ds2 = Seq(1, 2).toDS().as("b")
val joined = ds1.joinWith(ds2, $"a.value" === $"b.value", "inner")
val expectedSchema = StructType(
Seq(
StructField("_1", IntegerType, nullable = false),
StructField("_2", IntegerType, nullable = false)))
assert(joined.schema === expectedSchema)
val expected = Seq((1, 1), (2, 2))
checkSameResult(expected, joined)
}
test("joinWith tuple with primitive, expression") {
val session: SparkSession = spark
import session.implicits._
val ds1 = Seq(1, 1, 2).toDS()
val ds2 = Seq(("a", 1), ("b", 2)).toDS()
val joined = ds1.joinWith(ds2, $"value" === $"_2")
// This is an inner join, so both outputs fields are non-nullable
val expectedSchema = StructType(
Seq(
StructField("_1", IntegerType, nullable = false),
StructField(
"_2",
StructType(
Seq(StructField("_1", StringType), StructField("_2", IntegerType, nullable = false))),
nullable = false)))
assert(joined.schema === expectedSchema)
checkSameResult(Seq((1, ("a", 1)), (1, ("a", 1)), (2, ("b", 2))), joined)
}
test("joinWith tuple with primitive, rows") {
val session: SparkSession = spark
import session.implicits._
val ds1 = Seq(1, 1, 2).toDF()
val ds2 = Seq(("a", 1), ("b", 2)).toDF()
val joined = ds1.joinWith(ds2, $"value" === $"_2")
checkSameResult(
Seq((Row(1), Row("a", 1)), (Row(1), Row("a", 1)), (Row(2), Row("b", 2))),
joined)
}
test("joinWith class with primitive, toDF") {
val session: SparkSession = spark
import session.implicits._
val ds1 = Seq(1, 1, 2).toDS()
val ds2 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
val df = ds1
.joinWith(ds2, $"value" === $"b")
.toDF()
.select($"_1", $"_2.a", $"_2.b")
checkSameResult(Seq(Row(1, "a", 1), Row(1, "a", 1), Row(2, "b", 2)), df)
}
test("multi-level joinWith") {
val session: SparkSession = spark
import session.implicits._
val ds1 = Seq(("a", 1), ("b", 2)).toDS().as("a")
val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b")
val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c")
val joined = ds1
.joinWith(ds2, $"a._2" === $"b._2")
.as("ab")
.joinWith(ds3, $"ab._1._2" === $"c._2")
checkSameResult(
Seq(((("a", 1), ("a", 1)), ("a", 1)), ((("b", 2), ("b", 2)), ("b", 2))),
joined)
}
test("multi-level joinWith, rows") {
val session: SparkSession = spark
import session.implicits._
val ds1 = Seq(("a", 1), ("b", 2)).toDF().as("a")
val ds2 = Seq(("a", 1), ("b", 2)).toDF().as("b")
val ds3 = Seq(("a", 1), ("b", 2)).toDF().as("c")
val joined = ds1
.joinWith(ds2, $"a._2" === $"b._2")
.as("ab")
.joinWith(ds3, $"ab._1._2" === $"c._2")
checkSameResult(
Seq(((Row("a", 1), Row("a", 1)), Row("a", 1)), ((Row("b", 2), Row("b", 2)), Row("b", 2))),
joined)
}
test("self join") {
val session: SparkSession = spark
import session.implicits._
val ds = Seq("1", "2").toDS().as("a")
val joined = ds.joinWith(ds, lit(true), "cross")
checkSameResult(Seq(("1", "1"), ("1", "2"), ("2", "1"), ("2", "2")), joined)
}
test("SPARK-11894: Incorrect results are returned when using null") {
val session: SparkSession = spark
import session.implicits._
val nullInt = null.asInstanceOf[java.lang.Integer]
val ds1 = Seq((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")).toDS()
val ds2 = Seq((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")).toDS()
checkSameResult(
Seq(
((nullInt, "1"), (nullInt, "1")),
((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")),
((java.lang.Integer.valueOf(22), "2"), (nullInt, "1")),
((java.lang.Integer.valueOf(22), "2"), (java.lang.Integer.valueOf(22), "2"))),
ds1.joinWith(ds2, lit(true), "cross"))
}
test("SPARK-15441: Dataset outer join") {
val session: SparkSession = spark
import session.implicits._
val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS().as("left")
val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDS().as("right")
val joined = left.joinWith(right, $"left.b" === $"right.b", "left")
val expectedSchema = StructType(
Seq(
StructField(
"_1",
StructType(
Seq(StructField("a", StringType), StructField("b", IntegerType, nullable = false))),
nullable = false),
// This is a left join, so the right output is nullable:
StructField(
"_2",
StructType(
Seq(StructField("a", StringType), StructField("b", IntegerType, nullable = false))))))
assert(joined.schema === expectedSchema)
val result = joined.collect().toSet
assert(result == Set(ClassData("a", 1) -> null, ClassData("b", 2) -> ClassData("x", 2)))
}
test("SPARK-37829: DataFrame outer join") {
// Same as "SPARK-15441: Dataset outer join" but using DataFrames instead of Datasets
val session: SparkSession = spark
import session.implicits._
val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDF().as("left")
val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDF().as("right")
val joined = left.joinWith(right, $"left.b" === $"right.b", "left")
val leftFieldSchema = StructType(
Seq(StructField("a", StringType), StructField("b", IntegerType, nullable = false)))
val rightFieldSchema = StructType(
Seq(StructField("a", StringType), StructField("b", IntegerType, nullable = false)))
val expectedSchema = StructType(
Seq(
StructField("_1", leftFieldSchema, nullable = false),
// This is a left join, so the right output is nullable:
StructField("_2", rightFieldSchema)))
assert(joined.schema === expectedSchema)
val result = joined.collect().toSet
val expected = Set(
new GenericRowWithSchema(Array("a", 1), leftFieldSchema) ->
null,
new GenericRowWithSchema(Array("b", 2), leftFieldSchema) ->
new GenericRowWithSchema(Array("x", 2), rightFieldSchema))
assert(result == expected)
}
test("SPARK-24762: joinWith on Option[Product]") {
val session: SparkSession = spark
import session.implicits._
val ds1 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("a")
val ds2 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("b")
val joined = ds1.joinWith(ds2, $"a.value._1" === $"b.value._2", "inner")
checkSameResult(Seq((Some((2, 3)), Some((1, 2)))), joined)
}
test("dropDuplicatesWithinWatermark not supported in batch DataFrame") {
def testAndVerify(df: Dataset[_]): Unit = {
val exc = intercept[AnalysisException] {
df.write.format("noop").mode(SaveMode.Append).save()
}
assert(exc.getMessage.contains("dropDuplicatesWithinWatermark is not supported"))
assert(exc.getMessage.contains("batch DataFrames/DataSets"))
}
val result = spark.range(10).dropDuplicatesWithinWatermark()
testAndVerify(result)
val result2 = spark
.range(10)
.withColumn("newcol", col("id"))
.dropDuplicatesWithinWatermark("newcol")
testAndVerify(result2)
}
test("Dataset.metadataColumn") {
val session: SparkSession = spark
import session.implicits._
withTempPath { file =>
val path = file.getAbsoluteFile.toURI.toString
spark
.range(0, 100, 1, 1)
.withColumn("_metadata", concat(lit("lol_"), col("id")))
.write
.parquet(file.toPath.toAbsolutePath.toString)
val df = spark.read.parquet(path)
val (filepath, rc) = df
.groupBy(df.metadataColumn("_metadata").getField("file_path"))
.count()
.as[(String, Long)]
.head()
assert(filepath.startsWith(path))
assert(rc == 100)
}
}
test("SPARK-45216: Non-deterministic functions with seed") {
val session: SparkSession = spark
import session.implicits._
val df = Seq(Array.range(0, 10)).toDF("a")
val r = rand()
val r2 = randn()
val r3 = random()
val r4 = uuid()
val r5 = shuffle(col("a"))
df.select(r, r.as("r"), r2, r2.as("r2"), r3, r3.as("r3"), r4, r4.as("r4"), r5, r5.as("r5"))
.collect()
.foreach { row =>
(0 until 5).foreach(i => assert(row.get(i * 2) === row.get(i * 2 + 1)))
}
}
test("Observable metrics") {
val df = spark.range(99).withColumn("extra", col("id") - 1)
val ob1 = new Observation("ob1")
val observedDf = df.observe(ob1, min("id"), avg("id"), max("id"))
val observedObservedDf = observedDf.observe("ob2", min("extra"), avg("extra"), max("extra"))
val ob1Schema = new StructType()
.add("min(id)", LongType)
.add("avg(id)", DoubleType)
.add("max(id)", LongType)
val ob2Schema = new StructType()
.add("min(extra)", LongType)
.add("avg(extra)", DoubleType)
.add("max(extra)", LongType)
val ob1Metrics = Map("ob1" -> new GenericRowWithSchema(Array(0, 49, 98), ob1Schema))
val ob2Metrics = Map("ob2" -> new GenericRowWithSchema(Array(-1, 48, 97), ob2Schema))
assert(df.collectResult().getObservedMetrics === Map.empty)
assert(observedDf.collectResult().getObservedMetrics === ob1Metrics)
assert(observedObservedDf.collectResult().getObservedMetrics === ob1Metrics ++ ob2Metrics)
}
test("Observation.get is blocked until the query is finished") {
val df = spark.range(99).withColumn("extra", col("id") - 1)
val observation = new Observation("ob1")
val observedDf = df.observe(observation, min("id"), avg("id"), max("id"))
// Start a new thread to get the observation
val future = Future(observation.get)(ExecutionContext.global)
// make sure the thread is blocked right now
val e = intercept[java.util.concurrent.TimeoutException] {
SparkThreadUtils.awaitResult(future, 2.seconds)
}
assert(e.getMessage.contains("Future timed out"))
observedDf.collect()
// make sure the thread is unblocked after the query is finished
val metrics = SparkThreadUtils.awaitResult(future, 2.seconds)
assert(metrics === Map("min(id)" -> 0, "avg(id)" -> 49, "max(id)" -> 98))
}
}
private[sql] case class ClassData(a: String, b: Int)
private[sql] case class MyType(id: Long, a: Double, b: Double)
private[sql] case class KV(key: String, value: Int)
private[sql] class SimpleBean {
@scala.beans.BeanProperty
var value: String = _
}