blob: 641de3b99e44c4732de5adecc81e36cc32e05093 [file] [log] [blame]
package io.prediction.algorithms.scalding.mahout.itemsim
import com.twitter.scalding._
import io.prediction.commons.scalding.appdata.{ Users, Items, U2iActions }
import io.prediction.commons.filepath.DataFile
import io.prediction.commons.appdata.{ Item }
import org.slf4j.{ Logger, LoggerFactory }
* Source:
* Sink:
* Descripton:
* Prepare data for Mahout Item Similarity algo
* Required args:
* --dbType: <string> (eg. mongodb) (see --dbHost, --dbPort)
* --dbName: <string> appdata database name. (eg predictionio_appdata, or predictionio_training_appdata)
* --hdfsRoot: <string>. Root directory of the HDFS
* --appid: <int>
* --engineid: <int>
* --algoid: <int>
* --viewParam: <string>. (number 1 to 5, or "ignore")
* --likeParam: <string>
* --dislikeParam: <string>
* --conversionParam: <string>
* --conflictParam: <string>. (latest/highest/lowest)
* Optional args:
* --dbHost: <string> (eg. "")
* --dbPort: <int> (eg. 27017)
* --itypes: <string separated by white space>. eg "--itypes type1 type2". If no --itypes specified, then ALL itypes will be used.
* --evalid: <int>. Offline Evaluation if evalid is specified
* --debug: <String>. "test" - for testing purpose
* Example:
class DataPreparatorCommon(args: Args) extends Job(args) {
* parse arguments
val dbTypeArg = args("dbType")
val dbNameArg = args("dbName")
val dbHostArg = args.list("dbHost")
val dbPortArg = args.list("dbPort") map (x => x.toInt) // becomes Option[Int]
val hdfsRootArg = args("hdfsRoot")
val appidArg = args("appid").toInt
val engineidArg = args("engineid").toInt
val algoidArg = args("algoid").toInt
val evalidArg = args.optional("evalid") map (x => x.toInt)
val OFFLINE_EVAL = (evalidArg != None) // offline eval mode
val preItypesArg = args.list("itypes")
val itypesArg: Option[List[String]] = if (preItypesArg.mkString(",").length == 0) None else Option(preItypesArg)
// determine how to map actions to rating values
def getActionParam(name: String): Option[Int] = {
val actionParam: Option[Int] = args(name) match {
case "ignore" => None
case x => Some(x.toInt)
val viewParamArg: Option[Int] = getActionParam("viewParam")
val likeParamArg: Option[Int] = getActionParam("likeParam")
val dislikeParamArg: Option[Int] = getActionParam("dislikeParam")
val conversionParamArg: Option[Int] = getActionParam("conversionParam")
// When there are conflicting actions, e.g. a user gives an item a rating 5 but later dislikes it,
// determine which action will be considered as final preference.
final val CONFLICT_LATEST: String = "latest" // use latest action
final val CONFLICT_HIGHEST: String = "highest" // use the one with highest score
final val CONFLICT_LOWEST: String = "lowest" // use the one with lowest score
val conflictParamArg: String = args("conflictParam")
// check if the conflictParam is valid
require(List(CONFLICT_LATEST, CONFLICT_HIGHEST, CONFLICT_LOWEST).contains(conflictParamArg), "conflict param " + conflictParamArg + " is not valid.")
val debugArg = args.list("debug")
val DEBUG_TEST = debugArg.contains("test") // test mode
// NOTE: if OFFLINE_EVAL, read from training set, and use evalid as appid when read Items and U2iActions
val trainingAppid = if (OFFLINE_EVAL) evalidArg.get else appidArg
lazy val logger: Logger = LoggerFactory.getLogger(this.getClass)
class DataCopy(args: Args) extends DataPreparatorCommon(args) {
* source
val items = Items(appId = trainingAppid, itypes = itypesArg,
dbType = dbTypeArg, dbName = dbNameArg, dbHost = dbHostArg, dbPort = dbPortArg).readObj('item)
val users = Users(appId = trainingAppid,
dbType = dbTypeArg, dbName = dbNameArg, dbHost = dbHostArg, dbPort = dbPortArg).readData('uid)
* sink
val userIdSink = Tsv(DataFile(hdfsRootArg, appidArg, engineidArg, algoidArg, evalidArg, "userIds.tsv"))
val selectedItemSink = Tsv(DataFile(hdfsRootArg, appidArg, engineidArg, algoidArg, evalidArg, "selectedItems.tsv"))
* computation
items.mapTo('item -> ('iidx, 'itypes, 'starttime, 'endtime, 'inactive)) {
item: Item =>
// NOTE: convert List[String] into comma-separated String
// NOTE: endtime is optional
class DataPreparator(args: Args) extends DataPreparatorCommon(args) {
* constants
final val ACTION_RATE = "rate"
final val ACTION_LIKE = "like"
final val ACTION_DISLIKE = "dislike"
final val ACTION_VIEW = "view"
//final val ACTION_VIEWDETAILS = "viewDetails"
final val ACTION_CONVERSION = "conversion"
* source
val u2i = U2iActions(appId = trainingAppid,
dbType = dbTypeArg, dbName = dbNameArg, dbHost = dbHostArg, dbPort = dbPortArg).readData('action, 'uid, 'iid, 't, 'v)
// use byte offset as index for Mahout algo
val itemsIndex = TextLine(DataFile(hdfsRootArg, appidArg, engineidArg, algoidArg, evalidArg, "selectedItems.tsv")).read
.mapTo(('offset, 'line) -> ('iindex, 'iidx, 'itypes, 'starttime, 'endtime, 'inactive)) { fields: (String, String) =>
val (offset, line) = fields
val lineArray = line.split("\t")
val (iidx, itypes, starttime, endtime, inactive) = try {
(lineArray(0), lineArray(1), lineArray(2), lineArray(3), lineArray(4))
} catch {
case e: Exception => {
assert(false, "Failed to extract iidx and itypes from the line: " + line + ". Exception: " + e)
(0, "dummy", "dummy", "dummy")
(offset, iidx, itypes, starttime, endtime, inactive)
val usersIndex = TextLine(DataFile(hdfsRootArg, appidArg, engineidArg, algoidArg, evalidArg, "userIds.tsv")).read
.rename(('offset, 'line) -> ('uindex, 'uidx))
* sink
val itemsIndexSink = Tsv(DataFile(hdfsRootArg, appidArg, engineidArg, algoidArg, evalidArg, "itemsIndex.tsv"))
val usersIndexSink = Tsv(DataFile(hdfsRootArg, appidArg, engineidArg, algoidArg, evalidArg, "usersIndex.tsv"))
val ratingsSink = Csv(DataFile(hdfsRootArg, appidArg, engineidArg, algoidArg, evalidArg, "ratings.csv"))
* computation
// filter and pre-process actions
u2i.joinWithSmaller('iid -> 'iidx, itemsIndex) // only select actions of these items
.filter('action, 'v) { fields: (String, Option[String]) =>
val (action, v) = fields
val keepThis: Boolean = action match {
case ACTION_RATE => true
case ACTION_LIKE => (likeParamArg != None)
case ACTION_DISLIKE => (dislikeParamArg != None)
case ACTION_VIEW => (viewParamArg != None)
case ACTION_CONVERSION => (conversionParamArg != None)
case _ => {
logger.debug(s"Found custom action ${action}")
false // all other unsupported actions
.map(('action, 'v, 't) -> ('rating, 'tLong)) { fields: (String, Option[String], String) =>
val (action, v, t) = fields
// convert actions into rating value based on "action" and "v" fields
val rating: Int = action match {
case ACTION_RATE => try {
} catch {
case e: Exception => {
assert(false, s"Failed to convert v field ${v} to integer for ${action} action. Exception:" + e)
case ACTION_LIKE => likeParamArg.getOrElse {
assert(false, "Action type " + action + " should have been filtered out!")
case ACTION_DISLIKE => dislikeParamArg.getOrElse {
assert(false, "Action type " + action + " should have been filtered out!")
case ACTION_VIEW => viewParamArg.getOrElse {
assert(false, "Action type " + action + " should have been filtered out!")
case ACTION_CONVERSION => conversionParamArg.getOrElse {
assert(false, "Action type " + action + " should have been filtered out!")
case _ => { // all other unsupported actions
assert(false, "Action type " + action + " in u2iActions appdata is not supported!")
(rating, t.toLong)
.then(resolveConflict('uid, 'iid, 'tLong, 'rating, conflictParamArg) _)
.joinWithSmaller('uid -> 'uidx, usersIndex)
.project('uindex, 'iindex, 'rating)
.write(ratingsSink) // write ratings to a file
* function to resolve conflicting actions of same uid-iid pair.
def resolveConflict(uidField: Symbol, iidField: Symbol, tfield: Symbol, ratingField: Symbol, conflictSolution: String)(p: RichPipe): RichPipe = {
// NOTE: sortBy() sort from smallest to largest. use reverse to pick the largest one.
val dataPipe = conflictSolution match {
case CONFLICT_LATEST => p.groupBy(uidField, iidField) { _.sortBy(tfield).reverse.take(1) } // take latest one (largest t)
case CONFLICT_HIGHEST => p.groupBy(uidField, iidField) { _.sortBy(ratingField).reverse.take(1) } // take highest rating
case CONFLICT_LOWEST => p.groupBy(uidField, iidField) { _.sortBy(ratingField).take(1) } // take lowest rating