| package io.prediction.evaluations.scalding.commons.u2itrainingtestsplit |
| |
| import org.specs2.mutable._ |
| |
| import com.twitter.scalding._ |
| |
| import io.prediction.commons.scalding.appdata.{Users, Items, U2iActions} |
| import io.prediction.commons.filepath.U2ITrainingTestSplitFile |
| import io.prediction.commons.appdata.{User, Item} |
| |
| class U2ITrainingTestSplitTimeTest extends Specification with TupleConversions { |
| |
| def test(itypes: List[String], trainingPercent: Double, validationPercent: Double, testPercent: Double, timeorder: Boolean, |
| appid: Int, evalid: Int, |
| items: List[(String, String, String, String, String)], |
| users: List[(String, String, String)], |
| u2iActions: List[(String, String, String, String, String)], |
| selectedItems: List[(String, String, String, String, String)], |
| selectedUsers: List[(String, String, String)], |
| selectedU2iActions: List[(String, String, String, String, String)] |
| ) = { |
| |
| val dbType = "file" |
| val dbName = "testpath/" |
| val dbHost = None |
| val dbPort = None |
| |
| val training_dbType = "file" |
| val training_dbName = "trainingsetpath/" |
| val training_dbHost = None |
| val training_dbPort = None |
| |
| val validation_dbType = "file" |
| val validation_dbName = "validationpath/" |
| val validation_dbHost = None |
| val validation_dbPort = None |
| |
| val test_dbType = "file" |
| val test_dbName = "testsetpath/" |
| val test_dbHost = None |
| val test_dbPort = None |
| |
| val hdfsRoot = "testroot/" |
| |
| val engineid = 4 |
| |
| |
| val originalCount = selectedU2iActions.size |
| |
| val totalPercent = (trainingPercent + validationPercent + testPercent) |
| val evalCount: Int = scala.math.floor(totalPercent * originalCount).toInt |
| val trainingCount: Int = scala.math.floor((trainingPercent * originalCount)).toInt |
| val validationCount: Int = scala.math.floor((validationPercent * originalCount)).toInt |
| val testCount: Int = evalCount - trainingCount - validationCount |
| /* |
| println("originalCount=" + originalCount) |
| println("evalCount="+ evalCount ) |
| println("trainingCount="+ trainingCount) |
| println("validationCount="+ validationCount) |
| println("testCount="+testCount) |
| */ |
| |
| JobTest("io.prediction.evaluations.scalding.commons.u2itrainingtestsplit.U2ITrainingTestSplitTimePrep") |
| .arg("dbType", dbType) |
| .arg("dbName", dbName) |
| .arg("training_dbType", training_dbType) |
| .arg("training_dbName", training_dbName) |
| .arg("validation_dbType", validation_dbType) |
| .arg("validation_dbName", validation_dbName) |
| .arg("test_dbType", test_dbType) |
| .arg("test_dbName", test_dbName) |
| .arg("hdfsRoot", hdfsRoot) |
| .arg("appid", appid.toString) |
| .arg("engineid", engineid.toString) |
| .arg("evalid", evalid.toString) |
| .arg("trainingPercent", trainingPercent.toString) |
| .arg("validationPercent", validationPercent.toString) |
| .arg("testPercent", testPercent.toString) |
| .arg("timeorder", timeorder.toString) |
| .source(Users(appId=appid, dbType=dbType, dbName=dbName, dbHost=dbHost, dbPort=dbPort).getSource, users) |
| .source(Items(appId=appid, itypes=Some(itypes), dbType=dbType, dbName=dbName, dbHost=dbHost, dbPort=dbPort).getSource, items) |
| .source(U2iActions(appId=appid, dbType=dbType, dbName=dbName, dbHost=dbHost, dbPort=dbPort).getSource, u2iActions) |
| .sink[(String, String, String)](Users(appId=evalid, dbType=training_dbType, dbName=training_dbName, dbHost=training_dbHost, dbPort=training_dbPort).getSource) { outputBuffer => |
| "correctly write trainingUsers" in { |
| outputBuffer must containTheSameElementsAs(selectedUsers) |
| } |
| } |
| .sink[(String, String, String, String, String)](Items(appId=evalid, itypes=None, dbType=training_dbType, dbName=training_dbName, dbHost=training_dbHost, dbPort=training_dbPort).getSource) { outputBuffer => |
| "correctly write trainingItems" in { |
| outputBuffer must containTheSameElementsAs(selectedItems) |
| } |
| } |
| .sink[(String, String, String, String, String)](U2iActions(appId=evalid, |
| dbType="file", dbName=U2ITrainingTestSplitFile(hdfsRoot, appid, engineid, evalid, ""), dbHost=None, dbPort=None).getSource) { outputBuffer => |
| "correctly write u2iActions" in { |
| outputBuffer must containTheSameElementsAs(selectedU2iActions) |
| } |
| } |
| .sink[(Int)](Tsv(U2ITrainingTestSplitFile(hdfsRoot, appid, engineid, evalid, "u2iCount.tsv"))) { outputBuffer => |
| "correctly write u2iActions count" in { |
| outputBuffer must containTheSameElementsAs(List(originalCount)) |
| } |
| } |
| .run |
| .finish |
| |
| def splitTest() = { |
| |
| val results = scala.collection.mutable.Map[String, List[(String, String, String, String, String)]]() |
| |
| JobTest("io.prediction.evaluations.scalding.commons.u2itrainingtestsplit.U2ITrainingTestSplitTime") |
| .arg("dbType", dbType) |
| .arg("dbName", dbName) |
| .arg("training_dbType", training_dbType) |
| .arg("training_dbName", training_dbName) |
| .arg("validation_dbType", validation_dbType) |
| .arg("validation_dbName", validation_dbName) |
| .arg("test_dbType", test_dbType) |
| .arg("test_dbName", test_dbName) |
| .arg("hdfsRoot", hdfsRoot) |
| .arg("appid", appid.toString) |
| .arg("engineid", engineid.toString) |
| .arg("evalid", evalid.toString) |
| .arg("trainingPercent", trainingPercent.toString) |
| .arg("validationPercent", validationPercent.toString) |
| .arg("testPercent", testPercent.toString) |
| .arg("timeorder", timeorder.toString) |
| .arg("totalCount", originalCount.toString) |
| .source(U2iActions(appId=evalid, |
| dbType="file", dbName=U2ITrainingTestSplitFile(hdfsRoot, appid, engineid, evalid, ""), dbHost=None, dbPort=None).getSource, selectedU2iActions) |
| .sink[(String, String, String, String, String)](U2iActions(appId=evalid, |
| dbType=training_dbType, dbName=training_dbName, dbHost=training_dbHost, dbPort=training_dbPort).getSource) { outputBuffer => |
| "generate training set" in { |
| val output = outputBuffer.toList |
| results += ("training" -> output) // remember the output for later checking purpose |
| |
| // note: since the selection is random, can't know the expected selection beforehand. |
| // so just check if the original data contain the selected data and the size is correct. |
| // Randomness and time order is checked in later stages. |
| selectedU2iActions must containAllOf(output) and |
| (output.size must be_==(trainingCount)) |
| } |
| } |
| .sink[(String, String, String, String, String)](U2iActions(appId=evalid, |
| dbType=validation_dbType, dbName=validation_dbName, dbHost=validation_dbHost, dbPort=validation_dbPort).getSource) { outputBuffer => |
| "generate validation set" in { |
| val output = outputBuffer.toList |
| results += ("validation" -> output) |
| selectedU2iActions must containAllOf(output) and |
| (output.size must be_==(validationCount)) |
| } |
| } |
| .sink[(String, String, String, String, String)](U2iActions(appId=evalid, |
| dbType=test_dbType, dbName=test_dbName, dbHost=test_dbHost, dbPort=test_dbPort).getSource) { outputBuffer => |
| "generate test set" in { |
| val output = outputBuffer.toList |
| results += ("test" -> output) |
| selectedU2iActions must containAllOf(output) and |
| (output.size must be_==(testCount)) |
| } |
| } |
| .run |
| .finish |
| |
| "all sets are mutually exclusive" in { |
| (results("training") must not(containAnyOf(results("validation")))) and |
| (results("training") must not(containAnyOf(results("test")))) and |
| (results("validation") must not(containAnyOf(results("test")))) |
| } |
| |
| |
| def getTimeOnly(dataSet: List[(String, String, String, String, String)]): List[Long] = { |
| dataSet map {case (action, uid, iid, t, v) => t.toLong} |
| } |
| |
| if (timeorder) { |
| // check time order |
| if (validationPercent != 0) { |
| "validation set must be newer than training set" in { |
| getTimeOnly(results("validation")).min must be_>=(getTimeOnly(results("training")).max) |
| } |
| "test set must be newer than validation set" in { |
| getTimeOnly(results("test")).min must be_>=(getTimeOnly(results("validation")).max) |
| } |
| } |
| |
| "test set must be newer than training set" in { |
| getTimeOnly(results("test")).min must be_>=(getTimeOnly(results("training")).max) |
| } |
| } |
| |
| results |
| } |
| |
| val firstSplit = splitTest() |
| val secondSplit = splitTest() |
| |
| // simple check for randomness |
| if (timeorder) { |
| "at least one set of two split is different" in { |
| // for timeorder=true case, some sets may still be the same even resplit 2nd time |
| // because the original data is small, we select most of them (say > 90%) and |
| // split according to time order. The chance of ending up same data in the set is high. |
| // so here just do simple check: as long as 1 set is different, consider OK. |
| // (it's possible to check all difference if the test input data is large enough and selected percentage is relative small.) |
| (firstSplit("training") must not(containTheSameElementsAs(secondSplit("training")))) or |
| (firstSplit("validation") must not(containTheSameElementsAs(secondSplit("validation")))) or |
| (firstSplit("test") must not(containTheSameElementsAs(secondSplit("test")))) |
| } |
| } else { |
| "all sets of two splits are different" in { |
| if (validationPercent == 0) { |
| // don't check validation set since it is empty |
| (firstSplit("training") must not(containTheSameElementsAs(secondSplit("training")))) and |
| (firstSplit("test") must not(containTheSameElementsAs(secondSplit("test")))) |
| } else { |
| (firstSplit("training") must not(containTheSameElementsAs(secondSplit("training")))) and |
| (firstSplit("validation") must not(containTheSameElementsAs(secondSplit("validation")))) and |
| (firstSplit("test") must not(containTheSameElementsAs(secondSplit("test")))) |
| } |
| } |
| } |
| |
| } |
| |
| val appid = 2 |
| val evalid = 101 |
| val users = List( |
| (appid+"_u0", appid.toString, "123456"), |
| (appid+"_u1", appid.toString, "23456"), |
| (appid+"_u2", appid.toString, "455677"), |
| (appid+"_u3", appid.toString, "876563111")) |
| |
| val items = List( |
| (appid+"_i0", "t1,t2,t3", appid.toString, "2293300", "1266673"), |
| (appid+"_i1", "t2,t3", appid.toString, "14526361", "12345135"), |
| (appid+"_i2", "t4", appid.toString, "14526361", "23423424"), |
| (appid+"_i3", "t3,t4", appid.toString, "1231415", "378462511")) |
| |
| val u2iActions = List( |
| ("4", appid+"_u0", appid+"_i1", "1234500", "5"), |
| ("3", appid+"_u3", appid+"_i0", "1234505", "1"), |
| ("4", appid+"_u1", appid+"_i3", "1234501", "3"), |
| ("4", appid+"_u1", appid+"_i2", "1234506", "4"), |
| ("2", appid+"_u1", appid+"_i0", "1234507", "5"), |
| ("3", appid+"_u2", appid+"_i3", "1234502", "2"), |
| ("4", appid+"_u0", appid+"_i2", "1234508", "3"), |
| ("4", appid+"_u2", appid+"_i0", "1234509", "1"), |
| ("4", appid+"_u0", appid+"_i1", "1234503", "2"), |
| ("4", appid+"_u3", appid+"_i3", "1234504", "3"), |
| ("4", appid+"_u3", appid+"_i3", "1234503", "3"), |
| ("4", appid+"_u2", appid+"_i3", "1234504", "3"), |
| ("4", appid+"_u1", appid+"_i3", "1234505", "3"), |
| ("4", appid+"_u0", appid+"_i3", "1234509", "3"), |
| ("view", appid+"_u0", appid+"_i0", "1234509", "PIO_NONE"), // test missing v field case (non-rate action) |
| ("like", appid+"_u1", appid+"_i2", "1234509", "PIO_NONE")) // test missing v field case (non-rate action) |
| |
| val selectedUsers = List( |
| (evalid+"_u0", evalid.toString, "123456"), |
| (evalid+"_u1", evalid.toString, "23456"), |
| (evalid+"_u2", evalid.toString, "455677"), |
| (evalid+"_u3", evalid.toString, "876563111")) |
| |
| val selectedItemsAll = List( |
| (evalid+"_i0", "t1,t2,t3", evalid.toString, "2293300", "1266673"), |
| (evalid+"_i1", "t2,t3", evalid.toString, "14526361", "12345135"), |
| (evalid+"_i2", "t4", evalid.toString, "14526361", "23423424"), |
| (evalid+"_i3", "t3,t4", evalid.toString, "1231415", "378462511")) |
| |
| val selectedU2iActions = List( |
| ("4", evalid+"_u0", evalid+"_i1", "1234500", "5"), |
| ("3", evalid+"_u3", evalid+"_i0", "1234505", "1"), |
| ("4", evalid+"_u1", evalid+"_i3", "1234501", "3"), |
| ("4", evalid+"_u1", evalid+"_i2", "1234506", "4"), |
| ("2", evalid+"_u1", evalid+"_i0", "1234507", "5"), |
| ("3", evalid+"_u2", evalid+"_i3", "1234502", "2"), |
| ("4", evalid+"_u0", evalid+"_i2", "1234508", "3"), |
| ("4", evalid+"_u2", evalid+"_i0", "1234509", "1"), |
| ("4", evalid+"_u0", evalid+"_i1", "1234503", "2"), |
| ("4", evalid+"_u3", evalid+"_i3", "1234504", "3"), |
| ("4", evalid+"_u3", evalid+"_i3", "1234503", "3"), |
| ("4", evalid+"_u2", evalid+"_i3", "1234504", "3"), |
| ("4", evalid+"_u1", evalid+"_i3", "1234505", "3"), |
| ("4", evalid+"_u0", evalid+"_i3", "1234509", "3"), |
| ("view", evalid+"_u0", evalid+"_i0", "1234509", "PIO_NONE"), |
| ("like", evalid+"_u1", evalid+"_i2", "1234509", "PIO_NONE")) |
| |
| "U2ITrainingTestSplitTimeTest with timeorder=true" should { |
| test(List(""), 0.4, 0.3, 0.2, true, appid, evalid, |
| items, |
| users, |
| u2iActions, |
| selectedItemsAll, |
| selectedUsers, |
| selectedU2iActions |
| ) |
| |
| } |
| |
| "U2ITrainingTestSplitTimeTest with timeorder=false" should { |
| test(List(""), 0.3, 0.2, 0.3, false, appid, evalid, |
| items, |
| users, |
| u2iActions, |
| selectedItemsAll, |
| selectedUsers, |
| selectedU2iActions |
| ) |
| } |
| |
| "U2ITrainingTestSplitTimeTest with timeorder=true and validation=0" should { |
| test(List(""), 0.6, 0, 0.1, true, appid, evalid, |
| items, |
| users, |
| u2iActions, |
| selectedItemsAll, |
| selectedUsers, |
| selectedU2iActions |
| ) |
| } |
| |
| "U2ITrainingTestSplitTimeTest with timeorder=false and validation=0" should { |
| test(List(""), 0.6, 0, 0.4, false, appid, evalid, |
| items, |
| users, |
| u2iActions, |
| selectedItemsAll, |
| selectedUsers, |
| selectedU2iActions |
| ) |
| } |
| |
| } |