[GRIFFIN-358] Added AccuracyMeasureTest
diff --git a/measure/src/main/scala/org/apache/griffin/measure/execution/impl/AccuracyMeasure.scala b/measure/src/main/scala/org/apache/griffin/measure/execution/impl/AccuracyMeasure.scala
index 506a5b5..b618010 100644
--- a/measure/src/main/scala/org/apache/griffin/measure/execution/impl/AccuracyMeasure.scala
+++ b/measure/src/main/scala/org/apache/griffin/measure/execution/impl/AccuracyMeasure.scala
@@ -25,75 +25,84 @@
import org.apache.griffin.measure.configuration.dqdefinition.MeasureParam
import org.apache.griffin.measure.execution.Measure
+import org.apache.griffin.measure.step.builder.ConstantColumns
case class AccuracyMeasure(measureParam: MeasureParam) extends Measure {
case class AccuracyExpr(sourceCol: String, targetCol: String)
+ import AccuracyMeasure._
import Measure._
- private final val TargetSourceStr: String = "target.source"
- private final val SourceColStr: String = "source.col"
- private final val TargetColStr: String = "target.col"
-
- private final val AccurateStr: String = "accurate"
- private final val InAccurateStr: String = "inaccurate"
-
override val supportsRecordWrite: Boolean = true
override val supportsMetricWrite: Boolean = true
val targetSource: String = getFromConfig[String](TargetSourceStr, null)
- val exprOpt: Option[Seq[AccuracyExpr]] =
- Option(getFromConfig[Seq[Map[String, String]]](Expression, null).map(toAccuracyExpr).distinct)
+ val exprOpt: Option[Seq[Map[String, String]]] =
+ Option(getFromConfig[Seq[Map[String, String]]](Expression, null))
validate()
override def impl(sparkSession: SparkSession): (DataFrame, DataFrame) = {
- import org.apache.griffin.measure.step.builder.ConstantColumns
- datasetValidations(sparkSession)
+ val originalSource = sparkSession.read.table(measureParam.getDataSource)
+ val originalCols = originalSource.columns
- val dataSource = sparkSession.read.table(measureParam.getDataSource)
- val targetDataSource = sparkSession.read.table(targetSource).drop(ConstantColumns.tmst)
+ val dataSource = addColumnPrefix(originalSource, SourcePrefixStr)
- exprOpt match {
- case Some(accuracyExpr) =>
- val joinExpr =
- accuracyExpr.map(e => col(e.sourceCol) === col(e.targetCol)).reduce(_ and _)
+ val targetDataSource =
+ addColumnPrefix(
+ sparkSession.read.table(targetSource).drop(ConstantColumns.tmst),
+ TargetPrefixStr)
- val indicatorExpr =
- accuracyExpr
- .map(e =>
- coalesce(col(e.sourceCol), emptyCol) notEqual coalesce(col(e.targetCol), emptyCol))
- .reduce(_ or _)
+ val accuracyExprs = exprOpt.get
+ .map(toAccuracyExpr)
+ .distinct
+ .map(x =>
+ AccuracyExpr(s"$SourcePrefixStr${x.sourceCol}", s"$TargetPrefixStr${x.targetCol}"))
- val recordsDf = targetDataSource
- .join(dataSource, joinExpr, "outer")
- .withColumn(valueColumn, when(indicatorExpr, 1).otherwise(0))
- .selectExpr(s"${measureParam.getDataSource}.*", valueColumn)
+ val joinExpr =
+ accuracyExprs
+ .map(e => col(e.sourceCol) === col(e.targetCol))
+ .reduce(_ and _)
- val selectCols = Seq(Total, AccurateStr, InAccurateStr).flatMap(e => Seq(lit(e), col(e)))
- val metricColumn: Column = map(selectCols: _*).as(valueColumn)
+ val indicatorExpr =
+ accuracyExprs
+ .map(e =>
+ coalesce(col(e.sourceCol), emptyCol) notEqual coalesce(col(e.targetCol), emptyCol))
+ .reduce(_ or _)
- val metricDf = recordsDf
- .withColumn(Total, lit(1))
- .agg(sum(Total).as(Total), sum(valueColumn).as(InAccurateStr))
- .withColumn(AccurateStr, col(Total) - col(InAccurateStr))
- .select(metricColumn)
+ val nullExpr = accuracyExprs.map(e => col(e.sourceCol).isNull).reduce(_ or _)
- (recordsDf, metricDf)
- case None =>
- throw new IllegalArgumentException(s"'$Expression' must be defined.")
- }
+ val recordsDf = removeColumnPrefix(
+ targetDataSource
+ .join(dataSource, joinExpr, "outer")
+ .withColumn(valueColumn, when(indicatorExpr or nullExpr, 1).otherwise(0)),
+ SourcePrefixStr)
+ .select((originalCols :+ valueColumn).map(col): _*)
+
+ val selectCols =
+ Seq(Total, AccurateStr, InAccurateStr).flatMap(e => Seq(lit(e), col(e).cast("string")))
+ val metricColumn: Column = map(selectCols: _*).as(valueColumn)
+
+ val metricDf = recordsDf
+ .withColumn(Total, lit(1))
+ .agg(sum(Total).as(Total), sum(valueColumn).as(InAccurateStr))
+ .withColumn(AccurateStr, col(Total) - col(InAccurateStr))
+ .select(metricColumn)
+
+ (recordsDf, metricDf)
}
private def validate(): Unit = {
assert(exprOpt.isDefined, s"'$Expression' must be defined.")
- assert(exprOpt.nonEmpty, s"'$Expression' must not be empty.")
+ assert(exprOpt.get.flatten.nonEmpty, s"'$Expression' must not be empty or of invalid type.")
assert(
!StringUtil.isNullOrEmpty(targetSource),
- s"'$TargetSourceStr' must not be null or empty.")
+ s"'$TargetSourceStr' must not be null, empty or of invalid type.")
+
+ datasetValidations()
}
private def toAccuracyExpr(map: Map[String, String]): AccuracyExpr = {
@@ -103,7 +112,9 @@
AccuracyExpr(map(SourceColStr), map(TargetColStr))
}
- private def datasetValidations(sparkSession: SparkSession): Unit = {
+ private def datasetValidations(): Unit = {
+ val sparkSession = SparkSession.getDefaultSession.get
+
assert(
sparkSession.catalog.tableExists(targetSource),
s"Target source with name '$targetSource' does not exist.")
@@ -115,7 +126,7 @@
val targetDataSourceCols =
sparkSession.read.table(targetSource).columns.map(_.toLowerCase(Locale.ROOT)).toSet
- val accuracyExpr = exprOpt.get
+ val accuracyExpr = exprOpt.get.map(toAccuracyExpr).distinct
val (forDataSource, forTarget) =
accuracyExpr
.map(
@@ -138,4 +149,26 @@
s"Column(s) [${invalidColsTarget.map(_._1).mkString(", ")}] " +
s"do not exist in target data set with name '$targetSource'")
}
+
+ private def addColumnPrefix(dataFrame: DataFrame, prefix: String): DataFrame = {
+ val columns = dataFrame.columns
+ columns.foldLeft(dataFrame)((df, c) => df.withColumnRenamed(c, s"$prefix$c"))
+ }
+
+ private def removeColumnPrefix(dataFrame: DataFrame, prefix: String): DataFrame = {
+ val columns = dataFrame.columns
+ columns.foldLeft(dataFrame)((df, c) => df.withColumnRenamed(c, c.stripPrefix(prefix)))
+ }
+}
+
+object AccuracyMeasure{
+ final val SourcePrefixStr: String = "__source_"
+ final val TargetPrefixStr: String = "__target_"
+
+ final val TargetSourceStr: String = "target.source"
+ final val SourceColStr: String = "source.col"
+ final val TargetColStr: String = "target.col"
+
+ final val AccurateStr: String = "accurate"
+ final val InAccurateStr: String = "inaccurate"
}
diff --git a/measure/src/test/scala/org/apache/griffin/measure/execution/impl/AccuracyMeasureTest.scala b/measure/src/test/scala/org/apache/griffin/measure/execution/impl/AccuracyMeasureTest.scala
new file mode 100644
index 0000000..be596e3
--- /dev/null
+++ b/measure/src/test/scala/org/apache/griffin/measure/execution/impl/AccuracyMeasureTest.scala
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.griffin.measure.execution.impl
+
+import org.apache.commons.lang3.StringUtils
+
+import org.apache.griffin.measure.configuration.dqdefinition.MeasureParam
+import org.apache.griffin.measure.execution.Measure._
+import org.apache.griffin.measure.execution.impl.AccuracyMeasure._
+
+class AccuracyMeasureTest extends MeasureTest {
+ var param: MeasureParam = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ param = MeasureParam(
+ "param",
+ "Accuracy",
+ "source",
+ Map(
+ Expression -> Seq(Map(SourceColStr -> "gender", TargetColStr -> "gender")),
+ TargetSourceStr -> "target"))
+ }
+
+ "AccuracyMeasure" should "validate expression config" in {
+
+ // Validations for Accuracy Expr
+
+ // Empty
+ assertThrows[AssertionError] {
+ AccuracyMeasure(param.copy(config = Map.empty[String, String]))
+ }
+
+ // Incorrect Type and Empty
+ assertThrows[AssertionError] {
+ AccuracyMeasure(param.copy(config = Map(Expression -> StringUtils.EMPTY)))
+ }
+
+ // Null
+ assertThrows[AssertionError] {
+ AccuracyMeasure(param.copy(config = Map(Expression -> null)))
+ }
+
+ // Incorrect Type
+ assertThrows[AssertionError] {
+ AccuracyMeasure(param.copy(config = Map(Expression -> "gender")))
+ }
+
+ // Correct Type and Empty
+ assertThrows[AssertionError] {
+ AccuracyMeasure(param.copy(config = Map(Expression -> Seq.empty[Map[String, String]])))
+ }
+
+ // Invalid Expr
+ assertThrows[AssertionError] {
+ AccuracyMeasure(
+ param.copy(config = Map(Expression -> Seq(Map("a" -> "b")), TargetSourceStr -> "target")))
+ }
+
+ // Invalid Expr as target.col is missing
+ assertThrows[AssertionError] {
+ AccuracyMeasure(
+ param.copy(
+ config = Map(Expression -> Seq(Map(SourceColStr -> "b")), TargetSourceStr -> "target")))
+ }
+
+ // Invalid Expr as source.col is missing
+ assertThrows[AssertionError] {
+ AccuracyMeasure(
+ param.copy(
+ config = Map(Expression -> Seq(Map(TargetColStr -> "b")), TargetSourceStr -> "target")))
+ }
+
+ // Invalid Expr as provided source.col is invalid
+ assertThrows[AssertionError] {
+ AccuracyMeasure(
+ param.copy(
+ config = Map(
+ Expression -> Seq(Map(SourceColStr -> "b", TargetColStr -> "b")),
+ TargetSourceStr -> "target")))
+ }
+
+ // Invalid Expr as provided target.col is invalid
+ assertThrows[AssertionError] {
+ AccuracyMeasure(
+ param.copy(
+ config = Map(
+ Expression -> Seq(Map(SourceColStr -> "gender", TargetColStr -> "b")),
+ TargetSourceStr -> "target")))
+ }
+
+ // Validations for Target source
+
+ // Empty
+ assertThrows[AssertionError] {
+ AccuracyMeasure(
+ param.copy(
+ config = Map(Expression -> Seq(Map("a" -> "b")), TargetSourceStr -> StringUtils.EMPTY)))
+ }
+
+ // Incorrect Type
+ assertThrows[AssertionError] {
+ AccuracyMeasure(
+ param.copy(config = Map(Expression -> Seq(Map("a" -> "b")), TargetSourceStr -> 2331)))
+ }
+
+ // Null
+ assertThrows[AssertionError] {
+ AccuracyMeasure(
+ param.copy(config = Map(Expression -> Seq(Map("a" -> "b")), TargetSourceStr -> null)))
+ }
+
+ // Invalid target
+ assertThrows[AssertionError] {
+ AccuracyMeasure(
+ param.copy(config = Map(Expression -> Seq(Map("a" -> "b")), TargetSourceStr -> "jj")))
+ }
+ }
+
+ it should "support metric writing" in {
+ val measure = AccuracyMeasure(param)
+ assertResult(true)(measure.supportsMetricWrite)
+ }
+
+ it should "support record writing" in {
+ val measure = AccuracyMeasure(param)
+ assertResult(true)(measure.supportsRecordWrite)
+ }
+
+ it should "execute defined measure expr" in {
+ val measure = AccuracyMeasure(param)
+ val (recordsDf, metricsDf) = measure.execute(context, None)
+
+ assertResult(recordsDf.schema)(recordDfSchema)
+ assertResult(metricsDf.schema)(metricDfSchema)
+
+ assertResult(recordsDf.count())(source.count())
+ assertResult(metricsDf.count())(1L)
+
+ val row = metricsDf.head()
+ assertResult(param.getDataSource)(row.getAs[String](DataSource))
+ assertResult(param.getName)(row.getAs[String](MeasureName))
+ assertResult(param.getType.toString)(row.getAs[String](MeasureType))
+
+ val metricMap = row.getAs[Map[String, String]](Metrics)
+ assertResult(metricMap(Total))("5")
+ assertResult(metricMap(AccurateStr))("2")
+ assertResult(metricMap(InAccurateStr))("3")
+ }
+
+}
diff --git a/measure/src/test/scala/org/apache/griffin/measure/execution/impl/CompletenessMeasureTest.scala b/measure/src/test/scala/org/apache/griffin/measure/execution/impl/CompletenessMeasureTest.scala
index 1242405..baec723 100644
--- a/measure/src/test/scala/org/apache/griffin/measure/execution/impl/CompletenessMeasureTest.scala
+++ b/measure/src/test/scala/org/apache/griffin/measure/execution/impl/CompletenessMeasureTest.scala
@@ -43,6 +43,10 @@
assertThrows[AssertionError] {
CompletenessMeasure(param.copy(config = Map(Expression -> null)))
}
+
+ assertThrows[AssertionError] {
+ CompletenessMeasure(param.copy(config = Map(Expression -> 22)))
+ }
}
it should "support metric writing" in {
@@ -62,7 +66,7 @@
assertResult(recordsDf.schema)(recordDfSchema)
assertResult(metricsDf.schema)(metricDfSchema)
- assertResult(recordsDf.count())(dataSet.count())
+ assertResult(recordsDf.count())(source.count())
assertResult(metricsDf.count())(1L)
val row = metricsDf.head()
@@ -84,7 +88,7 @@
assertResult(recordsDf.schema)(recordDfSchema)
assertResult(metricsDf.schema)(metricDfSchema)
- assertResult(recordsDf.count())(dataSet.count())
+ assertResult(recordsDf.count())(source.count())
assertResult(metricsDf.count())(1L)
val row = metricsDf.head()
diff --git a/measure/src/test/scala/org/apache/griffin/measure/execution/impl/MeasureTest.scala b/measure/src/test/scala/org/apache/griffin/measure/execution/impl/MeasureTest.scala
index eae90e8..992a549 100644
--- a/measure/src/test/scala/org/apache/griffin/measure/execution/impl/MeasureTest.scala
+++ b/measure/src/test/scala/org/apache/griffin/measure/execution/impl/MeasureTest.scala
@@ -30,11 +30,13 @@
trait MeasureTest extends SparkSuiteBase with Matchers {
var sourceSchema: StructType = _
+ var targetSchema: StructType = _
var recordDfSchema: StructType = _
var metricDfSchema: StructType = _
var context: DQContext = _
- var dataSet: DataFrame = _
+ var source: DataFrame = _
+ var target: DataFrame = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -46,6 +48,8 @@
sourceSchema =
new StructType().add("id", "integer").add("name", "string").add("gender", "string")
+ targetSchema = new StructType().add("gender", "string")
+
recordDfSchema = sourceSchema.add(Status, "string", nullable = false)
metricDfSchema = new StructType()
.add(MeasureName, "string", nullable = false)
@@ -53,7 +57,7 @@
.add(DataSource, "string", nullable = false)
.add(Metrics, MapType(StringType, StringType), nullable = false)
- dataSet = spark
+ source = spark
.createDataset(
Seq(
Row(1, "John Smith", "Male"),
@@ -63,7 +67,10 @@
Row(5, null, null)))(RowEncoder(sourceSchema))
.cache()
- dataSet.createOrReplaceTempView("source")
+ target = spark.createDataset(Seq(Row("Male")))(RowEncoder(targetSchema)).cache()
+
+ source.createOrReplaceTempView("source")
+ target.createOrReplaceTempView("target")
}
}