blob: 7c0ddb9da164911e6079ee25e41fcee967796234 [file] [log] [blame]
package org.template.similarproduct
import io.prediction.controller.P2LAlgorithm
import io.prediction.controller.Params
import io.prediction.data.storage.BiMap
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.recommendation.ALS
import org.apache.spark.mllib.recommendation.{Rating => MLlibRating}
import grizzled.slf4j.Logger
import scala.collection.mutable.PriorityQueue
case class ALSAlgorithmParams(
rank: Int,
numIterations: Int,
lambda: Double,
seed: Option[Long]) extends Params
class ALSModel(
val productFeatures: Map[Int, Array[Double]],
val itemStringIntMap: BiMap[String, Int],
val items: Map[Int, Item]
) extends Serializable {
@transient lazy val itemIntStringMap = itemStringIntMap.inverse
override def toString = {
s" productFeatures: [${productFeatures.size}]" +
s"(${productFeatures.take(2).toList}...)" +
s" itemStringIntMap: [${itemStringIntMap.size}]" +
s"(${itemStringIntMap.take(2).toString}...)]" +
s" items: [${items.size}]" +
s"(${items.take(2).toString}...)]"
}
}
/**
* Use ALS to build item x feature matrix
*/
class ALSAlgorithm(val ap: ALSAlgorithmParams)
extends P2LAlgorithm[PreparedData, ALSModel, Query, PredictedResult] {
@transient lazy val logger = Logger[this.type]
def train(sc: SparkContext, data: PreparedData): ALSModel = {
require(!data.viewEvents.take(1).isEmpty,
s"viewEvents in PreparedData cannot be empty." +
" Please check if DataSource generates TrainingData" +
" and Preprator generates PreparedData correctly.")
require(!data.users.take(1).isEmpty,
s"users in PreparedData cannot be empty." +
" Please check if DataSource generates TrainingData" +
" and Preprator generates PreparedData correctly.")
require(!data.items.take(1).isEmpty,
s"items in PreparedData cannot be empty." +
" Please check if DataSource generates TrainingData" +
" and Preprator generates PreparedData correctly.")
// create User and item's String ID to integer index BiMap
val userStringIntMap = BiMap.stringInt(data.users.keys)
val itemStringIntMap = BiMap.stringInt(data.items.keys)
// collect Item as Map and convert ID to Int index
val items: Map[Int, Item] = data.items.map { case (id, item) =>
(itemStringIntMap(id), item)
}.collectAsMap.toMap
val mllibRatings = data.viewEvents
.map { r =>
// Convert user and item String IDs to Int index for MLlib
val uindex = userStringIntMap.getOrElse(r.user, -1)
val iindex = itemStringIntMap.getOrElse(r.item, -1)
if (uindex == -1)
logger.info(s"Couldn't convert nonexistent user ID ${r.user}"
+ " to Int index.")
if (iindex == -1)
logger.info(s"Couldn't convert nonexistent item ID ${r.item}"
+ " to Int index.")
((uindex, iindex), 1)
}.filter { case ((u, i), v) =>
// keep events with valid user and item index
(u != -1) && (i != -1)
}.reduceByKey(_ + _) // aggregate all view events of same user-item pair
.map { case ((u, i), v) =>
// MLlibRating requires integer index for user and item
MLlibRating(u, i, v)
}
.cache()
// MLLib ALS cannot handle empty training data.
require(!mllibRatings.take(1).isEmpty,
s"mllibRatings cannot be empty." +
" Please check if your events contain valid user and item ID.")
// seed for MLlib ALS
val seed = ap.seed.getOrElse(System.nanoTime)
val m = ALS.trainImplicit(
ratings = mllibRatings,
rank = ap.rank,
iterations = ap.numIterations,
lambda = ap.lambda,
blocks = -1,
alpha = 1.0,
seed = seed)
new ALSModel(
productFeatures = m.productFeatures.collectAsMap.toMap,
itemStringIntMap = itemStringIntMap,
items = items
)
}
def predict(model: ALSModel, query: Query): PredictedResult = {
val productFeatures = model.productFeatures
// convert items to Int index
val queryList: Set[Int] = query.items.map(model.itemStringIntMap.get(_))
.flatten.toSet
val queryFeatures: Vector[Array[Double]] = queryList.toVector
// productFeatures may not contain the requested item
.map { item => productFeatures.get(item) }
.flatten
val whiteList: Option[Set[Int]] = query.whiteList.map( set =>
set.map(model.itemStringIntMap.get(_)).flatten
)
val blackList: Option[Set[Int]] = query.blackList.map ( set =>
set.map(model.itemStringIntMap.get(_)).flatten
)
val ord = Ordering.by[(Int, Double), Double](_._2).reverse
val indexScores: Array[(Int, Double)] = if (queryFeatures.isEmpty) {
logger.info(s"No productFeatures vector for query items ${query.items}.")
Array[(Int, Double)]()
} else {
productFeatures.par // convert to parallel collection
.mapValues { f =>
queryFeatures.map{ qf =>
cosine(qf, f)
}.reduce(_ + _)
}
.filter(_._2 > 0) // keep items with score > 0
.seq // convert back to sequential collection
.toArray
}
val filteredScore = indexScores.view.filter { case (i, v) =>
isCandidateItem(
i = i,
items = model.items,
categories = query.categories,
queryList = queryList,
whiteList = whiteList,
blackList = blackList
)
}
val topScores = getTopN(filteredScore, query.num)(ord).toArray
val itemScores = topScores.map { case (i, s) =>
new ItemScore(
item = model.itemIntStringMap(i),
score = s
)
}
new PredictedResult(itemScores)
}
private
def getTopN[T](s: Seq[T], n: Int)(implicit ord: Ordering[T]): Seq[T] = {
val q = PriorityQueue()
for (x <- s) {
if (q.size < n)
q.enqueue(x)
else {
// q is full
if (ord.compare(x, q.head) < 0) {
q.dequeue()
q.enqueue(x)
}
}
}
q.dequeueAll.toSeq.reverse
}
private
def cosine(v1: Array[Double], v2: Array[Double]): Double = {
val size = v1.size
var i = 0
var n1: Double = 0
var n2: Double = 0
var d: Double = 0
while (i < size) {
n1 += v1(i) * v1(i)
n2 += v2(i) * v2(i)
d += v1(i) * v2(i)
i += 1
}
val n1n2 = (math.sqrt(n1) * math.sqrt(n2))
if (n1n2 == 0) 0 else (d / n1n2)
}
private
def isCandidateItem(
i: Int,
items: Map[Int, Item],
categories: Option[Set[String]],
queryList: Set[Int],
whiteList: Option[Set[Int]],
blackList: Option[Set[Int]]
): Boolean = {
whiteList.map(_.contains(i)).getOrElse(true) &&
blackList.map(!_.contains(i)).getOrElse(true) &&
// discard items in query as well
(!queryList.contains(i)) &&
// filter categories
categories.map { cat =>
items(i).categories.map { itemCat =>
// keep this item if has ovelap categories with the query
!(itemCat.toSet.intersect(cat).isEmpty)
}.getOrElse(false) // discard this item if it has no categories
}.getOrElse(true)
}
}