blob: 37bc35210e6c544ebb02c9f529390e0b70f50495 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.jdbc
import java.sql.{Connection, DriverManager}
import java.util.Properties
import org.apache.spark.SparkConf
import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
import org.apache.spark.sql.functions.{lit, sum, udf}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils
class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHelper {
import testImplicits._
val tempDir = Utils.createTempDir()
val url = s"jdbc:h2:${tempDir.getCanonicalPath};user=testUser;password=testPass"
var conn: java.sql.Connection = null
override def sparkConf: SparkConf = super.sparkConf
.set("spark.sql.catalog.h2", classOf[JDBCTableCatalog].getName)
.set("spark.sql.catalog.h2.url", url)
.set("spark.sql.catalog.h2.driver", "org.h2.Driver")
.set("spark.sql.catalog.h2.pushDownAggregate", "true")
private def withConnection[T](f: Connection => T): T = {
val conn = DriverManager.getConnection(url, new Properties())
try {
f(conn)
} finally {
conn.close()
}
}
override def beforeAll(): Unit = {
super.beforeAll()
Utils.classForName("org.h2.Driver")
withConnection { conn =>
conn.prepareStatement("CREATE SCHEMA \"test\"").executeUpdate()
conn.prepareStatement(
"CREATE TABLE \"test\".\"empty_table\" (name TEXT(32) NOT NULL, id INTEGER NOT NULL)")
.executeUpdate()
conn.prepareStatement(
"CREATE TABLE \"test\".\"people\" (name TEXT(32) NOT NULL, id INTEGER NOT NULL)")
.executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('fred', 1)").executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('mary', 2)").executeUpdate()
conn.prepareStatement(
"CREATE TABLE \"test\".\"employee\" (dept INTEGER, name TEXT(32), salary NUMERIC(20, 2)," +
" bonus DOUBLE)").executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000)")
.executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200)")
.executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200)")
.executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300)")
.executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200)")
.executeUpdate()
}
}
override def afterAll(): Unit = {
Utils.deleteRecursively(tempDir)
super.afterAll()
}
test("simple scan") {
checkAnswer(sql("SELECT * FROM h2.test.empty_table"), Seq())
checkAnswer(sql("SELECT * FROM h2.test.people"), Seq(Row("fred", 1), Row("mary", 2)))
checkAnswer(sql("SELECT name, id FROM h2.test.people"), Seq(Row("fred", 1), Row("mary", 2)))
}
test("scan with filter push-down") {
val df = spark.table("h2.test.people").filter($"id" > 1)
val filters = df.queryExecution.optimizedPlan.collect {
case f: Filter => f
}
assert(filters.isEmpty)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedFilters: [IsNotNull(ID), GreaterThan(ID,1)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Row("mary", 2))
}
test("scan with column pruning") {
val df = spark.table("h2.test.people").select("id")
val scan = df.queryExecution.optimizedPlan.collectFirst {
case s: DataSourceV2ScanRelation => s
}.get
assert(scan.schema.names.sameElements(Seq("ID")))
checkAnswer(df, Seq(Row(1), Row(2)))
}
test("scan with filter push-down and column pruning") {
val df = spark.table("h2.test.people").filter($"id" > 1).select("name")
val filters = df.queryExecution.optimizedPlan.collect {
case f: Filter => f
}
assert(filters.isEmpty)
val scan = df.queryExecution.optimizedPlan.collectFirst {
case s: DataSourceV2ScanRelation => s
}.get
assert(scan.schema.names.sameElements(Seq("NAME")))
checkAnswer(df, Row("mary"))
}
test("read/write with partition info") {
withTable("h2.test.abc") {
sql("CREATE TABLE h2.test.abc AS SELECT * FROM h2.test.people")
val df1 = Seq(("evan", 3), ("cathy", 4), ("alex", 5)).toDF("NAME", "ID")
val e = intercept[IllegalArgumentException] {
df1.write
.option("partitionColumn", "id")
.option("lowerBound", "0")
.option("upperBound", "3")
.option("numPartitions", "0")
.insertInto("h2.test.abc")
}.getMessage
assert(e.contains("Invalid value `0` for parameter `numPartitions` in table writing " +
"via JDBC. The minimum value is 1."))
df1.write
.option("partitionColumn", "id")
.option("lowerBound", "0")
.option("upperBound", "3")
.option("numPartitions", "3")
.insertInto("h2.test.abc")
val df2 = spark.read
.option("partitionColumn", "id")
.option("lowerBound", "0")
.option("upperBound", "3")
.option("numPartitions", "2")
.table("h2.test.abc")
assert(df2.rdd.getNumPartitions === 2)
assert(df2.count() === 5)
}
}
test("show tables") {
checkAnswer(sql("SHOW TABLES IN h2.test"),
Seq(Row("test", "people", false), Row("test", "empty_table", false),
Row("test", "employee", false)))
}
test("SQL API: create table as select") {
withTable("h2.test.abc") {
sql("CREATE TABLE h2.test.abc AS SELECT * FROM h2.test.people")
checkAnswer(sql("SELECT name, id FROM h2.test.abc"), Seq(Row("fred", 1), Row("mary", 2)))
}
}
test("DataFrameWriterV2: create table as select") {
withTable("h2.test.abc") {
spark.table("h2.test.people").writeTo("h2.test.abc").create()
checkAnswer(sql("SELECT name, id FROM h2.test.abc"), Seq(Row("fred", 1), Row("mary", 2)))
}
}
test("SQL API: replace table as select") {
withTable("h2.test.abc") {
intercept[CannotReplaceMissingTableException] {
sql("REPLACE TABLE h2.test.abc AS SELECT 1 as col")
}
sql("CREATE OR REPLACE TABLE h2.test.abc AS SELECT 1 as col")
checkAnswer(sql("SELECT col FROM h2.test.abc"), Row(1))
sql("REPLACE TABLE h2.test.abc AS SELECT * FROM h2.test.people")
checkAnswer(sql("SELECT name, id FROM h2.test.abc"), Seq(Row("fred", 1), Row("mary", 2)))
}
}
test("DataFrameWriterV2: replace table as select") {
withTable("h2.test.abc") {
intercept[CannotReplaceMissingTableException] {
sql("SELECT 1 AS col").writeTo("h2.test.abc").replace()
}
sql("SELECT 1 AS col").writeTo("h2.test.abc").createOrReplace()
checkAnswer(sql("SELECT col FROM h2.test.abc"), Row(1))
spark.table("h2.test.people").writeTo("h2.test.abc").replace()
checkAnswer(sql("SELECT name, id FROM h2.test.abc"), Seq(Row("fred", 1), Row("mary", 2)))
}
}
test("SQL API: insert and overwrite") {
withTable("h2.test.abc") {
sql("CREATE TABLE h2.test.abc AS SELECT * FROM h2.test.people")
sql("INSERT INTO h2.test.abc SELECT 'lucy', 3")
checkAnswer(
sql("SELECT name, id FROM h2.test.abc"),
Seq(Row("fred", 1), Row("mary", 2), Row("lucy", 3)))
sql("INSERT OVERWRITE h2.test.abc SELECT 'bob', 4")
checkAnswer(sql("SELECT name, id FROM h2.test.abc"), Row("bob", 4))
}
}
test("DataFrameWriterV2: insert and overwrite") {
withTable("h2.test.abc") {
sql("CREATE TABLE h2.test.abc AS SELECT * FROM h2.test.people")
// `DataFrameWriterV2` is by-name.
sql("SELECT 3 AS ID, 'lucy' AS NAME").writeTo("h2.test.abc").append()
checkAnswer(
sql("SELECT name, id FROM h2.test.abc"),
Seq(Row("fred", 1), Row("mary", 2), Row("lucy", 3)))
sql("SELECT 'bob' AS NAME, 4 AS ID").writeTo("h2.test.abc").overwrite(lit(true))
checkAnswer(sql("SELECT name, id FROM h2.test.abc"), Row("bob", 4))
}
}
test("scan with aggregate push-down: MAX MIN with filter and group by") {
val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" +
" group by DEPT")
val filters = df.queryExecution.optimizedPlan.collect {
case f: Filter => f
}
assert(filters.isEmpty)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [MAX(SALARY), MIN(BONUS)], " +
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
"PushedGroupby: [DEPT]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(10000, 1000), Row(12000, 1200), Row(12000, 1200)))
}
test("scan with aggregate push-down: MAX MIN with filter without group by") {
val df = sql("select MAX(ID), MIN(ID) FROM h2.test.people where id > 0")
val filters = df.queryExecution.optimizedPlan.collect {
case f: Filter => f
}
assert(filters.isEmpty)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [MAX(ID), MIN(ID)], " +
"PushedFilters: [IsNotNull(ID), GreaterThan(ID,0)], " +
"PushedGroupby: []"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(2, 1)))
}
test("scan with aggregate push-down: aggregate + number") {
val df = sql("select MAX(SALARY) + 1 FROM h2.test.employee")
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [MAX(SALARY)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(12001)))
}
test("scan with aggregate push-down: COUNT(*)") {
val df = sql("select COUNT(*) FROM h2.test.employee")
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [COUNT(*)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(5)))
}
test("scan with aggregate push-down: COUNT(col)") {
val df = sql("select COUNT(DEPT) FROM h2.test.employee")
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [COUNT(DEPT)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(5)))
}
test("scan with aggregate push-down: COUNT(DISTINCT col)") {
val df = sql("select COUNT(DISTINCT DEPT) FROM h2.test.employee")
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [COUNT(DISTINCT DEPT)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(3)))
}
test("scan with aggregate push-down: SUM without filer and group by") {
val df = sql("SELECT SUM(SALARY) FROM h2.test.employee")
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [SUM(SALARY)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(53000)))
}
test("scan with aggregate push-down: DISTINCT SUM without filer and group by") {
val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee")
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [SUM(DISTINCT SALARY)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(31000)))
}
test("scan with aggregate push-down: SUM with group by") {
val df = sql("SELECT SUM(SALARY) FROM h2.test.employee GROUP BY DEPT")
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [SUM(SALARY)], " +
"PushedFilters: [], " +
"PushedGroupby: [DEPT]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000)))
}
test("scan with aggregate push-down: DISTINCT SUM with group by") {
val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee GROUP BY DEPT")
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [SUM(DISTINCT SALARY)], " +
"PushedFilters: [], " +
"PushedGroupby: [DEPT]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000)))
}
test("scan with aggregate push-down: with multiple group by columns") {
val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" +
" group by DEPT, NAME")
val filters11 = df.queryExecution.optimizedPlan.collect {
case f: Filter => f
}
assert(filters11.isEmpty)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [MAX(SALARY), MIN(BONUS)], " +
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
"PushedGroupby: [DEPT, NAME]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(9000, 1200), Row(12000, 1200), Row(10000, 1300),
Row(10000, 1000), Row(12000, 1200)))
}
test("scan with aggregate push-down: with having clause") {
val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" +
" group by DEPT having MIN(BONUS) > 1000")
val filters = df.queryExecution.optimizedPlan.collect {
case f: Filter => f // filter over aggregate not push down
}
assert(filters.nonEmpty)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [MAX(SALARY), MIN(BONUS)], " +
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
"PushedGroupby: [DEPT]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(12000, 1200), Row(12000, 1200)))
}
test("scan with aggregate push-down: alias over aggregate") {
val df = sql("select * from h2.test.employee")
.groupBy($"DEPT")
.min("SALARY").as("total")
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [MIN(SALARY)], " +
"PushedFilters: [], " +
"PushedGroupby: [DEPT]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(1, 9000), Row(2, 10000), Row(6, 12000)))
}
test("scan with aggregate push-down: order by alias over aggregate") {
val df = spark.table("h2.test.employee")
val query = df.select($"DEPT", $"SALARY")
.filter($"DEPT" > 0)
.groupBy($"DEPT")
.agg(sum($"SALARY").as("total"))
.filter($"total" > 1000)
.orderBy($"total")
val filters = query.queryExecution.optimizedPlan.collect {
case f: Filter => f
}
assert(filters.nonEmpty) // filter over aggregate not pushed down
query.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [SUM(SALARY)], " +
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
"PushedGroupby: [DEPT]"
checkKeywordsExistsInExplain(query, expected_plan_fragment)
}
checkAnswer(query, Seq(Row(6, 12000), Row(1, 19000), Row(2, 22000)))
}
test("scan with aggregate push-down: udf over aggregate") {
val df = spark.table("h2.test.employee")
val decrease = udf { (x: Double, y: Double) => x - y }
val query = df.select(decrease(sum($"SALARY"), sum($"BONUS")).as("value"))
query.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [SUM(SALARY), SUM(BONUS)"
checkKeywordsExistsInExplain(query, expected_plan_fragment)
}
checkAnswer(query, Seq(Row(47100.0)))
}
test("scan with aggregate push-down: aggregate over alias NOT push down") {
val cols = Seq("a", "b", "c", "d")
val df1 = sql("select * from h2.test.employee").toDF(cols: _*)
val df2 = df1.groupBy().sum("c")
df2.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: []" // aggregate over alias not push down
checkKeywordsExistsInExplain(df2, expected_plan_fragment)
}
checkAnswer(df2, Seq(Row(53000.00)))
}
}