blob: eacdb9e2fcd7baf797ef0069b9513c59fe02f120 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connector
import scala.language.implicitConversions
import scala.util.Try
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkException
import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode}
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression}
import org.apache.spark.sql.connector.catalog.{Identifier, SupportsCatalogOptions, TableCatalog}
import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{LongType, StructType}
import org.apache.spark.sql.util.{CaseInsensitiveStringMap, QueryExecutionListener}
class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with BeforeAndAfter {
import testImplicits._
private val catalogName = "testcat"
private val format = classOf[CatalogSupportingInMemoryTableProvider].getName
private def catalog(name: String): TableCatalog = {
spark.sessionState.catalogManager.catalog(name).asInstanceOf[TableCatalog]
}
private implicit def stringToIdentifier(value: String): Identifier = {
Identifier.of(Array.empty, value)
}
before {
spark.conf.set(
V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[InMemoryTableSessionCatalog].getName)
spark.conf.set(
s"spark.sql.catalog.$catalogName", classOf[InMemoryTableCatalog].getName)
}
override def afterEach(): Unit = {
super.afterEach()
Try(catalog(SESSION_CATALOG_NAME).asInstanceOf[InMemoryTableSessionCatalog].clearTables())
catalog(catalogName).listTables(Array.empty).foreach(
catalog(catalogName).dropTable(_))
spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key)
spark.conf.unset(s"spark.sql.catalog.$catalogName")
}
private def testCreateAndRead(
saveMode: SaveMode,
withCatalogOption: Option[String],
partitionBy: Seq[String]): Unit = {
val df = spark.range(10).withColumn("part", 'id % 5)
val dfw = df.write.format(format).mode(saveMode).option("name", "t1")
withCatalogOption.foreach(cName => dfw.option("catalog", cName))
dfw.partitionBy(partitionBy: _*).save()
val ident = if (withCatalogOption.isEmpty) {
Identifier.of(Array("default"), "t1")
} else {
Identifier.of(Array(), "t1")
}
val table = catalog(withCatalogOption.getOrElse(SESSION_CATALOG_NAME)).loadTable(ident)
val namespace = withCatalogOption.getOrElse("default")
assert(table.name() === s"$namespace.t1", "Table identifier was wrong")
assert(table.partitioning().length === partitionBy.length, "Partitioning did not match")
if (partitionBy.nonEmpty) {
table.partitioning.head match {
case IdentityTransform(FieldReference(field)) =>
assert(field === Seq(partitionBy.head), "Partitioning column did not match")
case otherTransform =>
fail(s"Unexpected partitioning ${otherTransform.describe()} received")
}
}
assert(table.partitioning().map(_.references().head.fieldNames().head) === partitionBy,
"Partitioning was incorrect")
assert(table.schema() === df.schema.asNullable, "Schema did not match")
checkAnswer(load("t1", withCatalogOption), df.toDF())
}
test(s"save works with ErrorIfExists - no table, no partitioning, session catalog") {
testCreateAndRead(SaveMode.ErrorIfExists, None, Nil)
}
test(s"save works with ErrorIfExists - no table, with partitioning, session catalog") {
testCreateAndRead(SaveMode.ErrorIfExists, None, Seq("part"))
}
test(s"save works with Ignore - no table, no partitioning, testcat catalog") {
testCreateAndRead(SaveMode.Ignore, Some(catalogName), Nil)
}
test(s"save works with Ignore - no table, with partitioning, testcat catalog") {
testCreateAndRead(SaveMode.Ignore, Some(catalogName), Seq("part"))
}
test("save fails with ErrorIfExists if table exists - session catalog") {
sql(s"create table t1 (id bigint) using $format")
val df = spark.range(10)
intercept[TableAlreadyExistsException] {
val dfw = df.write.format(format).option("name", "t1")
dfw.save()
}
}
test("save fails with ErrorIfExists if table exists - testcat catalog") {
sql(s"create table $catalogName.t1 (id bigint) using $format")
val df = spark.range(10)
intercept[TableAlreadyExistsException] {
val dfw = df.write.format(format).option("name", "t1").option("catalog", catalogName)
dfw.save()
}
}
test("Ignore mode if table exists - session catalog") {
sql(s"create table t1 (id bigint) using $format")
val df = spark.range(10).withColumn("part", 'id % 5)
val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1")
dfw.save()
val table = catalog(SESSION_CATALOG_NAME).loadTable(Identifier.of(Array("default"), "t1"))
assert(table.partitioning().isEmpty, "Partitioning should be empty")
assert(table.schema() === new StructType().add("id", LongType), "Schema did not match")
assert(load("t1", None).count() === 0)
}
test("Ignore mode if table exists - testcat catalog") {
sql(s"create table $catalogName.t1 (id bigint) using $format")
val df = spark.range(10).withColumn("part", 'id % 5)
val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1")
dfw.option("catalog", catalogName).save()
val table = catalog(catalogName).loadTable("t1")
assert(table.partitioning().isEmpty, "Partitioning should be empty")
assert(table.schema() === new StructType().add("id", LongType), "Schema did not match")
assert(load("t1", Some(catalogName)).count() === 0)
}
test("append and overwrite modes - session catalog") {
sql(s"create table t1 (id bigint) using $format")
val df = spark.range(10)
df.write.format(format).option("name", "t1").mode(SaveMode.Append).save()
checkAnswer(load("t1", None), df.toDF())
val df2 = spark.range(10, 20)
df2.write.format(format).option("name", "t1").mode(SaveMode.Overwrite).save()
checkAnswer(load("t1", None), df2.toDF())
}
test("append and overwrite modes - testcat catalog") {
sql(s"create table $catalogName.t1 (id bigint) using $format")
val df = spark.range(10)
df.write.format(format).option("name", "t1").option("catalog", catalogName)
.mode(SaveMode.Append).save()
checkAnswer(load("t1", Some(catalogName)), df.toDF())
val df2 = spark.range(10, 20)
df2.write.format(format).option("name", "t1").option("catalog", catalogName)
.mode(SaveMode.Overwrite).save()
checkAnswer(load("t1", Some(catalogName)), df2.toDF())
}
test("fail on user specified schema when reading - session catalog") {
sql(s"create table t1 (id bigint) using $format")
val e = intercept[IllegalArgumentException] {
spark.read.format(format).option("name", "t1").schema("id bigint").load()
}
assert(e.getMessage.contains("not support user specified schema"))
}
test("fail on user specified schema when reading - testcat catalog") {
sql(s"create table $catalogName.t1 (id bigint) using $format")
val e = intercept[IllegalArgumentException] {
spark.read.format(format).option("name", "t1").option("catalog", catalogName)
.schema("id bigint").load()
}
assert(e.getMessage.contains("not support user specified schema"))
}
test("DataFrameReader creates v2Relation with identifiers") {
sql(s"create table $catalogName.t1 (id bigint) using $format")
val df = load("t1", Some(catalogName))
checkV2Identifiers(df.logicalPlan)
}
test("DataFrameWriter creates v2Relation with identifiers") {
sql(s"create table $catalogName.t1 (id bigint) using $format")
var plan: LogicalPlan = null
val listener = new QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
plan = qe.analyzed
}
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
}
spark.listenerManager.register(listener)
try {
// Test append
save("t1", SaveMode.Append, Some(catalogName))
sparkContext.listenerBus.waitUntilEmpty()
assert(plan.isInstanceOf[AppendData])
val appendRelation = plan.asInstanceOf[AppendData].table
checkV2Identifiers(appendRelation)
// Test overwrite
save("t1", SaveMode.Overwrite, Some(catalogName))
sparkContext.listenerBus.waitUntilEmpty()
assert(plan.isInstanceOf[OverwriteByExpression])
val overwriteRelation = plan.asInstanceOf[OverwriteByExpression].table
checkV2Identifiers(overwriteRelation)
// Test insert
spark.range(10).write.format(format).insertInto(s"$catalogName.t1")
sparkContext.listenerBus.waitUntilEmpty()
assert(plan.isInstanceOf[AppendData])
val insertRelation = plan.asInstanceOf[AppendData].table
checkV2Identifiers(insertRelation)
// Test saveAsTable append
spark.range(10).write.format(format).mode(SaveMode.Append).saveAsTable(s"$catalogName.t1")
sparkContext.listenerBus.waitUntilEmpty()
assert(plan.isInstanceOf[AppendData])
val saveAsTableRelation = plan.asInstanceOf[AppendData].table
checkV2Identifiers(saveAsTableRelation)
} finally {
spark.listenerManager.unregister(listener)
}
}
test("SPARK-33240: fail the query when instantiation on session catalog fails") {
try {
spark.sessionState.catalogManager.reset()
spark.conf.set(
V2_SESSION_CATALOG_IMPLEMENTATION.key, "InvalidCatalogClass")
val e = intercept[SparkException] {
sql(s"create table t1 (id bigint) using $format")
}
assert(e.getMessage.contains("Cannot find catalog plugin class"))
assert(e.getMessage.contains("InvalidCatalogClass"))
} finally {
spark.sessionState.catalogManager.reset()
}
}
private def checkV2Identifiers(
plan: LogicalPlan,
identifier: String = "t1",
catalogPlugin: TableCatalog = catalog(catalogName)): Unit = {
assert(plan.isInstanceOf[DataSourceV2Relation])
val v2 = plan.asInstanceOf[DataSourceV2Relation]
assert(v2.identifier.exists(_.name() == identifier))
assert(v2.catalog.exists(_ == catalogPlugin))
}
private def load(name: String, catalogOpt: Option[String]): DataFrame = {
val dfr = spark.read.format(format).option("name", name)
catalogOpt.foreach(cName => dfr.option("catalog", cName))
dfr.load()
}
private def save(name: String, mode: SaveMode, catalogOpt: Option[String]): Unit = {
val df = spark.range(10).write.format(format).option("name", name)
catalogOpt.foreach(cName => df.option("catalog", cName))
df.mode(mode).save()
}
}
class CatalogSupportingInMemoryTableProvider
extends FakeV2Provider
with SupportsCatalogOptions {
override def extractIdentifier(options: CaseInsensitiveStringMap): Identifier = {
val name = options.get("name")
assert(name != null, "The name should be provided for this table")
val namespace = if (options.containsKey("catalog")) {
Array[String]()
} else {
Array("default")
}
Identifier.of(namespace, name)
}
override def extractCatalog(options: CaseInsensitiveStringMap): String = {
options.get("catalog")
}
}