blob: fc256c90b601d63b686598e3428904ab1daa5ec6 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.griffin.measure.execution.impl
import java.util.Locale
import io.netty.util.internal.StringUtil
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.griffin.measure.configuration.dqdefinition.MeasureParam
import org.apache.griffin.measure.execution.Measure
case class AccuracyMeasure(measureParam: MeasureParam) extends Measure {
case class AccuracyExpr(sourceCol: String, targetCol: String)
import Measure._
private final val TargetSourceStr: String = "target.source"
private final val SourceColStr: String = "source.col"
private final val TargetColStr: String = "target.col"
private final val Total: String = "total"
private final val AccurateStr: String = "accurate"
private final val InAccurateStr: String = "inaccurate"
override val supportsRecordWrite: Boolean = true
override val supportsMetricWrite: Boolean = true
val targetSource: String = getFromConfig[String](TargetSourceStr, null)
val exprOpt: Option[Seq[AccuracyExpr]] =
Option(getFromConfig[Seq[Map[String, String]]](Expression, null).map(toAccuracyExpr).distinct)
validate()
override def impl(sparkSession: SparkSession): (DataFrame, DataFrame) = {
import org.apache.griffin.measure.step.builder.ConstantColumns
datasetValidations(sparkSession)
val dataSource = sparkSession.read.table(measureParam.getDataSource)
val targetDataSource = sparkSession.read.table(targetSource).drop(ConstantColumns.tmst)
exprOpt match {
case Some(accuracyExpr) =>
import org.apache.spark.sql.Column
val joinExpr =
accuracyExpr.map(e => col(e.sourceCol) === col(e.targetCol)).reduce(_ and _)
val indicatorExpr =
accuracyExpr
.map(e =>
coalesce(col(e.sourceCol), lit("")) notEqual coalesce(col(e.targetCol), lit("")))
.reduce(_ or _)
val recordsDf = targetDataSource
.join(dataSource, joinExpr, "outer")
.withColumn(valueColumn, when(indicatorExpr, 1).otherwise(0))
.selectExpr(s"${measureParam.getDataSource}.*", valueColumn)
val selectCols = Seq(Total, AccurateStr, InAccurateStr).flatMap(e => Seq(lit(e), col(e)))
val metricColumn: Column = map(selectCols: _*).as(valueColumn)
val metricDf = recordsDf
.withColumn(Total, lit(1))
.agg(sum(Total).as(Total), sum(valueColumn).as(InAccurateStr))
.withColumn(AccurateStr, col(Total) - col(InAccurateStr))
.select(metricColumn)
(recordsDf, metricDf)
case None =>
throw new IllegalArgumentException(s"'$Expression' must be defined.")
}
}
private def validate(): Unit = {
assert(exprOpt.isDefined, s"'$Expression' must be defined.")
assert(exprOpt.nonEmpty, s"'$Expression' must not be empty.")
assert(
!StringUtil.isNullOrEmpty(targetSource),
s"'$TargetSourceStr' must not be null or empty.")
}
private def toAccuracyExpr(map: Map[String, String]): AccuracyExpr = {
assert(map.contains(SourceColStr), s"'$SourceColStr' must be defined.")
assert(map.contains(TargetColStr), s"'$TargetColStr' must be defined.")
AccuracyExpr(map(SourceColStr), map(TargetColStr))
}
private def datasetValidations(sparkSession: SparkSession): Unit = {
assert(
sparkSession.catalog.tableExists(targetSource),
s"Target source with name '$targetSource' does not exist.")
val datasourceName = measureParam.getDataSource
val dataSourceCols =
sparkSession.read.table(datasourceName).columns.map(_.toLowerCase(Locale.ROOT)).toSet
val targetDataSourceCols =
sparkSession.read.table(targetSource).columns.map(_.toLowerCase(Locale.ROOT)).toSet
val accuracyExpr = exprOpt.get
val (forDataSource, forTarget) =
accuracyExpr
.map(
e =>
(
(e.sourceCol, dataSourceCols.contains(e.sourceCol)),
(e.targetCol, targetDataSourceCols.contains(e.targetCol))))
.unzip
val invalidColsDataSource = forDataSource.filterNot(_._2)
val invalidColsTarget = forTarget.filterNot(_._2)
assert(
invalidColsDataSource.isEmpty,
s"Column(s) [${invalidColsDataSource.map(_._1).mkString(", ")}] " +
s"do not exist in data set with name '$datasourceName'")
assert(
invalidColsTarget.isEmpty,
s"Column(s) [${invalidColsTarget.map(_._1).mkString(", ")}] " +
s"do not exist in target data set with name '$targetSource'")
}
}