blob: 5932598320ede779a80bb07c415cd99faacfa3f5 [file] [log] [blame]
/** Copyright 2015 TappingStone, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prediction.workflow
import com.github.nscala_time.time.Imports._
import com.google.common.io.ByteStreams
import grizzled.slf4j.Logging
import io.prediction.controller.Engine
import io.prediction.controller.Utils
import io.prediction.controller.Workflow
import io.prediction.controller.WorkflowParams
import io.prediction.core.BaseEngine
import io.prediction.data.storage.EngineInstance
import io.prediction.data.storage.EvaluationInstance
import io.prediction.data.storage.Storage
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.native.JsonMethods._
import org.json4s.native.Serialization.write
import scala.language.existentials
object CreateWorkflow extends Logging {
case class WorkflowConfig(
deployMode: String = "",
batch: String = "",
engineId: String = "",
engineVersion: String = "",
engineVariant: String = "",
engineFactory: String = "",
engineParamsKey: String = "",
evaluationClass: Option[String] = None,
engineParamsGeneratorClass: Option[String] = None,
env: Option[String] = None,
skipSanityCheck: Boolean = false,
stopAfterRead: Boolean = false,
stopAfterPrepare: Boolean = false,
verbosity: Int = 0,
verbose: Boolean = false,
debug: Boolean = false,
logFile: Option[String] = None)
case class AlgorithmParams(name: String, params: JValue)
implicit lazy val formats = Utils.json4sDefaultFormats +
new NameParamsSerializer
val hadoopConf = new Configuration
val hdfs = FileSystem.get(hadoopConf)
val localfs = FileSystem.getLocal(hadoopConf)
private def stringFromFile(
basePath: String,
filePath: String,
fs: FileSystem = hdfs): String = {
try {
val p =
if (basePath == "") {
new Path(filePath)
} else {
new Path(basePath + Path.SEPARATOR + filePath)
}
new String(ByteStreams.toByteArray(fs.open(p)).map(_.toChar))
} catch {
case e: java.io.IOException =>
error(s"Error reading from file: ${e.getMessage}. Aborting workflow.")
sys.exit(1)
}
}
val parser = new scopt.OptionParser[WorkflowConfig]("CreateWorkflow") {
opt[String]("batch") action { (x, c) =>
c.copy(batch = x)
} text("Batch label of the workflow run.")
opt[String]("engine-id") required() action { (x, c) =>
c.copy(engineId = x)
} text("Engine's ID.")
opt[String]("engine-version") required() action { (x, c) =>
c.copy(engineVersion = x)
} text("Engine's version.")
opt[String]("engine-variant") required() action { (x, c) =>
c.copy(engineVariant = x)
} text("Engine variant JSON.")
opt[String]("evaluation-class") action { (x, c) =>
c.copy(evaluationClass = Some(x))
} text("Class name of the run's evaluator.")
opt[String]("engine-params-generator-class") action { (x, c) =>
c.copy(engineParamsGeneratorClass = Some(x))
} text("Path to evaluator parameters")
opt[String]("env") action { (x, c) =>
c.copy(env = Some(x))
} text("Comma-separated list of environmental variables (in 'FOO=BAR' " +
"format) to pass to the Spark execution environment.")
opt[Unit]("verbose") action { (x, c) =>
c.copy(verbose = true)
} text("Enable verbose output.")
opt[Unit]("debug") action { (x, c) =>
c.copy(debug = true)
} text("Enable debug output.")
opt[Unit]("skip-sanity-check") action { (x, c) =>
c.copy(skipSanityCheck = true)
}
opt[Unit]("stop-after-read") action { (x, c) =>
c.copy(stopAfterRead = true)
}
opt[Unit]("stop-after-prepare") action { (x, c) =>
c.copy(stopAfterPrepare = true)
}
opt[String]("deploy-mode") action { (x, c) =>
c.copy(deployMode = x)
}
opt[Int]("verbosity") action { (x, c) =>
c.copy(verbosity = x)
}
opt[String]("engine-factory") action { (x, c) =>
c.copy(engineFactory = x)
}
opt[String]("engine-params-key") action { (x, c) =>
c.copy(engineParamsKey = x)
}
opt[String]("log-file") action { (x, c) =>
c.copy(logFile = Some(x))
}
}
def main(args: Array[String]): Unit = {
val wfcOpt = parser.parse(args, WorkflowConfig())
if (wfcOpt.isEmpty) {
logger.error("WorkflowConfig is empty. Quitting")
return
}
val wfc = wfcOpt.get
WorkflowUtils.modifyLogging(wfc.verbose)
val targetfs = if (wfc.deployMode == "cluster") hdfs else localfs
val variantJson = parse(stringFromFile("", wfc.engineVariant, targetfs))
val engineFactory = if (wfc.engineFactory == "") {
variantJson \ "engineFactory" match {
case JString(s) => s
case _ =>
error("Unable to read engine factory class name from " +
s"${wfc.engineVariant}. Aborting.")
sys.exit(1)
}
} else wfc.engineFactory
val variantId = variantJson \ "id" match {
case JString(s) => s
case _ =>
error("Unable to read engine variant ID from " +
s"${wfc.engineVariant}. Aborting.")
sys.exit(1)
}
val (engineLanguage, engineFactoryObj) = try {
WorkflowUtils.getEngine(engineFactory, getClass.getClassLoader)
} catch {
case e @ (_: ClassNotFoundException | _: NoSuchMethodException) =>
error(s"Unable to obtain engine: ${e.getMessage}. Aborting workflow.")
sys.exit(1)
}
val engine: BaseEngine[_, _, _, _] = engineFactoryObj()
val evaluation = wfc.evaluationClass.map { ec =>
try {
WorkflowUtils.getEvaluation(ec, getClass.getClassLoader)._2
} catch {
case e @ (_: ClassNotFoundException | _: NoSuchMethodException) =>
error(s"Unable to obtain evaluation ${ec}. Aborting workflow.", e)
sys.exit(1)
}
}
val engineParamsGenerator = wfc.engineParamsGeneratorClass.map { epg =>
try {
WorkflowUtils.getEngineParamsGenerator(epg, getClass.getClassLoader)._2
} catch {
case e @ (_: ClassNotFoundException | _: NoSuchMethodException) =>
error(s"Unable to obtain engine parameters generator ${epg}. " +
"Aborting workflow.", e)
sys.exit(1)
}
}
val pioEnvVars = wfc.env.map(e =>
e.split(',').flatMap(p =>
p.split('=') match {
case Array(k, v) => List(k -> v)
case _ => Nil
}
).toMap
).getOrElse(Map())
val customSparkConf = WorkflowUtils.extractSparkConf(variantJson)
val workflowParams = WorkflowParams(
verbose = wfc.verbosity,
skipSanityCheck = wfc.skipSanityCheck,
stopAfterRead = wfc.stopAfterRead,
stopAfterPrepare = wfc.stopAfterPrepare,
sparkEnv = WorkflowParams().sparkEnv ++ customSparkConf)
if (evaluation.isEmpty) {
// Evaluator Not Specified. Do training.
if (!engine.isInstanceOf[Engine[_,_,_,_,_,_]]) {
throw new NoSuchMethodException(s"Engine $engine is not trainable")
}
val trainableEngine = engine.asInstanceOf[Engine[_, _, _, _, _, _]]
val engineParams = if (wfc.engineParamsKey == "") {
trainableEngine.jValueToEngineParams(variantJson)
} else {
engineFactoryObj.engineParams(wfc.engineParamsKey)
}
val engineInstance = EngineInstance(
id = "",
status = "INIT",
startTime = DateTime.now,
endTime = DateTime.now,
engineId = wfc.engineId,
engineVersion = wfc.engineVersion,
engineVariant = variantId,
engineFactory = engineFactory,
evaluatorClass = wfc.evaluationClass.getOrElse(""),
batch = (if (wfc.batch == "") engineFactory else wfc.batch),
env = pioEnvVars,
sparkConf = workflowParams.sparkEnv,
dataSourceParams = write(engineParams.dataSourceParams),
preparatorParams = write(engineParams.preparatorParams),
algorithmsParams = write(engineParams.algorithmParamsList),
servingParams = write(engineParams.servingParams),
evaluatorParams = "",
evaluatorResults = "",
evaluatorResultsHTML = "",
evaluatorResultsJSON = "")
val engineInstanceId = Storage.getMetaDataEngineInstances.insert(
engineInstance)
CoreWorkflow.runTrain(
env = pioEnvVars,
params = workflowParams,
engine = trainableEngine,
engineParams = engineParams,
engineInstance = engineInstance.copy(id = engineInstanceId))
} else {
val evaluationInstance = EvaluationInstance(
evaluationClass = wfc.evaluationClass.get,
engineParamsGeneratorClass = wfc.engineParamsGeneratorClass.get,
batch = wfc.batch,
env = pioEnvVars,
sparkConf = workflowParams.sparkEnv
)
Workflow.runEvaluation(
evaluation = evaluation.get,
engineParamsGenerator = engineParamsGenerator.get,
evaluationInstance = evaluationInstance,
params = workflowParams)
}
}
}