blob: ab771e5ee8f1287d4bf679d36328fd2246f72e66 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.griffin.measure.execution
import org.apache.spark.sql.DataFrame
import org.apache.griffin.measure.Loggable
import org.apache.griffin.measure.configuration.dqdefinition.MeasureParam
import org.apache.griffin.measure.configuration.enums.{MeasureTypes, OutputType}
import org.apache.griffin.measure.context.DQContext
import org.apache.griffin.measure.execution.impl._
import org.apache.griffin.measure.step.write.{MetricFlushStep, MetricWriteStep, RecordWriteStep}
case class MeasureExecutor(context: DQContext) extends Loggable {
def execute(measureParams: Seq[MeasureParam]): Unit = {
cacheIfNecessary(measureParams)
measureParams
.groupBy(measureParam => measureParam.getDataSource)
.foreach(measuresForSource => {
val dataSourceName = measuresForSource._1
val measureParams = measuresForSource._2
val dataSource = context.sparkSession.read.table(dataSourceName)
if (dataSource.isStreaming) {
// todo this is a no op as streaming queries need to be registered.
dataSource.writeStream
.foreachBatch((_, batchId) => {
executeMeasures(measureParams, Some(batchId))
})
} else {
executeMeasures(measureParams)
}
})
}
private def cacheIfNecessary(measureParams: Seq[MeasureParam]): Unit = {
measureParams
.map(_.getDataSource)
.groupBy(x => x)
.mapValues(_.length)
.filter(_._2 > 1)
.foreach(source => {
info(
s"Caching data source with name '${source._1}'" +
s" as ${source._2} measures are applied on it.")
context.sparkSession.catalog.cacheTable(source._1)
})
}
private def executeMeasures(
measureParams: Seq[MeasureParam],
batchId: Option[Long] = None): Unit = {
measureParams.foreach(measureParam => {
val measure = createMeasure(measureParam)
val (recordsDf, metricsDf) = measure.execute(context, batchId)
persistRecords(measure, recordsDf)
persistMetrics(measure, metricsDf)
MetricFlushStep().execute(context)
})
}
private def createMeasure(measureParam: MeasureParam): Measure = {
measureParam.getType match {
case MeasureTypes.Completeness => CompletenessMeasure(measureParam)
case MeasureTypes.Duplication => DuplicationMeasure(measureParam)
case MeasureTypes.Profiling => ProfilingMeasure(measureParam)
case MeasureTypes.Accuracy => AccuracyMeasure(measureParam)
case MeasureTypes.SparkSQL => SparkSQLMeasure(measureParam)
case _ =>
val errorMsg = s"Measure type '${measureParam.getType}' is not supported."
val exception = new NotImplementedError(errorMsg)
error(errorMsg, exception)
throw exception
}
}
private def persistRecords(measure: Measure, recordsDf: DataFrame): Unit = {
val measureParam: MeasureParam = measure.measureParam
measureParam.getOutputOpt(OutputType.RecordOutputType) match {
case Some(_) =>
if (measure.supportsRecordWrite) {
recordsDf.createOrReplaceTempView("recordsDf")
RecordWriteStep(measureParam.getName, "recordsDf").execute(context)
} else warn(s"Measure with name '${measureParam.getName}' doesn't support record write")
case None =>
}
}
private def persistMetrics(measure: Measure, metricsDf: DataFrame): Unit = {
val measureParam: MeasureParam = measure.measureParam
measureParam.getOutputOpt(OutputType.MetricOutputType) match {
case Some(o) =>
if (measure.supportsMetricWrite) {
metricsDf.createOrReplaceTempView("metricsDf")
MetricWriteStep(measureParam.getName, "metricsDf", o.getFlatten)
.execute(context)
} else warn(s"Measure with name '${measureParam.getName}' doesn't support metric write")
case None =>
}
}
}