blob: ddfc95b48dcda6e996c98ed1dc34477cf5fda46c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.griffin.measure.transformations
import org.apache.spark.sql.DataFrame
import org.scalatest._
import org.apache.griffin.measure.SparkSuiteBase
import org.apache.griffin.measure.configuration.dqdefinition._
import org.apache.griffin.measure.configuration.enums.ProcessType.BatchProcessType
import org.apache.griffin.measure.context.{ContextId, DQContext}
import org.apache.griffin.measure.datasource.DataSourceFactory
import org.apache.griffin.measure.job.builder.DQJobBuilder
case class AccuracyResult(total: Long, miss: Long, matched: Long, matchedFraction: Double)
class AccuracyTransformationsIntegrationTest extends FlatSpec with Matchers with SparkSuiteBase {
private val EMPTY_PERSON_TABLE = "empty_person"
private val PERSON_TABLE = "person"
override def beforeAll(): Unit = {
super.beforeAll()
dropTables()
createPersonTable()
createEmptyPersonTable()
spark.conf.set("spark.sql.crossJoin.enabled", "true")
}
override def afterAll(): Unit = {
dropTables()
super.afterAll()
}
"accuracy" should "basically work" in {
checkAccuracy(
sourceName = PERSON_TABLE,
targetName = PERSON_TABLE,
expectedResult = AccuracyResult(total = 2, miss = 0, matched = 2, matchedFraction = 1.0))
}
"accuracy" should "work with empty target" in {
checkAccuracy(
sourceName = PERSON_TABLE,
targetName = EMPTY_PERSON_TABLE,
expectedResult = AccuracyResult(total = 2, miss = 2, matched = 0, matchedFraction = 0.0))
}
"accuracy" should "work with empty source" in {
checkAccuracy(
sourceName = EMPTY_PERSON_TABLE,
targetName = PERSON_TABLE,
expectedResult = AccuracyResult(total = 0, miss = 0, matched = 0, matchedFraction = 1.0))
}
"accuracy" should "work with empty source and target" in {
checkAccuracy(
sourceName = EMPTY_PERSON_TABLE,
targetName = EMPTY_PERSON_TABLE,
expectedResult = AccuracyResult(total = 0, miss = 0, matched = 0, matchedFraction = 1.0))
}
private def checkAccuracy(
sourceName: String,
targetName: String,
expectedResult: AccuracyResult) = {
val dqContext: DQContext = getDqContext(
dataSourcesParam = List(
DataSourceParam(
name = "source",
connectors = List(dataConnectorParam(tableName = sourceName))),
DataSourceParam(
name = "target",
connectors = List(dataConnectorParam(tableName = targetName)))))
val accuracyRule = RuleParam(
dslType = "griffin-dsl",
dqType = "ACCURACY",
outDfName = "person_accuracy",
rule = "source.name = target.name")
val spark = this.spark
import spark.implicits._
val res = getRuleResults(dqContext, accuracyRule)
.as[AccuracyResult]
.collect()
res.length shouldBe 1
res(0) shouldEqual expectedResult
}
private def getRuleResults(dqContext: DQContext, rule: RuleParam): DataFrame = {
val dqJob =
DQJobBuilder.buildDQJob(dqContext, evaluateRuleParam = EvaluateRuleParam(List(rule)))
dqJob.execute(dqContext)
spark.sql(s"select * from ${rule.getOutDfName()}")
}
private def createPersonTable(): Unit = {
val personCsvPath = getClass.getResource("/hive/person_table.csv").getFile
spark.sql(
s"CREATE TABLE $PERSON_TABLE " +
"( " +
" name String," +
" age int " +
") " +
"ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' " +
"STORED AS TEXTFILE")
spark.sql(s"LOAD DATA LOCAL INPATH '$personCsvPath' OVERWRITE INTO TABLE $PERSON_TABLE")
}
private def createEmptyPersonTable(): Unit = {
spark.sql(
s"CREATE TABLE $EMPTY_PERSON_TABLE " +
"( " +
" name String," +
" age int " +
") " +
"ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' " +
"STORED AS TEXTFILE")
spark.sql(s"select * from $EMPTY_PERSON_TABLE").show()
}
private def dropTables(): Unit = {
spark.sql(s"DROP TABLE IF EXISTS $PERSON_TABLE ")
spark.sql(s"DROP TABLE IF EXISTS $EMPTY_PERSON_TABLE ")
}
private def getDqContext(
dataSourcesParam: Seq[DataSourceParam],
name: String = "test-context"): DQContext = {
val dataSources = DataSourceFactory.getDataSources(spark, null, dataSourcesParam)
dataSources.foreach(_.init())
DQContext(ContextId(System.currentTimeMillis), name, dataSources, Nil, BatchProcessType)(
spark)
}
private def dataConnectorParam(tableName: String) = {
DataConnectorParam(
conType = "HIVE",
version = null,
dataFrameName = null,
config = Map("table.name" -> tableName),
preProc = null)
}
}