blob: 0a60f8bc62aed35d1216ba0942b8af7faf94a007 [file] [log] [blame]
package org.template.recommendation
import io.prediction.controller.P2LAlgorithm
import io.prediction.controller.Params
import io.prediction.data.storage.BiMap
import io.prediction.data.storage.Event
import io.prediction.data.storage.Storage
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.recommendation.ALS
import org.apache.spark.mllib.recommendation.{Rating => MLlibRating}
import grizzled.slf4j.Logger
import scala.collection.mutable.PriorityQueue
import scala.concurrent.duration.Duration
import scala.concurrent.ExecutionContext.Implicits.global
case class ALSAlgorithmParams(
rank: Int,
numIterations: Int,
lambda: Double,
seed: Option[Long]
) extends Params
class ALSModel(
val rank: Int,
val userFeatures: Map[Int, Array[Double]],
val productFeatures: Map[Int, (Item, Option[Array[Double]])],
val userStringIntMap: BiMap[String, Int],
val itemStringIntMap: BiMap[String, Int]
) extends Serializable {
@transient lazy val itemIntStringMap = itemStringIntMap.inverse
override def toString = {
s" rank: ${rank}" +
s" userFeatures: [${userFeatures.size}]" +
s"(${userFeatures.take(2).toList}...)" +
s" productFeatures: [${productFeatures.size}]" +
s"(${productFeatures.take(2).toList}...)" +
s" userStringIntMap: [${userStringIntMap.size}]" +
s"(${userStringIntMap.take(2).toString}...)]" +
s" itemStringIntMap: [${itemStringIntMap.size}]" +
s"(${itemStringIntMap.take(2).toString}...)]"
}
}
/**
* Use ALS to build item x feature matrix
*/
class ALSAlgorithm(val ap: ALSAlgorithmParams)
extends P2LAlgorithm[PreparedData, ALSModel, Query, PredictedResult] {
@transient lazy val logger = Logger[this.type]
def train(data: PreparedData): ALSModel = {
require(!data.viewEvents.take(1).isEmpty,
s"viewEvents in PreparedData cannot be empty." +
" Please check if DataSource generates TrainingData" +
" and Preprator generates PreparedData correctly.")
require(!data.users.take(1).isEmpty,
s"users in PreparedData cannot be empty." +
" Please check if DataSource generates TrainingData" +
" and Preprator generates PreparedData correctly.")
require(!data.items.take(1).isEmpty,
s"items in PreparedData cannot be empty." +
" Please check if DataSource generates TrainingData" +
" and Preprator generates PreparedData correctly.")
// create User and item's String ID to integer index BiMap
val userStringIntMap = BiMap.stringInt(data.users.keys)
val itemStringIntMap = BiMap.stringInt(data.items.keys)
val mllibRatings = data.viewEvents
.map { r =>
// Convert user and item String IDs to Int index for MLlib
val uindex = userStringIntMap.getOrElse(r.user, -1)
val iindex = itemStringIntMap.getOrElse(r.item, -1)
if (uindex == -1)
logger.info(s"Couldn't convert nonexistent user ID ${r.user}"
+ " to Int index.")
if (iindex == -1)
logger.info(s"Couldn't convert nonexistent item ID ${r.item}"
+ " to Int index.")
((uindex, iindex), 1)
}.filter { case ((u, i), v) =>
// keep events with valid user and item index
(u != -1) && (i != -1)
}
.reduceByKey(_ + _) // aggregate all view events of same user-item pair
.map { case ((u, i), v) =>
// MLlibRating requires integer index for user and item
MLlibRating(u, i, v)
}.cache()
// MLLib ALS cannot handle empty training data.
require(!mllibRatings.take(1).isEmpty,
s"mllibRatings cannot be empty." +
" Please check if your events contain valid user and item ID.")
// seed for MLlib ALS
val seed = ap.seed.getOrElse(System.nanoTime)
val m = ALS.trainImplicit(
ratings = mllibRatings,
rank = ap.rank,
iterations = ap.numIterations,
lambda = ap.lambda,
blocks = -1,
alpha = 1.0,
seed = seed)
val userFeatures = m.userFeatures.collectAsMap.toMap
// convert ID to Int index
val items = data.items.map { case (id, item) =>
(itemStringIntMap(id), item)
}
// join item with the trained productFeatures
val productFeatures = items.leftOuterJoin(m.productFeatures)
.collectAsMap.toMap
new ALSModel(
rank = m.rank,
userFeatures = userFeatures,
productFeatures = productFeatures,
userStringIntMap = userStringIntMap,
itemStringIntMap = itemStringIntMap
)
}
def predict(model: ALSModel, query: Query): PredictedResult = {
val userFeatures = model.userFeatures
val productFeatures = model.productFeatures
// convert whiteList's string ID to integer index
val whiteList: Option[Set[Int]] = query.whiteList.map( set =>
set.map(model.itemStringIntMap.get(_)).flatten
)
val blackList: Set[String] = query.blackList.getOrElse(Set[String]())
// combine query's blackList
// into final blackList.
// convert seen Items list from String ID to interger Index
val finalBlackList: Set[Int] = blackList.map( x =>
model.itemStringIntMap.get(x)).flatten
val userFeature =
model.userStringIntMap.get(query.user).map { userIndex =>
userFeatures.get(userIndex)
}
// flatten Option[Option[Array[Double]]] to Option[Array[Double]]
.flatten
val topScores = if (userFeature.isDefined) {
// the user has feature vector
val uf = userFeature.get
val indexScores: Map[Int, Double] =
productFeatures.par // convert to parallel collection
.filter { case (i, (item, feature)) =>
feature.isDefined &&
isCandidateItem(
i = i,
item = item,
categories = query.categories,
whiteList = whiteList,
blackList = finalBlackList
)
}
.map { case (i, (item, feature)) =>
// NOTE: feature must be defined, so can call .get
val s = dotProduct(uf, feature.get)
// Can adjust score here
(i, s)
}
.filter(_._2 > 0) // only keep items with score > 0
.seq // convert back to sequential collection
val ord = Ordering.by[(Int, Double), Double](_._2).reverse
val topScores = getTopN(indexScores, query.num)(ord).toArray
topScores
} else {
// the user doesn't have feature vector.
// For example, new user is created after model is trained.
logger.info(s"No userFeature found for user ${query.user}.")
Array[(Int, Double)]()
}
val itemScores = topScores.map { case (i, s) =>
new ItemScore(
// convert item int index back to string ID
item = model.itemIntStringMap(i),
score = s
)
}
new PredictedResult(itemScores)
}
private
def getTopN[T](s: Iterable[T], n: Int)(implicit ord: Ordering[T]): Seq[T] = {
val q = PriorityQueue()
for (x <- s) {
if (q.size < n)
q.enqueue(x)
else {
// q is full
if (ord.compare(x, q.head) < 0) {
q.dequeue()
q.enqueue(x)
}
}
}
q.dequeueAll.toSeq.reverse
}
private
def dotProduct(v1: Array[Double], v2: Array[Double]): Double = {
val size = v1.size
var i = 0
var d: Double = 0
while (i < size) {
d += v1(i) * v2(i)
i += 1
}
d
}
private
def isCandidateItem(
i: Int,
item: Item,
categories: Option[Set[String]],
whiteList: Option[Set[Int]],
blackList: Set[Int]
): Boolean = {
// can add other custom filtering here
whiteList.map(_.contains(i)).getOrElse(true) &&
!blackList.contains(i) &&
// filter categories
categories.map { cat =>
item.categories.map { itemCat =>
// keep this item if has ovelap categories with the query
!(itemCat.toSet.intersect(cat).isEmpty)
}.getOrElse(false) // discard this item if it has no categories
}.getOrElse(true)
}
}