| /** Copyright 2015 TappingStone, Inc. |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package io.prediction.tools |
| |
| import java.io.File |
| import java.net.URI |
| |
| import grizzled.slf4j.Logging |
| import io.prediction.data.storage.EngineManifest |
| import io.prediction.tools.console.ConsoleArgs |
| import io.prediction.workflow.WorkflowUtils |
| import org.apache.hadoop.conf.Configuration |
| import org.apache.hadoop.fs.FileSystem |
| import org.apache.hadoop.fs.Path |
| |
| import scala.sys.process._ |
| |
| object RunWorkflow extends Logging { |
| def runWorkflow( |
| ca: ConsoleArgs, |
| core: File, |
| em: EngineManifest, |
| variantJson: File): Int = { |
| // Collect and serialize PIO_* environmental variables |
| val pioEnvVars = sys.env.filter(kv => kv._1.startsWith("PIO_")).map(kv => |
| s"${kv._1}=${kv._2}" |
| ).mkString(",") |
| |
| val sparkHome = ca.common.sparkHome.getOrElse( |
| sys.env.getOrElse("SPARK_HOME", ".")) |
| |
| val hadoopConf = new Configuration |
| val hdfs = FileSystem.get(hadoopConf) |
| |
| val driverClassPathIndex = |
| ca.common.sparkPassThrough.indexOf("--driver-class-path") |
| val driverClassPathPrefix = |
| if (driverClassPathIndex != -1) { |
| Seq(ca.common.sparkPassThrough(driverClassPathIndex + 1)) |
| } else { |
| Seq() |
| } |
| val extraClasspaths = |
| driverClassPathPrefix ++ WorkflowUtils.thirdPartyClasspaths |
| |
| val deployModeIndex = |
| ca.common.sparkPassThrough.indexOf("--deploy-mode") |
| val deployMode = if (deployModeIndex != -1) { |
| ca.common.sparkPassThrough(deployModeIndex + 1) |
| } else { |
| "client" |
| } |
| |
| val extraFiles = WorkflowUtils.thirdPartyConfFiles |
| |
| val mainJar = |
| if (ca.build.uberJar) { |
| if (deployMode == "cluster") { |
| em.files.filter(_.startsWith("hdfs")).head |
| } else { |
| em.files.filterNot(_.startsWith("hdfs")).head |
| } |
| } else { |
| if (deployMode == "cluster") { |
| em.files.filter(_.contains("pio-assembly")).head |
| } else { |
| core.getCanonicalPath |
| } |
| } |
| |
| val workMode = |
| ca.common.evaluation.map(_ => "Evaluation").getOrElse("Training") |
| |
| val engineLocation = Seq( |
| sys.env("PIO_FS_ENGINESDIR"), |
| em.id, |
| em.version) |
| |
| if (deployMode == "cluster") { |
| val dstPath = new Path(engineLocation.mkString(Path.SEPARATOR)) |
| info("Cluster deploy mode detected. Trying to copy " + |
| s"${variantJson.getCanonicalPath} to " + |
| s"${hdfs.makeQualified(dstPath).toString}.") |
| hdfs.copyFromLocalFile(new Path(variantJson.toURI), dstPath) |
| } |
| |
| val sparkSubmit = |
| Seq(Seq(sparkHome, "bin", "spark-submit").mkString(File.separator)) ++ |
| ca.common.sparkPassThrough ++ |
| Seq( |
| "--class", |
| "io.prediction.workflow.CreateWorkflow", |
| "--name", |
| s"PredictionIO $workMode: ${em.id} ${em.version} (${ca.common.batch})") ++ |
| (if (!ca.build.uberJar) { |
| Seq("--jars", em.files.mkString(",")) |
| } else Seq()) ++ |
| (if (extraFiles.size > 0) { |
| Seq("--files", extraFiles.mkString(",")) |
| } else { |
| Seq() |
| }) ++ |
| (if (extraClasspaths.size > 0) { |
| Seq("--driver-class-path", extraClasspaths.mkString(":")) |
| } else { |
| Seq() |
| }) ++ |
| (if (ca.common.sparkKryo) { |
| Seq( |
| "--conf", |
| "spark.serializer=org.apache.spark.serializer.KryoSerializer") |
| } else { |
| Seq() |
| }) ++ |
| Seq( |
| mainJar, |
| "--env", |
| pioEnvVars, |
| "--engine-id", |
| em.id, |
| "--engine-version", |
| em.version, |
| "--engine-variant", |
| if (deployMode == "cluster") { |
| hdfs.makeQualified(new Path( |
| (engineLocation :+ variantJson.getName).mkString(Path.SEPARATOR))). |
| toString |
| } else { |
| variantJson.getCanonicalPath |
| }, |
| "--verbosity", |
| ca.common.verbosity.toString) ++ |
| ca.common.engineFactory.map( |
| x => Seq("--engine-factory", x)).getOrElse(Seq()) ++ |
| ca.common.engineParamsKey.map( |
| x => Seq("--engine-params-key", x)).getOrElse(Seq()) ++ |
| (if (deployMode == "cluster") Seq("--deploy-mode", "cluster") else Seq()) ++ |
| (if (ca.common.batch != "") Seq("--batch", ca.common.batch) else Seq()) ++ |
| (if (ca.common.verbose) Seq("--verbose") else Seq()) ++ |
| (if (ca.common.skipSanityCheck) Seq("--skip-sanity-check") else Seq()) ++ |
| (if (ca.common.stopAfterRead) Seq("--stop-after-read") else Seq()) ++ |
| (if (ca.common.stopAfterPrepare) { |
| Seq("--stop-after-prepare") |
| } else { |
| Seq() |
| }) ++ |
| ca.common.evaluation.map(x => Seq("--evaluation-class", x)). |
| getOrElse(Seq()) ++ |
| // If engineParamsGenerator is specified, it overrides the evaluation. |
| ca.common.engineParamsGenerator.orElse(ca.common.evaluation) |
| .map(x => Seq("--engine-params-generator-class", x)) |
| .getOrElse(Seq()) ++ |
| (if (ca.common.batch != "") Seq("--batch", ca.common.batch) else Seq()) ++ |
| Seq("--json-extractor", ca.common.jsonExtractor.toString) |
| |
| info(s"Submission command: ${sparkSubmit.mkString(" ")}") |
| Process(sparkSubmit, None, "CLASSPATH" -> "", "SPARK_YARN_USER_ENV" -> pioEnvVars).! |
| } |
| |
| def newRunWorkflow(ca: ConsoleArgs, em: EngineManifest): Int = { |
| val jarFiles = em.files.map(new URI(_)) |
| val args = Seq( |
| "--engine-id", |
| em.id, |
| "--engine-version", |
| em.version, |
| "--engine-variant", |
| ca.common.variantJson.toURI.toString, |
| "--verbosity", |
| ca.common.verbosity.toString) ++ |
| ca.common.engineFactory.map( |
| x => Seq("--engine-factory", x)).getOrElse(Seq()) ++ |
| ca.common.engineParamsKey.map( |
| x => Seq("--engine-params-key", x)).getOrElse(Seq()) ++ |
| (if (ca.common.batch != "") Seq("--batch", ca.common.batch) else Seq()) ++ |
| (if (ca.common.verbose) Seq("--verbose") else Seq()) ++ |
| (if (ca.common.skipSanityCheck) Seq("--skip-sanity-check") else Seq()) ++ |
| (if (ca.common.stopAfterRead) Seq("--stop-after-read") else Seq()) ++ |
| (if (ca.common.stopAfterPrepare) { |
| Seq("--stop-after-prepare") |
| } else { |
| Seq() |
| }) ++ |
| ca.common.evaluation.map(x => Seq("--evaluation-class", x)). |
| getOrElse(Seq()) ++ |
| // If engineParamsGenerator is specified, it overrides the evaluation. |
| ca.common.engineParamsGenerator.orElse(ca.common.evaluation) |
| .map(x => Seq("--engine-params-generator-class", x)) |
| .getOrElse(Seq()) ++ |
| (if (ca.common.batch != "") Seq("--batch", ca.common.batch) else Seq()) ++ |
| Seq("--json-extractor", ca.common.jsonExtractor.toString) |
| |
| Runner.runOnSpark( |
| "io.prediction.workflow.CreateWorkflow", |
| args, |
| ca, |
| jarFiles) |
| } |
| } |