blob: 4e321df9dee94d7a213ff5d895a5525f008a85fd [file] [log] [blame]
package io.prediction.evaluations.scalding.commons.u2itrainingtestsplit
import com.twitter.scalding._
import io.prediction.commons.scalding.appdata.{ Users, Items, U2iActions }
import io.prediction.commons.filepath.U2ITrainingTestSplitFile
import io.prediction.commons.appdata.{ User, Item }
/**
* Description:
* Split u2i into training, validation and test set
*
* Args:
* same as TrainingtestsplitCommon, plus additional args:
* --totalCount <int> total u2i actions count
*/
class U2ITrainingTestSplitTime(args: Args) extends U2ITrainingTestSplitCommon(args) {
val totalCountArg = args("totalCount").toInt // total u2i count
// evaluationPercent is sum of trainingPercentArg + validationPercentArg + testPercentArg
val evaluationCount: Int = (scala.math.floor(evaluationPercent * totalCountArg)).toInt
val trainingCount: Int = (scala.math.floor(trainingPercentArg * totalCountArg)).toInt
val validationCount: Int = (scala.math.floor(validationPercentArg * totalCountArg)).toInt
val trainingValidationCount: Int = trainingCount + validationCount
val testCount = evaluationCount - trainingValidationCount
require((trainingCount >= 1), "Not enough data for training set. trainingCount = " + trainingCount)
if (validationPercentArg != 0) {
require((validationCount >= 1), "Not enough data for validation set. validationCount = " + validationCount)
}
require((testCount >= 1), "Not enough data for test set. testCount = " + testCount)
/**
* source
*/
// data generated at prep stage
val u2iSource = U2iActions(appId = evalidArg,
dbType = "file", dbName = U2ITrainingTestSplitFile(hdfsRootArg, appidArg, engineidArg, evalidArg, ""), dbHost = Seq(), dbPort = Seq())
/**
* sink
*/
val trainingU2iSink = U2iActions(appId = evalidArg,
dbType = training_dbTypeArg, dbName = training_dbNameArg, dbHost = training_dbHostArg, dbPort = training_dbPortArg)
val validationU2iSink = U2iActions(appId = evalidArg,
dbType = validation_dbTypeArg, dbName = validation_dbNameArg, dbHost = validation_dbHostArg, dbPort = validation_dbPortArg)
// sink to test_appadta
val testU2iSink = U2iActions(appId = evalidArg,
dbType = test_dbTypeArg, dbName = test_dbNameArg, dbHost = test_dbHostArg, dbPort = test_dbPortArg)
/**
* computation
*/
val randomU2i = if (timeorderArg) {
// shuffle, take and then sort
u2iSource.readData('action, 'uid, 'iid, 't, 'v)
.shuffle(11)
.groupAll(_.take(evaluationCount))
.groupAll(_.sortBy('t)) // NOTE: small to largest (oldest first, so training set should be taken first)
} else {
// shuffle and then take
u2iSource.readData('action, 'uid, 'iid, 't, 'v)
.shuffle(11)
.groupAll(_.take(evaluationCount))
}
// split
val trainingOrValidation = randomU2i.groupAll(_.take(trainingValidationCount))
trainingOrValidation.groupAll(_.take(trainingCount))
.then(trainingU2iSink.writeData('action, 'uid, 'iid, 't, 'v, evalidArg) _) // NOTE: appid is replaced by evalid
trainingOrValidation.groupAll(_.drop(trainingCount))
.then(validationU2iSink.writeData('action, 'uid, 'iid, 't, 'v, evalidArg) _) // NOTE: appid is replaced by evalid
randomU2i.groupAll(_.drop(trainingValidationCount))
.then(testU2iSink.writeData('action, 'uid, 'iid, 't, 'v, evalidArg) _) // NOTE: appid is replaced by evalid
}