blob: 4cc6cb1f52c470a6ac65316a206dab511ff08e4d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.test.util
import java.util.{Locale, TimeZone}
import scala.collection.JavaConverters._
import org.apache.spark.sql.{CarbonEnv, CarbonToSparkAdapter, DataFrame, Row, SQLContext}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.test.TestQueryExecutor
import org.apache.carbondata.common.logging.LogServiceFactory
import org.apache.carbondata.core.cache.CacheProvider
import org.apache.carbondata.core.constants.{CarbonCommonConstants, CarbonLoadOptionConstants}
import org.apache.carbondata.core.util.{CarbonProperties, ThreadLocalSessionInfo}
class QueryTest extends PlanTest {
val LOGGER = LogServiceFactory.getLogService(this.getClass.getCanonicalName)
// Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
// Add Locale setting
Locale.setDefault(Locale.US)
/**
* Runs the plan and makes sure the answer contains all of the keywords, or the
* none of keywords are listed in the answer
* @param df the [[DataFrame]] to be executed
* @param exists true for make sure the keywords are listed in the output, otherwise
* to make sure none of the keyword are not listed in the output
* @param keywords keyword in string array
*/
def checkExistence(df: DataFrame, exists: Boolean, keywords: String*) {
val outputs = df.collect().map(_.mkString(" ")).mkString(" ")
for (key <- keywords) {
if (exists) {
assert(outputs.contains(key), s"Failed for $df ($key doesn't exist in result)")
} else {
assert(!outputs.contains(key), s"Failed for $df ($key existed in the result)")
}
}
}
/**
* Runs the plan and counts the keyword in the answer
* @param df the [[DataFrame]] to be executed
* @param count expected count
* @param keyword keyword to search
*/
def checkExistenceCount(df: DataFrame, count: Long, keyword: String): Unit = {
val outputs = df.collect().map(_.mkString).mkString
assert(outputs.sliding(keyword.length).count(_ == keyword) === count)
}
def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) {
test(sqlString) {
checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
}
}
/**
* Runs the plan and makes sure the answer matches the expected result.
* @param df the [[DataFrame]] to be executed
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
*/
protected def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Unit = {
QueryTest.checkAnswer(df, expectedAnswer) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
}
protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = {
checkAnswer(df, Seq(expectedAnswer))
}
protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = {
checkAnswer(df, expectedAnswer.collect())
}
protected def dropTable(tableName: String): Unit = {
sql(s"DROP TABLE IF EXISTS $tableName")
}
protected def dropIndexes(tableName: String, indexNames: String*): Unit = {
for (indexName <- indexNames) {
sql(s"DROP INDEX IF EXISTS $indexName ON TABLE $tableName")
}
}
def sql(sqlText: String): DataFrame = TestQueryExecutor.INSTANCE.sql(sqlText)
val sqlContext: SQLContext = TestQueryExecutor.INSTANCE.sqlContext
val hiveClient = CarbonToSparkAdapter.getHiveExternalCatalog(sqlContext.sparkSession).client
lazy val projectPath = TestQueryExecutor.projectPath
lazy val warehouse = TestQueryExecutor.warehouse
lazy val storeLocation = warehouse
val resourcesPath = TestQueryExecutor.resourcesPath
val target = TestQueryExecutor.target
val integrationPath = TestQueryExecutor.integrationPath
val dblocation = TestQueryExecutor.location
val defaultParallelism = sqlContext.sparkContext.defaultParallelism
def defaultConfig(): Unit = {
CarbonEnv.getInstance(sqlContext.sparkSession).carbonSessionInfo.getSessionParams.clear()
ThreadLocalSessionInfo.unsetAll()
Option(CacheProvider.getInstance().getCarbonCache).map(_.clear())
CarbonProperties.getInstance()
.addProperty("enable.unsafe.sort", "true")
.removeProperty(CarbonCommonConstants.LOAD_SORT_SCOPE)
.removeProperty(CarbonLoadOptionConstants.CARBON_OPTIONS_SORT_SCOPE)
sqlContext.setConf("enable.unsafe.sort", "true")
sqlContext.setConf("carbon.options.sort.scope", "NO_SORT")
sqlContext.setConf(
CarbonLoadOptionConstants.CARBON_OPTIONS_IS_EMPTY_DATA_BAD_RECORD,
CarbonLoadOptionConstants.CARBON_OPTIONS_IS_EMPTY_DATA_BAD_RECORD_DEFAULT)
sqlContext.setConf("carbon.custom.block.distribution", "block")
sqlContext.setConf("carbon.options.bad.records.action", "FORCE")
sqlContext.setConf("carbon.options.bad.records.logger.enable", "false")
Seq(
CarbonCommonConstants.LOAD_SORT_SCOPE,
"carbon.enable.vector.reader",
"carbon.push.rowfilters.for.vector",
"spark.sql.sources.commitProtocolClass").foreach { key =>
sqlContext.sparkSession.conf.unset(key)
}
}
def printConfiguration(): Unit = {
CarbonProperties.getInstance().print()
LOGGER.error("------spark conf--------------------------")
LOGGER.error(sqlContext.sessionState.conf.getAllConfs
.map(x => x._1 + "=" + x._2).mkString(", "))
val sessionParams =
CarbonEnv.getInstance(sqlContext.sparkSession).carbonSessionInfo.getSessionParams.getAll
LOGGER.error("------CarbonEnv sessionParam--------------------------")
LOGGER.error(sessionParams.asScala.map(x => x._1 + "=" + x._2).mkString(", "))
}
def printTable(table: String, database: String = "default"): Unit = {
sql("SELECT current_database()").show(100, false)
sql(s"describe formatted ${ database }.${ table }").show(100, false)
}
def setCarbonProperties(propertiesString: String): Unit = {
val properties = propertiesString.split(", ", -1)
val exclude = Set("carbon.badRecords.location",
"carbon.storelocation")
properties.foreach { property =>
val entry = property.split("=")
if (!exclude.contains(entry(0))) {
CarbonProperties.getInstance().addProperty(entry(0), entry(1))
}
}
}
def confSparkSession(confString: String): Unit = {
val confs = confString.split(", ", -1)
val exclude = Set("spark.sql.warehouse.dir",
"carbon.options.bad.record.path",
"spark.sql.catalogImplementation",
"spark.sql.extensions",
"spark.app.name",
"spark.driver.host",
"spark.driver.port",
"spark.executor.id",
"spark.master",
"spark.app.id")
confs.foreach { conf =>
val entry = conf.split("=")
if (!exclude.contains(entry(0))) {
if (entry.length == 2) {
sqlContext.sessionState.conf.setConfString(entry(0), entry(1))
} else {
sqlContext.sessionState.conf.setConfString(entry(0), "")
}
}
}
}
}
object QueryTest {
def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = {
checkAnswer(df, expectedAnswer.asScala) match {
case Some(errorMessage) => errorMessage
case None => null
}
}
/**
* Runs the plan and makes sure the answer matches the expected result.
* If there was exception during the execution or the contents of the DataFrame does not
* match the expected result, an error message will be returned. Otherwise, a [[None]] will
* be returned.
* @param df the [[DataFrame]] to be executed
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
*/
def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = {
val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
// Java's java.math.BigDecimal.compareTo).
// For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
// equality test.
val converted: Seq[Row] = answer.map { s =>
Row.fromSeq(s.toSeq.map {
case d: java.math.BigDecimal => BigDecimal(d)
case b: Array[Byte] => b.toSeq
case d : Double =>
if (!d.isInfinite && !d.isNaN) {
var bd = BigDecimal(d)
bd = bd.setScale(5, BigDecimal.RoundingMode.UP)
bd.doubleValue()
}
else {
d
}
case o => o
})
}
if (!isSorted) converted.sortBy(_.toString()) else converted
}
val sparkAnswer = try df.collect().toSeq catch {
case e: Exception =>
val errorMessage =
s"""
|Exception thrown while executing query:
|${df.queryExecution}
|== Exception ==
|$e
|${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
""".stripMargin
return Some(errorMessage)
}
if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
val errorMessage =
s"""
|Results do not match for query:
|${df.queryExecution}
|== Results ==
|${
sideBySide(
s"== Correct Answer - ${expectedAnswer.size} ==" +:
prepareAnswer(expectedAnswer).map(_.toString()),
s"== Spark Answer - ${sparkAnswer.size} ==" +:
prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")
}
""".stripMargin
return Some(errorMessage)
}
return None
}
}