Add ECommAlgorithm to return popular item as default
diff --git a/build.sbt b/build.sbt
index 6de8bac..61c2da9 100644
--- a/build.sbt
+++ b/build.sbt
@@ -9,4 +9,5 @@
libraryDependencies ++= Seq(
"io.prediction" %% "core" % pioVersion.value % "provided",
"org.apache.spark" %% "spark-core" % "1.3.0" % "provided",
- "org.apache.spark" %% "spark-mllib" % "1.3.0" % "provided")
+ "org.apache.spark" %% "spark-mllib" % "1.3.0" % "provided",
+ "org.scalatest" %% "scalatest" % "2.2.1" % "test")
diff --git a/engine.json b/engine.json
index 1da14e1..b2cae07 100644
--- a/engine.json
+++ b/engine.json
@@ -9,7 +9,7 @@
},
"algorithms": [
{
- "name": "als",
+ "name": "ecomm",
"params": {
"appName": "INVALID_APP_NAME",
"unseenOnly": true,
diff --git a/src/main/scala/DataSource.scala b/src/main/scala/DataSource.scala
index 783b2bb..bb83a3d 100644
--- a/src/main/scala/DataSource.scala
+++ b/src/main/scala/DataSource.scala
@@ -59,37 +59,53 @@
(entityId, item)
}.cache()
- // get all "user" "view" "item" events
- val viewEventsRDD: RDD[ViewEvent] = PEventStore.find(
+ val eventsRDD: RDD[Event] = PEventStore.find(
appName = dsp.appName,
entityType = Some("user"),
- eventNames = Some(List("view")),
+ eventNames = Some(List("view", "buy")),
// targetEntityType is optional field of an event.
targetEntityType = Some(Some("item")))(sc)
- // eventsDb.find() returns RDD[Event]
+ .cache()
+
+ val viewEventsRDD: RDD[ViewEvent] = eventsRDD
+ .filter { event => event.event == "view" }
.map { event =>
- val viewEvent = try {
- event.event match {
- case "view" => ViewEvent(
- user = event.entityId,
- item = event.targetEntityId.get,
- t = event.eventTime.getMillis)
- case _ => throw new Exception(s"Unexpected event ${event} is read.")
- }
+ try {
+ ViewEvent(
+ user = event.entityId,
+ item = event.targetEntityId.get,
+ t = event.eventTime.getMillis
+ )
} catch {
- case e: Exception => {
+ case e: Exception =>
logger.error(s"Cannot convert ${event} to ViewEvent." +
s" Exception: ${e}.")
throw e
- }
}
- viewEvent
- }.cache()
+ }
+
+ val buyEventsRDD: RDD[BuyEvent] = eventsRDD
+ .filter { event => event.event == "buy" }
+ .map { event =>
+ try {
+ BuyEvent(
+ user = event.entityId,
+ item = event.targetEntityId.get,
+ t = event.eventTime.getMillis
+ )
+ } catch {
+ case e: Exception =>
+ logger.error(s"Cannot convert ${event} to BuyEvent." +
+ s" Exception: ${e}.")
+ throw e
+ }
+ }
new TrainingData(
users = usersRDD,
items = itemsRDD,
- viewEvents = viewEventsRDD
+ viewEvents = viewEventsRDD,
+ buyEvents = buyEventsRDD
)
}
}
@@ -100,14 +116,18 @@
case class ViewEvent(user: String, item: String, t: Long)
+case class BuyEvent(user: String, item: String, t: Long)
+
class TrainingData(
val users: RDD[(String, User)],
val items: RDD[(String, Item)],
- val viewEvents: RDD[ViewEvent]
+ val viewEvents: RDD[ViewEvent],
+ val buyEvents: RDD[BuyEvent]
) extends Serializable {
override def toString = {
s"users: [${users.count()} (${users.take(2).toList}...)]" +
s"items: [${items.count()} (${items.take(2).toList}...)]" +
- s"viewEvents: [${viewEvents.count()}] (${viewEvents.take(2).toList}...)"
+ s"viewEvents: [${viewEvents.count()}] (${viewEvents.take(2).toList}...)" +
+ s"buyEvents: [${buyEvents.count()}] (${buyEvents.take(2).toList}...)"
}
}
diff --git a/src/main/scala/ECommAlgorithm.scala b/src/main/scala/ECommAlgorithm.scala
new file mode 100644
index 0000000..063fa99
--- /dev/null
+++ b/src/main/scala/ECommAlgorithm.scala
@@ -0,0 +1,502 @@
+package org.template.ecommercerecommendation
+
+import io.prediction.controller.P2LAlgorithm
+import io.prediction.controller.Params
+import io.prediction.data.storage.BiMap
+import io.prediction.data.storage.Event
+import io.prediction.data.store.LEventStore
+
+import org.apache.spark.SparkContext
+import org.apache.spark.SparkContext._
+import org.apache.spark.mllib.recommendation.ALS
+import org.apache.spark.mllib.recommendation.{Rating => MLlibRating}
+import org.apache.spark.rdd.RDD
+
+import grizzled.slf4j.Logger
+
+import scala.collection.mutable.PriorityQueue
+import scala.concurrent.duration.Duration
+import scala.concurrent.ExecutionContext.Implicits.global
+
+case class ECommAlgorithmParams(
+ appName: String,
+ unseenOnly: Boolean,
+ seenEvents: List[String],
+ rank: Int,
+ numIterations: Int,
+ lambda: Double,
+ seed: Option[Long]
+) extends Params
+
+
+case class ProductModel(
+ item: Item,
+ features: Option[Array[Double]], // features by ALS
+ count: Int // popular count
+)
+
+class ECommModel(
+ val rank: Int,
+ val userFeatures: Map[Int, Array[Double]],
+ val productModels: Map[Int, ProductModel],
+ val userStringIntMap: BiMap[String, Int],
+ val itemStringIntMap: BiMap[String, Int]
+) extends Serializable {
+
+ @transient lazy val itemIntStringMap = itemStringIntMap.inverse
+
+ override def toString = {
+ s" rank: ${rank}" +
+ s" userFeatures: [${userFeatures.size}]" +
+ s"(${userFeatures.take(2).toList}...)" +
+ s" productModels: [${productModels.size}]" +
+ s"(${productModels.take(2).toList}...)" +
+ s" userStringIntMap: [${userStringIntMap.size}]" +
+ s"(${userStringIntMap.take(2).toString}...)]" +
+ s" itemStringIntMap: [${itemStringIntMap.size}]" +
+ s"(${itemStringIntMap.take(2).toString}...)]"
+ }
+}
+
+class ECommAlgorithm(val ap: ECommAlgorithmParams)
+ extends P2LAlgorithm[PreparedData, ECommModel, Query, PredictedResult] {
+
+ @transient lazy val logger = Logger[this.type]
+
+ def train(sc: SparkContext, data: PreparedData): ECommModel = {
+ require(!data.viewEvents.take(1).isEmpty,
+ s"viewEvents in PreparedData cannot be empty." +
+ " Please check if DataSource generates TrainingData" +
+ " and Preprator generates PreparedData correctly.")
+ require(!data.users.take(1).isEmpty,
+ s"users in PreparedData cannot be empty." +
+ " Please check if DataSource generates TrainingData" +
+ " and Preprator generates PreparedData correctly.")
+ require(!data.items.take(1).isEmpty,
+ s"items in PreparedData cannot be empty." +
+ " Please check if DataSource generates TrainingData" +
+ " and Preprator generates PreparedData correctly.")
+ // create User and item's String ID to integer index BiMap
+ val userStringIntMap = BiMap.stringInt(data.users.keys)
+ val itemStringIntMap = BiMap.stringInt(data.items.keys)
+
+ val mllibRatings = data.viewEvents
+ .map { r =>
+ // Convert user and item String IDs to Int index for MLlib
+ val uindex = userStringIntMap.getOrElse(r.user, -1)
+ val iindex = itemStringIntMap.getOrElse(r.item, -1)
+
+ if (uindex == -1)
+ logger.info(s"Couldn't convert nonexistent user ID ${r.user}"
+ + " to Int index.")
+
+ if (iindex == -1)
+ logger.info(s"Couldn't convert nonexistent item ID ${r.item}"
+ + " to Int index.")
+
+ ((uindex, iindex), 1)
+ }.filter { case ((u, i), v) =>
+ // keep events with valid user and item index
+ (u != -1) && (i != -1)
+ }
+ .reduceByKey(_ + _) // aggregate all view events of same user-item pair
+ .map { case ((u, i), v) =>
+ // MLlibRating requires integer index for user and item
+ MLlibRating(u, i, v)
+ }.cache()
+
+ // MLLib ALS cannot handle empty training data.
+ require(!mllibRatings.take(1).isEmpty,
+ s"mllibRatings cannot be empty." +
+ " Please check if your events contain valid user and item ID.")
+
+ // seed for MLlib ALS
+ val seed = ap.seed.getOrElse(System.nanoTime)
+
+ // use ALS to train feature vectors
+ val m = ALS.trainImplicit(
+ ratings = mllibRatings,
+ rank = ap.rank,
+ iterations = ap.numIterations,
+ lambda = ap.lambda,
+ blocks = -1,
+ alpha = 1.0,
+ seed = seed)
+
+ val userFeatures = m.userFeatures.collectAsMap.toMap
+
+ // convert ID to Int index
+ val items = data.items.map { case (id, item) =>
+ (itemStringIntMap(id), item)
+ }
+
+ // join item with the trained productFeatures
+ val productFeatures: Map[Int, (Item, Option[Array[Double]])] =
+ items.leftOuterJoin(m.productFeatures).collectAsMap.toMap
+
+
+ val popularCount = trainDefault(
+ userStringIntMap = userStringIntMap,
+ itemStringIntMap = itemStringIntMap,
+ data = data
+ )
+
+ val productModels: Map[Int, ProductModel] = productFeatures
+ .map { case (index, (item, features)) =>
+ val pm = ProductModel(
+ item = item,
+ features = features,
+ // NOTE: use getOrElse because popularCount may not contain all items.
+ count = popularCount.getOrElse(index, 0)
+ )
+ (index, pm)
+ }
+
+ new ECommModel(
+ rank = m.rank,
+ userFeatures = userFeatures,
+ productModels = productModels,
+ userStringIntMap = userStringIntMap,
+ itemStringIntMap = itemStringIntMap
+ )
+ }
+
+ // train default model
+ private def trainDefault(
+ userStringIntMap: BiMap[String, Int],
+ itemStringIntMap: BiMap[String, Int],
+ data: PreparedData): Map[Int, Int] = {
+ // count number of buys
+ // (item index, count)
+ val buyCountsRDD: RDD[(Int, Int)] = data.buyEvents
+ .map { r =>
+ // Convert user and item String IDs to Int index
+ val uindex = userStringIntMap.getOrElse(r.user, -1)
+ val iindex = itemStringIntMap.getOrElse(r.item, -1)
+
+ if (uindex == -1)
+ logger.info(s"Couldn't convert nonexistent user ID ${r.user}"
+ + " to Int index.")
+
+ if (iindex == -1)
+ logger.info(s"Couldn't convert nonexistent item ID ${r.item}"
+ + " to Int index.")
+
+ (uindex, iindex, 1)
+ }
+ .filter { case (u, i, v) =>
+ // keep events with valid user and item index
+ (u != -1) && (i != -1)
+ }
+ .map { case (u, i, v) => (i, 1) }
+ .reduceByKey{ case (a, b) => a + b }
+
+ buyCountsRDD.collectAsMap.toMap
+ }
+
+ def predict(model: ECommModel, query: Query): PredictedResult = {
+
+ val userFeatures = model.userFeatures
+ val productModels = model.productModels
+
+ // convert whiteList's string ID to integer index
+ val whiteList: Option[Set[Int]] = query.whiteList.map( set =>
+ set.map(model.itemStringIntMap.get(_)).flatten
+ )
+
+ val blackList: Set[String] = query.blackList.getOrElse(Set[String]())
+
+ // if unseenOnly is True, get all seen items
+ val seenItems: Set[String] = if (ap.unseenOnly) {
+
+ // get all user item events which are considered as "seen" events
+ val seenEvents: Iterator[Event] = try {
+ LEventStore.findByEntity(
+ appName = ap.appName,
+ entityType = "user",
+ entityId = query.user,
+ eventNames = Some(ap.seenEvents),
+ targetEntityType = Some(Some("item")),
+ // set time limit to avoid super long DB access
+ timeout = Duration(200, "millis")
+ )
+ } catch {
+ case e: scala.concurrent.TimeoutException =>
+ logger.error(s"Timeout when read seen events." +
+ s" Empty list is used. ${e}")
+ Iterator[Event]()
+ case e: Exception =>
+ logger.error(s"Error when read seen events: ${e}")
+ throw e
+ }
+
+ seenEvents.map { event =>
+ try {
+ event.targetEntityId.get
+ } catch {
+ case e => {
+ logger.error(s"Can't get targetEntityId of event ${event}.")
+ throw e
+ }
+ }
+ }.toSet
+ } else {
+ Set[String]()
+ }
+
+ // get the latest constraint unavailableItems $set event
+ val unavailableItems: Set[String] = try {
+ val constr = LEventStore.findByEntity(
+ appName = ap.appName,
+ entityType = "constraint",
+ entityId = "unavailableItems",
+ eventNames = Some(Seq("$set")),
+ limit = Some(1),
+ latest = true,
+ timeout = Duration(200, "millis")
+ )
+ if (constr.hasNext) {
+ constr.next.properties.get[Set[String]]("items")
+ } else {
+ Set[String]()
+ }
+ } catch {
+ case e: scala.concurrent.TimeoutException =>
+ logger.error(s"Timeout when read set unavailableItems event." +
+ s" Empty list is used. ${e}")
+ Set[String]()
+ case e: Exception =>
+ logger.error(s"Error when read set unavailableItems event: ${e}")
+ throw e
+ }
+
+ // combine query's blackList,seenItems and unavailableItems
+ // into final blackList.
+ // convert seen Items list from String ID to interger Index
+ val finalBlackList: Set[Int] = (blackList ++ seenItems ++
+ unavailableItems).map( x => model.itemStringIntMap.get(x)).flatten
+
+ val userFeature =
+ model.userStringIntMap.get(query.user).map { userIndex =>
+ userFeatures.get(userIndex)
+ }
+ // flatten Option[Option[Array[Double]]] to Option[Array[Double]]
+ .flatten
+
+ val topScores = if (userFeature.isDefined) {
+ // the user has feature vector
+ val uf = userFeature.get
+ val indexScores: Map[Int, Double] =
+ productModels.par // convert to parallel collection
+ .filter { case (i, pm) =>
+ pm.features.isDefined &&
+ isCandidateItem(
+ i = i,
+ item = pm.item,
+ categories = query.categories,
+ whiteList = whiteList,
+ blackList = finalBlackList
+ )
+ }
+ .map { case (i, pm) =>
+ // NOTE: features must be defined, so can call .get
+ val s = dotProduct(uf, pm.features.get)
+ // Can adjust score here
+ (i, s)
+ }
+ .filter(_._2 > 0) // only keep items with score > 0
+ .seq // convert back to sequential collection
+
+ val ord = Ordering.by[(Int, Double), Double](_._2).reverse
+ val topScores = getTopN(indexScores, query.num)(ord).toArray
+
+ topScores
+
+ } else {
+ // the user doesn't have feature vector.
+ // For example, new user is created after model is trained.
+ logger.info(s"No userFeature found for user ${query.user}.")
+ predictNewUser(
+ model = model,
+ query = query,
+ whiteList = whiteList,
+ blackList = finalBlackList
+ )
+
+ }
+
+ val itemScores = topScores.map { case (i, s) =>
+ new ItemScore(
+ // convert item int index back to string ID
+ item = model.itemIntStringMap(i),
+ score = s
+ )
+ }
+
+ new PredictedResult(itemScores)
+ }
+
+ /** Get recently viewed item of the user and return top similar items */
+ private
+ def predictNewUser(
+ model: ECommModel,
+ query: Query,
+ whiteList: Option[Set[Int]],
+ blackList: Set[Int]): Array[(Int, Double)] = {
+
+ val userFeatures = model.userFeatures
+ val productModels = model.productModels
+
+ // get latest 10 user view item events
+ val recentEvents = try {
+ LEventStore.findByEntity(
+ appName = ap.appName,
+ // entityType and entityId is specified for fast lookup
+ entityType = "user",
+ entityId = query.user,
+ eventNames = Some(Seq("view")),
+ targetEntityType = Some(Some("item")),
+ limit = Some(10),
+ latest = true,
+ // set time limit to avoid super long DB access
+ timeout = Duration(200, "millis")
+ )
+ } catch {
+ case e: scala.concurrent.TimeoutException =>
+ logger.error(s"Timeout when read recent events." +
+ s" Empty list is used. ${e}")
+ Iterator[Event]()
+ case e: Exception =>
+ logger.error(s"Error when read recent events: ${e}")
+ throw e
+ }
+
+ val recentItems: Set[String] = recentEvents.map { event =>
+ try {
+ event.targetEntityId.get
+ } catch {
+ case e => {
+ logger.error("Can't get targetEntityId of event ${event}.")
+ throw e
+ }
+ }
+ }.toSet
+
+ val recentList: Set[Int] = recentItems.map (x =>
+ model.itemStringIntMap.get(x)).flatten
+
+ val recentFeatures: Vector[Array[Double]] = recentList.toVector
+ // productModels may not contain the requested item
+ .map { i =>
+ productModels.get(i).map { pm => pm.features }.flatten
+ }.flatten
+
+ val indexScores: Map[Int, Double] = if (recentFeatures.isEmpty) {
+ logger.info(s"No features vector for recent items ${recentItems}.")
+
+ productModels.map { case (i, pm) =>
+ // Can adjust score here
+ (i, pm.count.toDouble)
+ }
+
+ } else {
+ productModels.par // convert to parallel collection
+ .filter { case (i, pm) => //(item, feature)) =>
+ pm.features.isDefined &&
+ isCandidateItem(
+ i = i,
+ item = pm.item,
+ categories = query.categories,
+ whiteList = whiteList,
+ blackList = blackList
+ )
+ }
+ .map { case (i, pm) => //(item, feature)) =>
+ val s = recentFeatures.map{ rf =>
+ // pm.features must be defined because of filter logic above
+ cosine(rf, pm.features.get)
+ }.reduce(_ + _)
+ // Can adjust score here
+ (i, s)
+ }
+ .filter(_._2 > 0) // keep items with score > 0
+ .seq // convert back to sequential collection
+ }
+
+ val ord = Ordering.by[(Int, Double), Double](_._2).reverse
+ val topScores = getTopN(indexScores, query.num)(ord).toArray
+
+ topScores
+ }
+
+ private
+ def getTopN[T](s: Iterable[T], n: Int)(implicit ord: Ordering[T]): Seq[T] = {
+
+ val q = PriorityQueue()
+
+ for (x <- s) {
+ if (q.size < n)
+ q.enqueue(x)
+ else {
+ // q is full
+ if (ord.compare(x, q.head) < 0) {
+ q.dequeue()
+ q.enqueue(x)
+ }
+ }
+ }
+
+ q.dequeueAll.toSeq.reverse
+ }
+
+ private
+ def dotProduct(v1: Array[Double], v2: Array[Double]): Double = {
+ val size = v1.size
+ var i = 0
+ var d: Double = 0
+ while (i < size) {
+ d += v1(i) * v2(i)
+ i += 1
+ }
+ d
+ }
+
+ private
+ def cosine(v1: Array[Double], v2: Array[Double]): Double = {
+ val size = v1.size
+ var i = 0
+ var n1: Double = 0
+ var n2: Double = 0
+ var d: Double = 0
+ while (i < size) {
+ n1 += v1(i) * v1(i)
+ n2 += v2(i) * v2(i)
+ d += v1(i) * v2(i)
+ i += 1
+ }
+ val n1n2 = (math.sqrt(n1) * math.sqrt(n2))
+ if (n1n2 == 0) 0 else (d / n1n2)
+ }
+
+ private
+ def isCandidateItem(
+ i: Int,
+ item: Item,
+ categories: Option[Set[String]],
+ whiteList: Option[Set[Int]],
+ blackList: Set[Int]
+ ): Boolean = {
+ // can add other custom filtering here
+ whiteList.map(_.contains(i)).getOrElse(true) &&
+ !blackList.contains(i) &&
+ // filter categories
+ categories.map { cat =>
+ item.categories.map { itemCat =>
+ // keep this item if has ovelap categories with the query
+ !(itemCat.toSet.intersect(cat).isEmpty)
+ }.getOrElse(false) // discard this item if it has no categories
+ }.getOrElse(true)
+
+ }
+
+}
diff --git a/src/main/scala/Engine.scala b/src/main/scala/Engine.scala
index 42ec4d4..47d9896 100644
--- a/src/main/scala/Engine.scala
+++ b/src/main/scala/Engine.scala
@@ -25,7 +25,7 @@
new Engine(
classOf[DataSource],
classOf[Preparator],
- Map("als" -> classOf[ALSAlgorithm]),
+ Map("ecomm" -> classOf[ECommAlgorithm]),
classOf[Serving])
}
}
diff --git a/src/main/scala/PopularAlgorithm.scala b/src/main/scala/PopularAlgorithm.scala
new file mode 100644
index 0000000..c4255bc
--- /dev/null
+++ b/src/main/scala/PopularAlgorithm.scala
@@ -0,0 +1,72 @@
+package org.template.ecommercerecommendation
+/*
+import io.prediction.controller.P2LAlgorithm
+import io.prediction.controller.Params
+
+
+case class PopularAlgorithmParams() extends Params
+
+class PopularModel(
+ val itemModel: Vector[(String, (Item, Int))] // Vector of (item ID, (Item, Count))
+) extends Serializable {
+
+}
+
+class PopularAlgorithm(val ap: PopularAlgorithmParams)
+ extends P2LAlgorithm[PreparedData, PopularModel, Query, PredictedResult] {
+
+
+ def train(sc: SparkContext, data: PreparedData): PopularModel = {
+
+ // calculate number of buys for each item
+ val buyCounts: RDD[(String, Int)] = data.buyEvents
+ .map { buy => (buy.item, 1) }
+ .reduceByKey{ case (a, b) => a + b }
+
+ // combine item data with the count
+ val itemWithCount: RDD[(String, (Item, Int))] = data.items.join(buyCounts)
+
+ // collect to local vector, and sort save as model
+ val itemModel = itemWithCount.collect.toVector
+ .sortBy{ case (id, (item, count)) => count }(Ordering.Int.revese)
+
+ PopularModel(
+ itemModel = itemModel
+ )
+
+ }
+
+
+ def predict(model: PopularModel, query: Query): PredictedResult = {
+ model.itemModel.filter {
+ case (id, (item, count)) =>
+ isCandidateItem(
+
+ )
+ )
+ }
+ }
+
+ private
+ def isCandidateItem(
+ i: Int,
+ item: Item,
+ categories: Option[Set[String]],
+ whiteList: Option[Set[Int]],
+ blackList: Set[Int]
+ ): Boolean = {
+ // can add other custom filtering here
+ whiteList.map(_.contains(i)).getOrElse(true) &&
+ !blackList.contains(i) &&
+ // filter categories
+ categories.map { cat =>
+ item.categories.map { itemCat =>
+ // keep this item if has ovelap categories with the query
+ !(itemCat.toSet.intersect(cat).isEmpty)
+ }.getOrElse(false) // discard this item if it has no categories
+ }.getOrElse(true)
+
+ }
+
+}
+*/
diff --git a/src/main/scala/Preparator.scala b/src/main/scala/Preparator.scala
index 4dd45cf..ff82f80 100644
--- a/src/main/scala/Preparator.scala
+++ b/src/main/scala/Preparator.scala
@@ -13,12 +13,14 @@
new PreparedData(
users = trainingData.users,
items = trainingData.items,
- viewEvents = trainingData.viewEvents)
+ viewEvents = trainingData.viewEvents,
+ buyEvents = trainingData.buyEvents)
}
}
class PreparedData(
val users: RDD[(String, User)],
val items: RDD[(String, Item)],
- val viewEvents: RDD[ViewEvent]
+ val viewEvents: RDD[ViewEvent],
+ val buyEvents: RDD[BuyEvent]
) extends Serializable
diff --git a/src/test/scala/ECommAlgorithmTest.scala b/src/test/scala/ECommAlgorithmTest.scala
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/test/scala/ECommAlgorithmTest.scala
diff --git a/src/test/scala/EngineTestSparkContext.scala b/src/test/scala/EngineTestSparkContext.scala
new file mode 100644
index 0000000..2931403
--- /dev/null
+++ b/src/test/scala/EngineTestSparkContext.scala
@@ -0,0 +1,36 @@
+package org.template.ecommercerecommendation
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.scalatest.{BeforeAndAfterAll, Suite}
+
+trait EngineTestSparkContext extends BeforeAndAfterAll {
+ self: Suite =>
+ @transient private var _sc: SparkContext = _
+
+ def sc: SparkContext = _sc
+
+ var conf = new SparkConf(false)
+
+ override def beforeAll() {
+ _sc = new SparkContext("local", "test", conf)
+ super.beforeAll()
+ }
+
+ override def afterAll() {
+ LocalSparkContext.stop(_sc)
+
+ _sc = null
+ super.afterAll()
+ }
+}
+
+object LocalSparkContext {
+ def stop(sc: SparkContext) {
+ if (sc != null) {
+ sc.stop()
+ }
+ // To avoid Akka rebinding to the same port, since it doesn't unbind
+ // immediately on shutdown
+ System.clearProperty("spark.driver.port")
+ }
+}
diff --git a/src/test/scala/PreparatorTest.scala b/src/test/scala/PreparatorTest.scala
new file mode 100644
index 0000000..c00d86d
--- /dev/null
+++ b/src/test/scala/PreparatorTest.scala
@@ -0,0 +1,43 @@
+package org.template.ecommercerecommendation
+
+import org.scalatest.FlatSpec
+import org.scalatest.Matchers
+
+class PreparatorTest extends FlatSpec with EngineTestSparkContext with Matchers {
+
+ val preparator = new Preparator()
+ val users = Map(
+ "u0" -> User(),
+ "u1" -> User()
+ )
+
+ val items = Map(
+ "i0" -> Item(categories = Some(List("c0", "c1"))),
+ "i1" -> Item(categories = None)
+ )
+
+ val view = Seq(
+ ViewEvent("u0", "i0", 1000010),
+ ViewEvent("u0", "i1", 1000020),
+ ViewEvent("u1", "i1", 1000030)
+ )
+
+ val buy = Seq(
+ BuyEvent("u0", "i0", 1000020),
+ BuyEvent("u0", "i1", 1000030),
+ BuyEvent("u1", "i1", 1000040)
+ )
+
+ "Preparator" should "prepare PreparedData" in {
+ val trainingData = new TrainingData(
+ users = sc.parallelize(users.toSeq),
+ items = sc.parallelize(items.toSeq),
+ viewEvents = sc.parallelize(view.toSeq),
+ buyEvents = sc.parallelize(buy.toSeq)
+ )
+
+ val preparedData = preparator.prepare(sc, trainingData)
+
+ preparedData.users.collect should contain theSameElementsAs users
+ }
+}