blob: d0ec835c37ba05f6f21174b7f8cff68c93bfde19 [file] [log] [blame]
package io.prediction.algorithms.graphchi.itemsim
import grizzled.slf4j.Logger
import breeze.linalg._
import com.twitter.scalding.Args
import scala.io.Source
import io.prediction.commons.Config
import io.prediction.commons.modeldata.{ ItemSimScore }
/**
* Input files:
* - itemsIndex.tsv (iindex iid itypes): all items
* - validItemsIndex.tsv (iindex): valid candidate items to be recommended
* - ratings.mm-topk (iindex1 iindex2 score) generated by GraphChi
*
*/
object GraphChiModelConstructor {
/* global */
val logger = Logger(GraphChiModelConstructor.getClass)
//println(logger.isInfoEnabled)
val commonsConfig = new Config
// argument of this job
case class JobArg(
val inputDir: String,
val appid: Int,
val algoid: Int,
val evalid: Option[Int],
val modelSet: Boolean,
val numSimilarItems: Int)
def main(cmdArgs: Array[String]) {
logger.info("Running model constructor for GraphChi ItemSim ...")
logger.info(cmdArgs.mkString(","))
/* get arg */
val args = Args(cmdArgs)
val arg = JobArg(
inputDir = args("inputDir"),
appid = args("appid").toInt,
algoid = args("algoid").toInt,
evalid = args.optional("evalid") map (x => x.toInt),
modelSet = args("modelSet").toBoolean,
numSimilarItems = args("numSimilarItems").toInt
)
/* run job */
modelCon(arg)
cleanUp(arg)
}
def modelCon(arg: JobArg) = {
// NOTE: if OFFLINE_EVAL, write to training modeldata and use evalid as appid
val OFFLINE_EVAL = (arg.evalid != None)
val modeldataDb = if (!OFFLINE_EVAL)
commonsConfig.getModeldataItemSimScores
else
commonsConfig.getModeldataTrainingItemSimScores
val appid = if (OFFLINE_EVAL) arg.evalid.get else arg.appid
case class ItemData(
val iid: String,
val itypes: Seq[String])
// item index file (iindex iid itypes)
// iindex -> ItemData
val itemsMap: Map[Int, ItemData] = Source.fromFile(s"${arg.inputDir}itemsIndex.tsv")
.getLines()
.map[(Int, ItemData)] { line =>
val (iindex, item) = try {
val fields = line.split("\t")
val itemData = ItemData(
iid = fields(1),
itypes = fields(2).split(",").toList
)
(fields(0).toInt, itemData)
} catch {
case e: Exception => {
throw new RuntimeException(s"Cannot get item info in line: ${line}. ${e}")
}
}
(iindex, item)
}.toMap
// valid item index file (iindex iid itypes)
// iindex
val validItemsSet: Set[Int] = Source.fromFile(s"${arg.inputDir}validItemsIndex.tsv")
.getLines()
.map[Int] { line =>
val iindex = try {
val fields = line.split("\t")
fields(0).toInt
} catch {
case e: Exception => {
throw new RuntimeException(s"Cannot get item info in line: ${line}. ${e}")
}
}
iindex
}.toSet
// iindex1 iindex2 score
val simScores = Source.fromFile(s"${arg.inputDir}ratings.mm-topk")
.getLines()
.map[(Int, Int, Double)] { line =>
val (iindex1, iindex2, score) = try {
val fields = line.split("""\s+""")
(fields(0).toInt, fields(1).toInt, fields(2).toDouble)
} catch {
case e: Exception => throw new RuntimeException(s"Cannot read item index and score from the line: ${line}. ${e}")
}
(iindex1, iindex2, score)
}.toSeq
val similarities = simScores ++
simScores.map { case (iindex1, iindex2, score) => (iindex2, iindex1, score) }
// iindex1 -> Seq[(iindex1, iindex2, score)]
similarities.groupBy(_._1).foreach {
case (iindex1, scoresSeq) =>
// only recommend items in validItems
val topScores = scoresSeq.filter { x => validItemsSet(x._2) }
.sortBy(_._3)(Ordering[Double].reverse)
.take(arg.numSimilarItems)
logger.debug(s"${iindex1}: ${topScores.toList}")
modeldataDb.insert(ItemSimScore(
iid = itemsMap(iindex1).iid,
simiids = topScores.map(x => itemsMap(x._2).iid),
scores = topScores.map(_._3),
itypes = topScores.map(x => itemsMap(x._2).itypes),
appid = appid,
algoid = arg.algoid,
modelset = arg.modelSet))
}
}
def cleanUp(arg: JobArg) = {
}
}