blob: 0963bac1b489bb2425e32ed2425fb871f1b78bc2 [file] [log] [blame]
package io.prediction.metrics.scalding.itemsim.ismap
import com.twitter.scalding._
import cascading.pipe.joiner.LeftJoin
import cascading.pipe.joiner.RightJoin
import io.prediction.commons.filepath.OfflineMetricFile
import io.prediction.commons.scalding.settings.OfflineEvalResults
/**
* Source:
* relevantUsers.tsv
* iid uid
* i0 u0
* i0 u1
* i0 u2
* relevantItems.tsv
* uid iid
* u0 i0
* u0 i1
* u0 i2
* topKItems.tsv
* iid simiid score
* i0 i1 3.2
* i0 i4 2.5
* i0 i5 1.4
*
* Sink:
* offlineEvalResults DB
* averagePrecision.tsv
* iid ap
* i0 0.03
*
*
* Description:
* Calculate Item Similarity Mean Average Precision @ k score
* There is an assumption that the number of missing similar items per item is always equal to k.
*
* Required args:
* --dbType: <string> The OfflineEvalResults DB Type (eg. mongodb) (see --dbHost, --dbPort)
* --dbName: <string>
*
* --hdfsRoot: <string>. Root directory of the HDFS
*
* --appid: <int>
* --engineid: <int>
* --evalid: <int>
* --metricid: <int>
* --algoid: <int>
* --iteration: <int>
* --splitset: <string>
*
* --kParam: <int>
*
* Optional args:
* --dbHost: <string> (eg. "127.0.0.1")
* --dbPort: <int> (eg. 27017)
*
* --debug: <String>. "test" - for testing purpose
*
* Example:
* scald.rb --hdfs-local io.prediction.metrics.scalding.itemsim.ismap.ISMAPAtK --dbType mongodb --dbName predictionio --dbHost 127.0.0.1 --dbPort 27017 --hdfsRoot hdfs/predictionio/ --appid 34 --engineid 3 --evalid 15 --metricid 10 --algoid 9 --kParam 30
*/
class ISMAPAtK(args: Args) extends Job(args) {
/** parse args */
val dbTypeArg = args("dbType")
val dbNameArg = args("dbName")
val dbHostArg = args.list("dbHost")
val dbPortArg = args.list("dbPort") map (x => x.toInt)
val hdfsRootArg = args("hdfsRoot")
val appidArg = args("appid").toInt
val engineidArg = args("engineid").toInt
val evalidArg = args("evalid").toInt
val metricidArg = args("metricid").toInt
val algoidArg = args("algoid").toInt
val iterationArg = args.getOrElse("iteration", "1").toInt
val splitsetArg = args.getOrElse("splitset", "")
val kParamArg = args("kParam").toInt
val debugArg = args.list("debug")
val DEBUG_TEST = debugArg.contains("test") // test mode
/** sources */
val relevantUsers = TypedTsv[(String, String)](OfflineMetricFile(hdfsRootArg, appidArg, engineidArg, evalidArg, metricidArg, algoidArg, "relevantUsers.tsv"), ('ruiid, 'ruuid))
val relevantItems = TypedTsv[(String, String)](OfflineMetricFile(hdfsRootArg, appidArg, engineidArg, evalidArg, metricidArg, algoidArg, "relevantItems.tsv"), ('riuid, 'riiid))
val relevantItemsAsList = relevantItems.groupBy('riuid) { _.toList[String]('riiid -> 'riiids) }
val topKItems = TypedTsv[(String, String, Double)](OfflineMetricFile(hdfsRootArg, appidArg, engineidArg, evalidArg, metricidArg, algoidArg, "topKItems.tsv"), ('iid, 'simiid, 'score))
/** sinks */
val averagePrecisionSink = Tsv(OfflineMetricFile(hdfsRootArg, appidArg, engineidArg, evalidArg, metricidArg, algoidArg, "averagePrecision.tsv"))
val offlineEvalResultsSink = OfflineEvalResults(dbType = dbTypeArg, dbName = dbNameArg, dbHost = dbHostArg, dbPort = dbPortArg)
/** computation */
val itemsMapAtK = topKItems
.joinWithSmaller('iid -> 'ruiid, relevantUsers)
.joinWithSmaller('ruuid -> 'riuid, relevantItemsAsList, joiner = new LeftJoin)
.groupBy('iid, 'ruuid) {
_.sortBy('score).reverse.scanLeft(('simiid, 'riiids) -> ('precision, 'hit, 'count))((0.0, 0, 0)) {
(newFields: (Double, Int, Int), fields: (String, List[String])) =>
val (simiid, riiids) = fields
val (precision, hit, count) = newFields
Option(riiids) map { r =>
if (r.contains(simiid)) {
((hit + 1).toDouble / (count + 1).toDouble, hit + 1, count + 1)
} else {
(0.0, hit, count + 1)
}
} getOrElse {
(0.0, hit, count + 1)
}
}
}
.filter('count) { count: Int => count > 0 }
.groupBy('iid, 'ruuid) { _.average('precision) }
.groupBy('iid) { _.average('precision) }
itemsMapAtK.write(averagePrecisionSink)
itemsMapAtK.groupAll { _.average('precision) }
.mapTo('precision -> ('evalid, 'metricid, 'algoid, 'score, 'iteration, 'splitset)) { precision: Double =>
(evalidArg, metricidArg, algoidArg, precision, iterationArg, splitsetArg)
}
.then(offlineEvalResultsSink.writeData('evalid, 'metricid, 'algoid, 'score, 'iteration, 'splitset) _)
}