blob: e5d07c466cc776fdc0c08bc134d39389581842b4 [file] [log] [blame]
package io.prediction.evaluations.commons.trainingtestsplit
import io.prediction.commons.filepath.U2ITrainingTestSplitFile
import java.io.File
import scala.io.Source
import scala.sys.process._
import grizzled.slf4j.Logger
case class U2ITrainingTestSplitTimeConfig(
hadoop: String = "",
pdioEvalJar: String = "",
hdfsRoot: String = "",
localTempRoot: String = "",
appid: Int = 0,
engineid: Int = 0,
evalid: Int = 0,
sequenceNum: Int = 0)
/**
* Wrapper for Scalding U2ITrainingTestSplitTime job
*
* Args:
* --hadoop <string> hadoop command
* --pdioEvalJar <string> the name of the Scalding U2ITrainingTestSplit job jar
* --sequenceNum. <int>. the sequence number (starts from 1 for the 1st iteration and then increment for later iterations)
*
* --dbType: <string> appdata DB type
* --dbName: <string>
* --dbHost: <string>. optional. (eg. "localhost")
* --dbPort: <int>. optional. (eg. 27017)
*
* --training_dbType: <string> training_appadta DB type
* --training_dbName: <string>
* --training_dbHost: <string>. optional
* --training_dbPort: <int>. optional
*
* --validation_dbType: <string> validation_appdata DB type
* --validation_dbName: <string>
* --validation_dbHost: <string>. optional
* --validation_dbPort: <int>. optional
*
* --test_dbType: <string> test_appdata DB type
* --test_dbName: <string>
* --test_dbHost: <string>. optional
* --test_dbPort: <int>. optional
*
* --hdfsRoot: <string>. Root directory of the HDFS
*
* --appid: <int>
* --engineid: <int>
* --evalid: <int>
*
* --itypes: <string separated by white space>. eg "--itypes type1 type2". If no --itypes specified, then ALL itypes will be used.
*
* --trainingPercent: <double> (0.01 to 1). training set percentage
* --validationPercent: <dboule> (0.01 to 1). validation set percentage
* --testPercent: <double> (0.01 to 1). test set percentage
* --timeorder: <boolean>. Require total percentage < 1
*
*/
object U2ITrainingTestSplitTime {
def main(args: Array[String]) {
val parser = new scopt.OptionParser[U2ITrainingTestSplitTimeConfig]("u2itrainingtestsplit") {
head("u2itrainingtestsplit")
opt[String]("hadoop") required () action { (x, c) =>
c.copy(hadoop = x)
} text ("path to the 'hadoop' command")
opt[String]("pdioEvalJar") required () action { (x, c) =>
c.copy(pdioEvalJar = x)
} text ("path to PredictionIO Hadoop job JAR")
opt[String]("hdfsRoot") required () action { (x, c) =>
c.copy(hdfsRoot = x)
} text ("PredictionIO root path in HDFS")
opt[String]("localTempRoot") required () action { (x, c) =>
c.copy(localTempRoot = x)
} text ("local directory for temporary storage")
opt[Int]("appid") required () action { (x, c) =>
c.copy(appid = x)
} text ("the App ID of this offline evaluation")
opt[Int]("engineid") required () action { (x, c) =>
c.copy(engineid = x)
} text ("the Engine ID of this offline evaluation")
opt[Int]("evalid") required () action { (x, c) =>
c.copy(evalid = x)
} text ("the OfflineEval ID of this offline evaluation")
opt[Int]("sequenceNum") required () action { (x, c) =>
c.copy(sequenceNum = x)
} validate { x =>
if (x >= 1) success else failure("--sequenceNum must be >= 1")
} text ("sequence (iteration) number of the offline evaluation")
override def errorOnUnknownArgument = false
}
val logger = Logger(U2ITrainingTestSplitTime.getClass)
parser.parse(args, U2ITrainingTestSplitTimeConfig()) map { config =>
val hadoop = config.hadoop
val pdioEvalJar = config.pdioEvalJar
val hdfsRoot = config.hdfsRoot
val localTempRoot = config.localTempRoot
val appid = config.appid
val engineid = config.engineid
val evalid = config.evalid
val sequenceNum = config.sequenceNum
val argsString = args.mkString(" ")
val resplit = sequenceNum > 1
/** command */
if (!resplit) {
// prep
val splitPrepCmd = hadoop + " jar " + pdioEvalJar + " io.prediction.evaluations.scalding.commons.u2itrainingtestsplit.U2ITrainingTestSplitTimePrep " + argsString
executeCommandAndCheck(splitPrepCmd)
}
// copy the count to local tmp
val hdfsCountPath = U2ITrainingTestSplitFile(hdfsRoot, appid, engineid, evalid, "u2iCount.tsv")
val localCountPath = localTempRoot + "eval-" + evalid + "-u2iCount.tsv"
val localCountFile = new File(localCountPath)
// create parent dir
localCountFile.getParentFile().mkdirs()
// delete existing file first
if (localCountFile.exists()) localCountFile.delete()
// get the count from hdfs
val getHdfsCountCmd = hadoop + " fs -getmerge " + hdfsCountPath + " " + localCountPath
executeCommandAndCheck(getHdfsCountCmd)
// read the local file and get the count
val lines = Source.fromFile(localCountPath).getLines
if (lines.isEmpty) throw new RuntimeException(s"Count file $localCountPath is empty")
val count = lines.next
// split
val splitCmd = hadoop + " jar " + pdioEvalJar + " io.prediction.evaluations.scalding.commons.u2itrainingtestsplit.U2ITrainingTestSplitTime " + argsString + " --totalCount " + count
executeCommandAndCheck(splitCmd)
// delete local tmp file
logger.info(s"Deleting temporary file $localCountPath...")
localCountFile.delete()
}
def executeCommandAndCheck(cmd: String) = {
logger.info(s"Executing $cmd...")
if ((cmd.!) != 0) throw new RuntimeException(s"Failed to execute '$cmd'")
}
}
}