[GRIFFIN-358] Added parallelization to MeasureExecutor
diff --git a/measure/src/main/scala/org/apache/griffin/measure/Application.scala b/measure/src/main/scala/org/apache/griffin/measure/Application.scala
index 88bcc16..b57c9c5 100644
--- a/measure/src/main/scala/org/apache/griffin/measure/Application.scala
+++ b/measure/src/main/scala/org/apache/griffin/measure/Application.scala
@@ -119,6 +119,8 @@
if (!success) {
sys.exit(-5)
+ } else {
+ sys.exit(0)
}
}
diff --git a/measure/src/main/scala/org/apache/griffin/measure/context/DQContext.scala b/measure/src/main/scala/org/apache/griffin/measure/context/DQContext.scala
index fa1468c..c72f06c 100644
--- a/measure/src/main/scala/org/apache/griffin/measure/context/DQContext.scala
+++ b/measure/src/main/scala/org/apache/griffin/measure/context/DQContext.scala
@@ -66,15 +66,15 @@
implicit val encoder: Encoder[String] = Encoders.STRING
val functionNames: Seq[String] = sparkSession.catalog.listFunctions.map(_.name).collect.toSeq
- val dataSourceTimeRanges: Map[String, TimeRange] = loadDataSources()
+ var dataSourceTimeRanges: Map[String, TimeRange] = _
def loadDataSources(): Map[String, TimeRange] = {
- dataSources.map { ds =>
+ dataSourceTimeRanges = dataSources.map { ds =>
(ds.name, ds.loadData(this))
}.toMap
- }
- printTimeRanges()
+ dataSourceTimeRanges
+ }
private val sinkFactory = SinkFactory(sinkParams, name)
private val defaultSinks: Seq[Sink] = createSinks(contextId.timestamp)
@@ -105,16 +105,4 @@
dataFrameCache.clearAllTrashDataFrames()
}
- private def printTimeRanges(): Unit = {
- if (dataSourceTimeRanges.nonEmpty) {
- val timeRangesStr = dataSourceTimeRanges
- .map { pair =>
- val (name, timeRange) = pair
- s"$name -> (${timeRange.begin}, ${timeRange.end}]"
- }
- .mkString(", ")
- println(s"data source timeRanges: $timeRangesStr")
- }
- }
-
}
diff --git a/measure/src/main/scala/org/apache/griffin/measure/execution/Measure.scala b/measure/src/main/scala/org/apache/griffin/measure/execution/Measure.scala
index 0454f38..a087d1b 100644
--- a/measure/src/main/scala/org/apache/griffin/measure/execution/Measure.scala
+++ b/measure/src/main/scala/org/apache/griffin/measure/execution/Measure.scala
@@ -27,23 +27,54 @@
import org.apache.griffin.measure.configuration.dqdefinition.MeasureParam
import org.apache.griffin.measure.utils.ParamUtil._
+/**
+ * Measure
+ *
+ * An abstraction for a data quality measure implementation.
+ */
trait Measure extends Loggable {
import Measure._
+ /**
+ * SparkSession for this Griffin Application.
+ */
val sparkSession: SparkSession
+ /**
+ * Object representation of user defined measure.
+ */
val measureParam: MeasureParam
+ /**
+ * If this measure supports record writing.
+ */
val supportsRecordWrite: Boolean
+ /**
+ * If this measure supports metric writing.
+ */
val supportsMetricWrite: Boolean
+ /**
+ * Metric values column.
+ */
final val valueColumn = s"${MeasureColPrefix}_${measureParam.getName}"
+ /**
+ * Helper method to get a typed value from measure configuration based on given key.
+ *
+ * @param key given key for which the value needs to be fetched.
+ * @param defValue default value in case of no value.
+ * @tparam T type of value to get.
+ * @return value for given key
+ */
def getFromConfig[T: ClassTag](key: String, defValue: T): T = {
measureParam.getConfig.getAnyRef[T](key, defValue)
}
+ /**
+ * Enriches metrics dataframe with some additional keys.
+ */
// todo add status col to persist blank metrics if the measure fails
def preProcessMetrics(input: DataFrame): DataFrame = {
if (supportsMetricWrite) {
@@ -56,6 +87,9 @@
} else input
}
+ /**
+ * Enriches records dataframe with a status column marking rows as good or bad based on values.
+ */
def preProcessRecords(input: DataFrame): DataFrame = {
if (supportsRecordWrite) {
input
@@ -64,38 +98,59 @@
} else input
}
+ /**
+ * Implementation of this measure.
+ *
+ * @return tuple of records dataframe and metric dataframe
+ */
def impl(): (DataFrame, DataFrame)
+ /**
+ * Implementation should define validtion checks in this method (if required).
+ * This method needs to be called explicitly call this method (preferably during measure creation).
+ *
+ * Defaults to no-op.
+ */
def validate(): Unit = {}
+ /**
+ * Executes this measure specific transformation on input data source.
+ *
+ * @param batchId batch id to append in case of streaming source.
+ * @return enriched tuple of records dataframe and metric dataframe
+ */
def execute(batchId: Option[Long] = None): (DataFrame, DataFrame) = {
val (recordsDf, metricDf) = impl()
val processedRecordDf = preProcessRecords(recordsDf)
val processedMetricDf = preProcessMetrics(metricDf)
- var batchDetailsOpt = StringUtils.EMPTY
val res = batchId match {
case Some(batchId) =>
implicit val bId: Long = batchId
- batchDetailsOpt = s"for batch id $bId"
(appendBatchIdIfAvailable(processedRecordDf), appendBatchIdIfAvailable(processedMetricDf))
case None => (processedRecordDf, processedMetricDf)
}
- info(
- s"Execution of '${measureParam.getType}' measure " +
- s"with name '${measureParam.getName}' is complete $batchDetailsOpt")
-
res
}
+ /**
+ * Appends batch id to metrics in case of streaming sources.
+ *
+ * @param input metric dataframe
+ * @param batchId batch id to append
+ * @return updated metric dataframe
+ */
private def appendBatchIdIfAvailable(input: DataFrame)(implicit batchId: Long): DataFrame = {
input.withColumn(BatchId, typedLit[Long](batchId))
}
}
+/**
+ * Measure Constants.
+ */
object Measure {
final val DataSource = "data_source"
diff --git a/measure/src/main/scala/org/apache/griffin/measure/execution/MeasureExecutor.scala b/measure/src/main/scala/org/apache/griffin/measure/execution/MeasureExecutor.scala
index 1b75f2f..1f49328 100644
--- a/measure/src/main/scala/org/apache/griffin/measure/execution/MeasureExecutor.scala
+++ b/measure/src/main/scala/org/apache/griffin/measure/execution/MeasureExecutor.scala
@@ -17,19 +17,79 @@
package org.apache.griffin.measure.execution
-import org.apache.spark.sql.DataFrame
+import java.util.Date
+import java.util.concurrent.Executors
+
+import scala.concurrent.{ExecutionContext, ExecutionContextExecutorService, Future}
+import scala.util.{Failure, Success}
+
+import org.apache.commons.lang3.StringUtils
+import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.griffin.measure.Loggable
import org.apache.griffin.measure.configuration.dqdefinition.MeasureParam
import org.apache.griffin.measure.configuration.enums.{MeasureTypes, OutputType}
-import org.apache.griffin.measure.context.DQContext
+import org.apache.griffin.measure.context.{ContextId, DQContext}
import org.apache.griffin.measure.execution.impl._
import org.apache.griffin.measure.step.write.{MetricFlushStep, MetricWriteStep, RecordWriteStep}
+/**
+ * MeasureExecutor
+ *
+ * This acts as the starting point for the execution of different data quality measures
+ * defined by the users in `DQConfig`. Execution of the measures involves the following steps,
+ * - Create a fix pool of threads which will be used to execute measures in parallel
+ * - For each measure defined per data source,
+ * - Caching data source(s) if necessary
+ * - In parallel do the following,
+ * - Create Measure entity (transformation step)
+ * - Write Metrics if required and if supported (metric write step)
+ * - Write Records if required and if supported (record write step)
+ * - Clear internal objects (metric flush step)
+ * - Un caching data source(s) if cached already.
+ *
+ * In contrast to the execution of `GriffinDslDQStepBuilder`, `MeasureExecutor` executes each of
+ * the defined measures independently. This means that the outputs (metrics and records) are written
+ * independently for each measure.
+ *
+ * @param context Instance of `DQContext`
+ */
case class MeasureExecutor(context: DQContext) extends Loggable {
+ /**
+ * SparkSession for this Griffin Application.
+ */
+ private val sparkSession: SparkSession = context.sparkSession
+
+ /**
+ * Enable or disable caching of data sources before execution. Defaults to `true`.
+ */
+ private val cacheDataSources: Boolean = sparkSession.sparkContext.getConf
+ .getBoolean("spark.griffin.measure.cacheDataSources", defaultValue = true)
+
+ /**
+ * Size of thread pool for parallel measure execution.
+ * Defaults to number of processors available to the spark driver JVM.
+ */
+ private val numThreads: Int = sparkSession.sparkContext.getConf
+ .getInt("spark.griffin.measure.parallelism", Runtime.getRuntime.availableProcessors())
+
+ /**
+ * Service to handle threaded execution of tasks (measures).
+ */
+ private implicit val ec: ExecutionContextExecutorService =
+ ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(numThreads))
+
+ /**
+ * Starting point of measure execution.
+ *
+ * @param measureParams Object representation(s) of user defined measure(s).
+ */
def execute(measureParams: Seq[MeasureParam]): Unit = {
- cacheIfNecessary(measureParams)
+ implicit val measureCountByDataSource: Map[String, Int] = measureParams
+ .map(_.getDataSource)
+ .groupBy(x => x)
+ .mapValues(_.length)
measureParams
.groupBy(measureParam => measureParam.getDataSource)
@@ -37,50 +97,114 @@
val dataSourceName = measuresForSource._1
val measureParams = measuresForSource._2
- val dataSource = context.sparkSession.read.table(dataSourceName)
+ withCacheIfNecessary(dataSourceName, {
+ val dataSource = sparkSession.read.table(dataSourceName)
- if (dataSource.isStreaming) {
- // todo this is a no op as streaming queries need to be registered.
- dataSource.writeStream
- .foreachBatch((_, batchId) => {
- executeMeasures(measureParams, Some(batchId))
- })
- } else {
- executeMeasures(measureParams)
- }
+ if (dataSource.isStreaming) {
+ // TODO this is a no op as streaming queries need to be registered.
+
+ dataSource.writeStream
+ .foreachBatch((_, batchId) => {
+ executeMeasures(measureParams, Some(batchId))
+ })
+ } else {
+ executeMeasures(measureParams)
+ }
+ })
})
}
- private def cacheIfNecessary(measureParams: Seq[MeasureParam]): Unit = {
- measureParams
- .map(_.getDataSource)
- .groupBy(x => x)
- .mapValues(_.length)
- .filter(_._2 > 1)
- .foreach(source => {
- info(
- s"Caching data source with name '${source._1}'" +
- s" as ${source._2} measures are applied on it.")
- context.sparkSession.catalog.cacheTable(source._1)
- })
+ /**
+ * Performs a function with cached data source if necessary.
+ * Caches data sources if it has more than 1 measure defined for it, and, `cacheDataSources` is `true`.
+ * After the function is complete, the data source is uncached.
+ *
+ * @param dataSourceName name of data source
+ * @param f function to perform
+ * @param measureCountByDataSource number of measures for each data source
+ * @tparam T return type of function `f`
+ * @return
+ */
+ private def withCacheIfNecessary[T](dataSourceName: String, f: => T)(
+ implicit measureCountByDataSource: Map[String, Int]): T = {
+ val numMeasures = measureCountByDataSource(dataSourceName)
+ var isCached = false
+ if (cacheDataSources && numMeasures > 1) {
+ info(
+ s"Caching data source with name '$dataSourceName' as $numMeasures measures are applied on it.")
+ sparkSession.catalog.cacheTable(dataSourceName)
+ isCached = true
+ }
+
+ f
+
+ if (isCached) {
+ sparkSession.catalog.uncacheTable(dataSourceName)
+ }
}
+ /**
+ * Executes measures for a data sources. Involves the following steps,
+ * - Transformation
+ * - Persist metrics if required
+ * - Persist records if required
+ *
+ * All measures are executed in parallel.
+ *
+ * @param measureParams Object representation(s) of user defined measure(s).
+ * @param batchId Option batch Id in case of streaming sources to identify micro batches.
+ */
private def executeMeasures(
measureParams: Seq[MeasureParam],
batchId: Option[Long] = None): Unit = {
- measureParams.foreach(measureParam => {
- val measure = createMeasure(measureParam)
- val (recordsDf, metricsDf) = measure.execute(batchId)
+ val batchDetailsOpt = batchId.map(bId => s"for batch id $bId").getOrElse(StringUtils.EMPTY)
- persistRecords(measure, recordsDf)
- persistMetrics(measure, metricsDf)
+ // define the tasks
+ val tasks: Map[String, Future[_]] = (for (i <- measureParams.indices)
+ yield {
+ val measureParam = measureParams(i)
+ val measureName = measureParam.getName
- MetricFlushStep().execute(context)
+ (measureName, Future {
+ val currentContext = context.cloneDQContext(ContextId(new Date().getTime))
+ info(s"Started execution of measure with name '$measureName'")
+
+ val measure = createMeasure(measureParam)
+ val (recordsDf, metricsDf) = measure.execute(batchId)
+
+ persistRecords(currentContext, measure, recordsDf)
+ persistMetrics(currentContext, measure, metricsDf)
+
+ MetricFlushStep().execute(currentContext)
+ })
+ }).toMap
+
+ tasks.foreach(task =>
+ task._2.onComplete {
+ case Success(_) =>
+ info(s"Successfully executed measure with name '${task._1}' $batchDetailsOpt")
+ case Failure(e) =>
+ error(s"Error executing measure with name '${task._1}' $batchDetailsOpt", e)
})
+
+ while (!tasks.forall(_._2.isCompleted)) {
+ info(
+ s"Measures with name ${tasks.filterNot(_._2.isCompleted).keys.mkString("['", "', '", "']")} " +
+ s"are still executing.")
+ Thread.sleep(1000)
+ }
+
+ info(
+ s"Completed execution of all measures for data source with name '${measureParams.head.getDataSource}'.")
}
+ /**
+ * Instantiates measure implementations based on the user defined configurations.
+ *
+ * @param measureParam Object representation of user defined a measure.
+ * @return
+ */
private def createMeasure(measureParam: MeasureParam): Measure = {
- val sparkSession = context.sparkSession
measureParam.getType match {
case MeasureTypes.Completeness => CompletenessMeasure(sparkSession, measureParam)
case MeasureTypes.Duplication => DuplicationMeasure(sparkSession, measureParam)
@@ -95,7 +219,14 @@
}
}
- private def persistRecords(measure: Measure, recordsDf: DataFrame): Unit = {
+ /**
+ * Persists records to one or more sink based on the user defined measure configuration.
+ *
+ * @param context DQ Context.
+ * @param measure a measure implementation
+ * @param recordsDf records dataframe to persist.
+ */
+ private def persistRecords(context: DQContext, measure: Measure, recordsDf: DataFrame): Unit = {
val measureParam: MeasureParam = measure.measureParam
measureParam.getOutputOpt(OutputType.RecordOutputType) match {
@@ -108,7 +239,14 @@
}
}
- private def persistMetrics(measure: Measure, metricsDf: DataFrame): Unit = {
+ /**
+ * Persists metrics to one or more sink based on the user defined measure configuration.
+ *
+ * @param context DQ Context.
+ * @param measure a measure implementation
+ * @param metricsDf metrics dataframe to persist
+ */
+ private def persistMetrics(context: DQContext, measure: Measure, metricsDf: DataFrame): Unit = {
val measureParam: MeasureParam = measure.measureParam
measureParam.getOutputOpt(OutputType.MetricOutputType) match {
diff --git a/measure/src/main/scala/org/apache/griffin/measure/launch/batch/BatchDQApp.scala b/measure/src/main/scala/org/apache/griffin/measure/launch/batch/BatchDQApp.scala
index 9c03f98..7e97dc7 100644
--- a/measure/src/main/scala/org/apache/griffin/measure/launch/batch/BatchDQApp.scala
+++ b/measure/src/main/scala/org/apache/griffin/measure/launch/batch/BatchDQApp.scala
@@ -77,6 +77,8 @@
dqContext = DQContext(contextId, metricName, dataSources, sinkParams, BatchProcessType)(
sparkSession)
+ dqContext.loadDataSources()
+
// start id
val applicationId = sparkSession.sparkContext.applicationId
dqContext.getSinks.foreach(_.open(applicationId))
diff --git a/measure/src/main/scala/org/apache/griffin/measure/launch/streaming/StreamingDQApp.scala b/measure/src/main/scala/org/apache/griffin/measure/launch/streaming/StreamingDQApp.scala
index 09554f3..c83acb3 100644
--- a/measure/src/main/scala/org/apache/griffin/measure/launch/streaming/StreamingDQApp.scala
+++ b/measure/src/main/scala/org/apache/griffin/measure/launch/streaming/StreamingDQApp.scala
@@ -98,6 +98,7 @@
val globalContext: DQContext =
DQContext(contextId, metricName, dataSources, sinkParams, StreamingProcessType)(
sparkSession)
+ globalContext.loadDataSources()
// start id
val applicationId = sparkSession.sparkContext.applicationId