blob: 966109e286c3d484618519fc43c3717bc78e7706 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.griffin.measure.execution
import java.util.Date
import java.util.concurrent.Executors
import scala.concurrent.{ExecutionContext, ExecutionContextExecutorService, Future}
import scala.util.{Failure, Success}
import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.griffin.measure.Loggable
import org.apache.griffin.measure.configuration.dqdefinition.MeasureParam
import org.apache.griffin.measure.configuration.enums.{MeasureTypes, OutputType}
import org.apache.griffin.measure.context.{ContextId, DQContext}
import org.apache.griffin.measure.execution.impl._
import org.apache.griffin.measure.step.write.{MetricFlushStep, MetricWriteStep, RecordWriteStep}
/**
* MeasureExecutor
*
* This acts as the starting point for the execution of different data quality measures
* defined by the users in `DQConfig`. Execution of the measures involves the following steps,
* - Create a fix pool of threads which will be used to execute measures in parallel
* - For each measure defined per data source,
* - Caching data source(s) if necessary
* - In parallel do the following,
* - Create Measure entity (transformation step)
* - Write Metrics if required and if supported (metric write step)
* - Write Records if required and if supported (record write step)
* - Clear internal objects (metric flush step)
* - Un caching data source(s) if cached already.
*
* In contrast to the execution of `GriffinDslDQStepBuilder`, `MeasureExecutor` executes each of
* the defined measures independently. This means that the outputs (metrics and records) are written
* independently for each measure.
*
* @param context Instance of `DQContext`
*/
case class MeasureExecutor(context: DQContext) extends Loggable {
/**
* SparkSession for this Griffin Application.
*/
private val sparkSession: SparkSession = context.sparkSession
/**
* Enable or disable caching of data sources before execution. Defaults to `true`.
*/
private val cacheDataSources: Boolean = sparkSession.sparkContext.getConf
.getBoolean("spark.griffin.measure.cacheDataSources", defaultValue = true)
/**
* Size of thread pool for parallel measure execution.
* Defaults to number of processors available to the spark driver JVM.
*/
private val numThreads: Int = sparkSession.sparkContext.getConf
.getInt("spark.griffin.measure.parallelism", Runtime.getRuntime.availableProcessors())
/**
* Service to handle threaded execution of tasks (measures).
*/
private implicit val ec: ExecutionContextExecutorService =
ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(numThreads))
/**
* Starting point of measure execution.
*
* @param measureParams Object representation(s) of user defined measure(s).
*/
def execute(measureParams: Seq[MeasureParam]): Unit = {
implicit val measureCountByDataSource: Map[String, Int] = measureParams
.map(_.getDataSource)
.groupBy(x => x)
.mapValues(_.length)
measureParams
.groupBy(measureParam => measureParam.getDataSource)
.foreach(measuresForSource => {
val dataSourceName = measuresForSource._1
val measureParams = measuresForSource._2
withCacheIfNecessary(dataSourceName, {
val dataSource = sparkSession.read.table(dataSourceName)
if (dataSource.isStreaming) {
// TODO this is a no op as streaming queries need to be registered.
dataSource.writeStream
.foreachBatch((_, batchId) => {
executeMeasures(measureParams, Some(batchId))
})
} else {
executeMeasures(measureParams)
}
})
})
}
/**
* Performs a function with cached data source if necessary.
* Caches data sources if it has more than 1 measure defined for it, and, `cacheDataSources` is `true`.
* After the function is complete, the data source is uncached.
*
* @param dataSourceName name of data source
* @param f function to perform
* @param measureCountByDataSource number of measures for each data source
* @return
*/
private def withCacheIfNecessary(dataSourceName: String, f: => Unit)(
implicit measureCountByDataSource: Map[String, Int]): Unit = {
val numMeasures = measureCountByDataSource(dataSourceName)
var isCached = false
if (cacheDataSources && numMeasures > 1) {
info(
s"Caching data source with name '$dataSourceName' as $numMeasures measures are applied on it.")
sparkSession.catalog.cacheTable(dataSourceName)
isCached = true
}
f
if (isCached) {
sparkSession.catalog.uncacheTable(dataSourceName)
}
}
/**
* Executes measures for a data sources. Involves the following steps,
* - Transformation
* - Persist metrics if required
* - Persist records if required
*
* All measures are executed in parallel.
*
* @param measureParams Object representation(s) of user defined measure(s).
* @param batchId Option batch Id in case of streaming sources to identify micro batches.
*/
private def executeMeasures(
measureParams: Seq[MeasureParam],
batchId: Option[Long] = None): Unit = {
val batchDetailsOpt = batchId.map(bId => s"for batch id $bId").getOrElse(StringUtils.EMPTY)
// define the tasks
val tasks: Map[String, Future[_]] = (for (i <- measureParams.indices)
yield {
val measureParam = measureParams(i)
val measureName = measureParam.getName
(measureName, Future {
val currentContext = context.cloneDQContext(ContextId(new Date().getTime))
info(s"Started execution of measure with name '$measureName'")
val measure = createMeasure(measureParam)
val (recordsDf, metricsDf) = measure.execute(batchId)
persistRecords(currentContext, measure, recordsDf)
persistMetrics(currentContext, measure, metricsDf)
MetricFlushStep().execute(currentContext)
})
}).toMap
tasks.foreach(task =>
task._2.onComplete {
case Success(_) =>
info(s"Successfully executed measure with name '${task._1}' $batchDetailsOpt")
case Failure(e) =>
error(s"Error executing measure with name '${task._1}' $batchDetailsOpt", e)
})
while (!tasks.forall(_._2.isCompleted)) {
info(
s"Measures with name ${tasks.filterNot(_._2.isCompleted).keys.mkString("['", "', '", "']")} " +
s"are still executing.")
Thread.sleep(1000)
}
info(
s"Completed execution of all measures for data source with name '${measureParams.head.getDataSource}'.")
}
/**
* Instantiates measure implementations based on the user defined configurations.
*
* @param measureParam Object representation of user defined a measure.
* @return
*/
private def createMeasure(measureParam: MeasureParam): Measure = {
measureParam.getType match {
case MeasureTypes.Completeness => CompletenessMeasure(sparkSession, measureParam)
case MeasureTypes.Duplication => DuplicationMeasure(sparkSession, measureParam)
case MeasureTypes.Profiling => ProfilingMeasure(sparkSession, measureParam)
case MeasureTypes.Accuracy => AccuracyMeasure(sparkSession, measureParam)
case MeasureTypes.SparkSQL => SparkSQLMeasure(sparkSession, measureParam)
case _ =>
val errorMsg = s"Measure type '${measureParam.getType}' is not supported."
val exception = new NotImplementedError(errorMsg)
error(errorMsg, exception)
throw exception
}
}
/**
* Persists records to one or more sink based on the user defined measure configuration.
*
* @param context DQ Context.
* @param measure a measure implementation
* @param recordsDf records dataframe to persist.
*/
private def persistRecords(context: DQContext, measure: Measure, recordsDf: DataFrame): Unit = {
val measureParam: MeasureParam = measure.measureParam
measureParam.getOutputOpt(OutputType.RecordOutputType) match {
case Some(_) =>
if (measure.supportsRecordWrite) {
recordsDf.createOrReplaceTempView("recordsDf")
RecordWriteStep(measureParam.getName, "recordsDf").execute(context)
} else warn(s"Measure with name '${measureParam.getName}' doesn't support record write")
case None =>
}
}
/**
* Persists metrics to one or more sink based on the user defined measure configuration.
*
* @param context DQ Context.
* @param measure a measure implementation
* @param metricsDf metrics dataframe to persist
*/
private def persistMetrics(context: DQContext, measure: Measure, metricsDf: DataFrame): Unit = {
val measureParam: MeasureParam = measure.measureParam
measureParam.getOutputOpt(OutputType.MetricOutputType) match {
case Some(o) =>
if (measure.supportsMetricWrite) {
metricsDf.createOrReplaceTempView("metricsDf")
MetricWriteStep(measureParam.getName, "metricsDf", o.getFlatten)
.execute(context)
} else warn(s"Measure with name '${measureParam.getName}' doesn't support metric write")
case None =>
}
}
}