blob: 56c5e95d83f8e359d572756834909b707772316c [file] [log] [blame]
package org.example.recommendation
import org.apache.predictionio.controller.PAlgorithm
import org.apache.predictionio.controller.Params
import org.apache.predictionio.data.storage.BiMap
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.recommendation.ALS
import org.apache.spark.mllib.recommendation.{Rating => MLlibRating}
import org.apache.spark.mllib.recommendation.ALSModel
import grizzled.slf4j.Logger
case class ALSAlgorithmParams(
rank: Int,
numIterations: Int,
lambda: Double,
seed: Option[Long]) extends Params
class ALSAlgorithm(val ap: ALSAlgorithmParams)
extends PAlgorithm[PreparedData, ALSModel, Query, PredictedResult] {
@transient lazy val logger = Logger[this.type]
if (ap.numIterations > 30) {
logger.warn(
s"ALSAlgorithmParams.numIterations > 30, current: ${ap.numIterations}. " +
s"There is a chance of running to StackOverflowException." +
s"To remedy it, set lower numIterations or checkpoint parameters.")
}
override
def train(sc: SparkContext, data: PreparedData): ALSModel = {
// MLLib ALS cannot handle empty training data.
require(!data.ratings.take(1).isEmpty,
s"RDD[Rating] in PreparedData cannot be empty." +
" Please check if DataSource generates TrainingData" +
" and Preparator generates PreparedData correctly.")
// Convert user and item String IDs to Int index for MLlib
val userStringIntMap = BiMap.stringInt(data.ratings.map(_.user))
val itemStringIntMap = BiMap.stringInt(data.ratings.map(_.item))
val mllibRatings = data.ratings.map( r =>
// MLlibRating requires integer index for user and item
MLlibRating(userStringIntMap(r.user), itemStringIntMap(r.item), r.rating)
)
// seed for MLlib ALS
val seed = ap.seed.getOrElse(System.nanoTime)
// Set checkpoint directory
// sc.setCheckpointDir("checkpoint")
// If you only have one type of implicit event (Eg. "view" event only),
// set implicitPrefs to true
val implicitPrefs = false
val als = new ALS()
als.setUserBlocks(-1)
als.setProductBlocks(-1)
als.setRank(ap.rank)
als.setIterations(ap.numIterations)
als.setLambda(ap.lambda)
als.setImplicitPrefs(implicitPrefs)
als.setAlpha(1.0)
als.setSeed(seed)
als.setCheckpointInterval(10)
val m = als.run(mllibRatings)
new ALSModel(
rank = m.rank,
userFeatures = m.userFeatures,
productFeatures = m.productFeatures,
userStringIntMap = userStringIntMap,
itemStringIntMap = itemStringIntMap)
}
override
def predict(model: ALSModel, query: Query): PredictedResult = {
// Convert String ID to Int index for Mllib
model.userStringIntMap.get(query.user).map { userInt =>
// create inverse view of itemStringIntMap
val itemIntStringMap = model.itemStringIntMap.inverse
// recommendProducts() returns Array[MLlibRating], which uses item Int
// index. Convert it to String ID for returning PredictedResult
val itemScores = model.recommendProducts(userInt, query.num)
.map (r => ItemScore(itemIntStringMap(r.product), r.rating))
PredictedResult(itemScores)
}.getOrElse{
logger.info(s"No prediction for unknown user ${query.user}.")
PredictedResult(Array.empty)
}
}
// This function is used by the evaluation module, where a batch of queries is sent to this engine
// for evaluation purpose.
override def batchPredict(model: ALSModel, queries: RDD[(Long, Query)]): RDD[(Long, PredictedResult)] = {
val userIxQueries: RDD[(Int, (Long, Query))] = queries
.map { case (ix, query) => {
// If user not found, then the index is -1
val userIx = model.userStringIntMap.get(query.user).getOrElse(-1)
(userIx, (ix, query))
}}
// Cross product of all valid users from the queries and products in the model.
val usersProducts: RDD[(Int, Int)] = userIxQueries
.keys
.filter(_ != -1)
.cartesian(model.productFeatures.map(_._1))
// Call mllib ALS's predict function.
val ratings: RDD[MLlibRating] = model.predict(usersProducts)
// The following code construct predicted results from mllib's ratings.
// Not optimal implementation. Instead of groupBy, should use combineByKey with a PriorityQueue
val userRatings: RDD[(Int, Iterable[MLlibRating])] = ratings.groupBy(_.user)
userIxQueries.leftOuterJoin(userRatings)
.map {
// When there are ratings
case (userIx, ((ix, query), Some(ratings))) => {
val topItemScores: Array[ItemScore] = ratings
.toArray
.sortBy(_.rating)(Ordering.Double.reverse) // note: from large to small ordering
.take(query.num)
.map { rating => ItemScore(
model.itemStringIntMap.inverse(rating.product),
rating.rating) }
(ix, PredictedResult(itemScores = topItemScores))
}
// When user doesn't exist in training data
case (userIx, ((ix, query), None)) => {
require(userIx == -1)
(ix, PredictedResult(itemScores = Array.empty))
}
}
}
}