blob: 2bf51193bdf81df7472f822f562d17638962b940 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.griffin.measure.transformations
import org.apache.spark.sql.DataFrame
import org.scalatest._
import org.apache.griffin.measure.configuration.dqdefinition._
import org.apache.griffin.measure.configuration.enums.ProcessType.BatchProcessType
import org.apache.griffin.measure.context.{ContextId, DQContext}
import org.apache.griffin.measure.datasource.DataSourceFactory
import org.apache.griffin.measure.job.builder.DQJobBuilder
import org.apache.griffin.measure.SparkSuiteBase
case class AccuracyResult(total: Long, miss: Long, matched: Long, matchedFraction: Double)
class AccuracyTransformationsIntegrationTest extends FlatSpec with Matchers with SparkSuiteBase {
private val EMPTY_PERSON_TABLE = "empty_person"
private val PERSON_TABLE = "person"
override def beforeAll(): Unit = {
super.beforeAll()
dropTables()
createPersonTable()
createEmptyPersonTable()
spark.conf.set("spark.sql.crossJoin.enabled", "true")
}
override def afterAll(): Unit = {
dropTables()
super.afterAll()
}
"accuracy" should "basically work" in {
checkAccuracy(
sourceName = PERSON_TABLE,
targetName = PERSON_TABLE,
expectedResult = AccuracyResult(total = 2, miss = 0, matched = 2, matchedFraction = 1.0))
}
"accuracy" should "work with empty target" in {
checkAccuracy(
sourceName = PERSON_TABLE,
targetName = EMPTY_PERSON_TABLE,
expectedResult = AccuracyResult(total = 2, miss = 2, matched = 0, matchedFraction = 0.0)
)
}
"accuracy" should "work with empty source" in {
checkAccuracy(
sourceName = EMPTY_PERSON_TABLE,
targetName = PERSON_TABLE,
expectedResult = AccuracyResult(total = 0, miss = 0, matched = 0, matchedFraction = 1.0))
}
"accuracy" should "work with empty source and target" in {
checkAccuracy(
sourceName = EMPTY_PERSON_TABLE,
targetName = EMPTY_PERSON_TABLE,
expectedResult = AccuracyResult(total = 0, miss = 0, matched = 0, matchedFraction = 1.0))
}
private def checkAccuracy(sourceName: String, targetName: String, expectedResult: AccuracyResult) = {
val dqContext: DQContext = getDqContext(
dataSourcesParam = List(
DataSourceParam(
name = "source",
connectors = List(dataConnectorParam(tableName = sourceName))
),
DataSourceParam(
name = "target",
connectors = List(dataConnectorParam(tableName = targetName))
)
))
val accuracyRule = RuleParam(
dslType = "griffin-dsl",
dqType = "ACCURACY",
outDfName = "person_accuracy",
rule = "source.name = target.name"
)
val spark = this.spark
import spark.implicits._
val res = getRuleResults(dqContext, accuracyRule)
.as[AccuracyResult]
.collect()
res.length shouldBe 1
res(0) shouldEqual expectedResult
}
private def getRuleResults(dqContext: DQContext, rule: RuleParam): DataFrame = {
val dqJob = DQJobBuilder.buildDQJob(
dqContext,
evaluateRuleParam = EvaluateRuleParam(List(rule))
)
dqJob.execute(dqContext)
spark.sql(s"select * from ${rule.getOutDfName()}")
}
private def createPersonTable(): Unit = {
val personCsvPath = getClass.getResource("/hive/person_table.csv").getFile
spark.sql(
s"CREATE TABLE ${PERSON_TABLE} " +
"( " +
" name String," +
" age int " +
") " +
"ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' " +
"STORED AS TEXTFILE"
)
spark.sql(s"LOAD DATA LOCAL INPATH '$personCsvPath' OVERWRITE INTO TABLE ${PERSON_TABLE}")
}
private def createEmptyPersonTable(): Unit = {
spark.sql(
s"CREATE TABLE ${EMPTY_PERSON_TABLE} " +
"( " +
" name String," +
" age int " +
") " +
"ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' " +
"STORED AS TEXTFILE"
)
spark.sql(s"select * from ${EMPTY_PERSON_TABLE}").show()
}
private def dropTables(): Unit = {
spark.sql(s"DROP TABLE IF EXISTS ${PERSON_TABLE} ")
spark.sql(s"DROP TABLE IF EXISTS ${EMPTY_PERSON_TABLE} ")
}
private def getDqContext(dataSourcesParam: Seq[DataSourceParam], name: String = "test-context"): DQContext = {
val dataSources = DataSourceFactory.getDataSources(spark, null, dataSourcesParam)
dataSources.foreach(_.init())
DQContext(
ContextId(System.currentTimeMillis),
name,
dataSources,
Nil,
BatchProcessType
)(spark)
}
private def dataConnectorParam(tableName: String) = {
DataConnectorParam(
conType = "HIVE",
version = null,
dataFrameName = null,
config = Map("table.name" -> tableName),
preProc = null
)
}
}