package org.apache.griffin.measure.context
import org.apache.spark.sql.{Encoders, SparkSession}
import org.apache.griffin.measure.configuration.dqdefinition._
import org.apache.griffin.measure.configuration.enums._
import org.apache.griffin.measure.datasource._
import org.apache.griffin.measure.sink.{Sink, SinkFactory}
* dq context: the context of each calculation
* unique context id in each calculation
* access the same spark session this app created
case class DQContext(contextId: ContextId,
name: String,
dataSources: Seq[DataSource],
sinkParams: Seq[SinkParam],
procType: ProcessType
)(@transient implicit val sparkSession: SparkSession) {
val compileTableRegister: CompileTableRegister = CompileTableRegister()
val runTimeTableRegister: RunTimeTableRegister = RunTimeTableRegister(sparkSession)
val dataFrameCache: DataFrameCache = DataFrameCache()
val metricWrapper: MetricWrapper = MetricWrapper(name, sparkSession.sparkContext.applicationId)
val writeMode = WriteMode.defaultMode(procType)
val dataSourceNames: Seq[String] = {
// sort data source names, put baseline data source name to the head
val (blOpt, others) = dataSources.foldLeft((None: Option[String], Nil: Seq[String])) { (ret, ds) =>
val (opt, seq) = ret
if (opt.isEmpty && ds.isBaseline) (Some(, seq) else (opt, seq :+
blOpt match {
case Some(bl) => bl +: others
case _ => others
dataSourceNames.foreach(name => compileTableRegister.registerTable(name))
def getDataSourceName(index: Int): String = {
if (dataSourceNames.size > index) dataSourceNames(index) else ""
implicit val encoder = Encoders.STRING
val functionNames: Seq[String] =
val dataSourceTimeRanges = loadDataSources()
def loadDataSources(): Map[String, TimeRange] = { { ds =>
(, ds.loadData(this))
private val sinkFactory = SinkFactory(sinkParams, name)
private val defaultSink: Sink = createSink(contextId.timestamp)
def getSink(timestamp: Long): Sink = {
if (timestamp == contextId.timestamp) getSink()
else createSink(timestamp)
def getSink(): Sink = defaultSink
private def createSink(t: Long): Sink = {
procType match {
case BatchProcessType => sinkFactory.getSinks(t, true)
case StreamingProcessType => sinkFactory.getSinks(t, false)
def cloneDQContext(newContextId: ContextId): DQContext = {
DQContext(newContextId, name, dataSources, sinkParams, procType)(sparkSession)
def clean(): Unit = {
private def printTimeRanges(): Unit = {
if (dataSourceTimeRanges.nonEmpty) {
val timeRangesStr = { pair =>
val (name, timeRange) = pair
s"${name} -> (${timeRange.begin}, ${timeRange.end}]"
}.mkString(", ")
println(s"data source timeRanges: ${timeRangesStr}")