Initial version
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..64fa18b
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,3 @@
+manifest.json
+target/
+pio.log
diff --git a/README.md b/README.md
index e1e0199..f992275 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,75 @@
-# template-scala-parallel-ecommercerecommendation
-PredictionIO E-Commerce Recommendation engine template (Scala-based parallelized engine)
+# E-Commerce Recommendation Template
+
+## Documentation
+
+Please refer to http://docs.prediction.io/templates/ecommercerecommendation/quickstart/
+
+## Versions
+
+### develop
+
+
+## Development Notes
+
+### import sample data
+
+```
+$ python data/import_eventserver.py --access_key <your_access_key>
+```
+
+### query
+
+normal:
+
+```
+curl -H "Content-Type: application/json" \
+-d '{
+ "user" : "u1",
+ "num" : 10 }' \
+http://localhost:8000/queries.json \
+-w %{time_connect}:%{time_starttransfer}:%{time_total}
+```
+
+```
+curl -H "Content-Type: application/json" \
+-d '{
+ "user" : "u1",
+ "num": 10,
+ "categories" : ["c4", "c3"]
+}' \
+http://localhost:8000/queries.json \
+-w %{time_connect}:%{time_starttransfer}:%{time_total}
+```
+
+```
+curl -H "Content-Type: application/json" \
+-d '{
+ "user" : "u1",
+ "num": 10,
+ "whiteList": ["i21", "i26", "i40"]
+}' \
+http://localhost:8000/queries.json \
+-w %{time_connect}:%{time_starttransfer}:%{time_total}
+```
+
+```
+curl -H "Content-Type: application/json" \
+-d '{
+ "user" : "u1",
+ "num": 10,
+ "blackList": ["i21", "i26", "i40"]
+}' \
+http://localhost:8000/queries.json \
+-w %{time_connect}:%{time_starttransfer}:%{time_total}
+```
+
+unknown user:
+
+```
+curl -H "Content-Type: application/json" \
+-d '{
+ "user" : "unk1",
+ "num": 10}' \
+http://localhost:8000/queries.json \
+-w %{time_connect}:%{time_starttransfer}:%{time_total}
+```
diff --git a/build.sbt b/build.sbt
new file mode 100644
index 0000000..575ccf2
--- /dev/null
+++ b/build.sbt
@@ -0,0 +1,12 @@
+import AssemblyKeys._
+
+assemblySettings
+
+name := "template-scala-parallel-ecommercerecommendation"
+
+organization := "io.prediction"
+
+libraryDependencies ++= Seq(
+ "io.prediction" %% "core" % "0.8.7-SNAPSHOT" % "provided",
+ "org.apache.spark" %% "spark-core" % "1.2.0" % "provided",
+ "org.apache.spark" %% "spark-mllib" % "1.2.0" % "provided")
diff --git a/data/import_eventserver.py b/data/import_eventserver.py
new file mode 100644
index 0000000..f931a5b
--- /dev/null
+++ b/data/import_eventserver.py
@@ -0,0 +1,73 @@
+"""
+Import sample data for E-Commerce Recommendation Engine Template
+"""
+
+import predictionio
+import argparse
+import random
+
+SEED = 3
+
+def import_events(client):
+ random.seed(SEED)
+ count = 0
+ print client.get_status()
+ print "Importing data..."
+
+ # generate 10 users, with user ids u1,u2,....,u10
+ user_ids = ["u%s" % i for i in range(1, 11)]
+ for user_id in user_ids:
+ print "Set user", user_id
+ client.create_event(
+ event="$set",
+ entity_type="user",
+ entity_id=user_id
+ )
+ count += 1
+
+ # generate 50 items, with item ids i1,i2,....,i50
+ # random assign 1 to 4 categories among c1-c6 to items
+ categories = ["c%s" % i for i in range(1, 7)]
+ item_ids = ["i%s" % i for i in range(1, 51)]
+ for item_id in item_ids:
+ print "Set item", item_id
+ client.create_event(
+ event="$set",
+ entity_type="item",
+ entity_id=item_id,
+ properties={
+ "categories" : random.sample(categories, random.randint(1, 4))
+ }
+ )
+ count += 1
+
+ # each user randomly viewed 10 items
+ for user_id in user_ids:
+ for viewed_item in random.sample(item_ids, 10):
+ print "User", user_id ,"views item", viewed_item
+ client.create_event(
+ event="view",
+ entity_type="user",
+ entity_id=user_id,
+ target_entity_type="item",
+ target_entity_id=viewed_item
+ )
+ count += 1
+
+ print "%s events are imported." % count
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(
+ description="Import sample data for e-commerce recommendation engine")
+ parser.add_argument('--access_key', default='invald_access_key')
+ parser.add_argument('--url', default="http://localhost:7070")
+
+ args = parser.parse_args()
+ print args
+
+ client = predictionio.EventClient(
+ access_key=args.access_key,
+ url=args.url,
+ threads=5,
+ qsize=500)
+ import_events(client)
diff --git a/data/send_query.py b/data/send_query.py
new file mode 100644
index 0000000..b0eb651
--- /dev/null
+++ b/data/send_query.py
@@ -0,0 +1,7 @@
+"""
+Send sample query to prediction engine
+"""
+
+import predictionio
+engine_client = predictionio.EngineClient(url="http://localhost:8000")
+print engine_client.send_query({"user": "u1", "num": 4})
diff --git a/engine.json b/engine.json
new file mode 100644
index 0000000..bf02c23
--- /dev/null
+++ b/engine.json
@@ -0,0 +1,21 @@
+{
+ "id": "default",
+ "description": "Default settings",
+ "engineFactory": "org.template.ecommercerecommendation.ECommerceRecommendationEngine",
+ "datasource": {
+ "params" : {
+ "appId": 6
+ }
+ },
+ "algorithms": [
+ {
+ "name": "als",
+ "params": {
+ "rank": 10,
+ "numIterations" : 20,
+ "lambda": 0.01,
+ "seed": 3
+ }
+ }
+ ]
+}
diff --git a/project/assembly.sbt b/project/assembly.sbt
new file mode 100644
index 0000000..54c3252
--- /dev/null
+++ b/project/assembly.sbt
@@ -0,0 +1 @@
+addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2")
diff --git a/src/main/scala/ALSAlgorithm.scala b/src/main/scala/ALSAlgorithm.scala
new file mode 100644
index 0000000..395a678
--- /dev/null
+++ b/src/main/scala/ALSAlgorithm.scala
@@ -0,0 +1,235 @@
+package org.template.ecommercerecommendation
+
+import io.prediction.controller.P2LAlgorithm
+import io.prediction.controller.Params
+import io.prediction.data.storage.BiMap
+
+import org.apache.spark.SparkContext
+import org.apache.spark.SparkContext._
+import org.apache.spark.mllib.recommendation.ALS
+import org.apache.spark.mllib.recommendation.{Rating => MLlibRating}
+
+import grizzled.slf4j.Logger
+
+import scala.collection.mutable.PriorityQueue
+
+import scala.collection.immutable.HashMap
+
+case class ALSAlgorithmParams(
+ rank: Int,
+ numIterations: Int,
+ lambda: Double,
+ seed: Option[Long]) extends Params
+
+class ALSModel(
+ val rank: Int,
+ val userFeatures: Map[Int, Array[Double]],
+ val productFeatures: Map[Int, Array[Double]],
+ val userStringIntMap: BiMap[String, Int],
+ val itemStringIntMap: BiMap[String, Int],
+ val items: Map[Int, Item]
+) extends Serializable {
+
+ @transient lazy val itemIntStringMap = itemStringIntMap.inverse
+
+ override def toString = {
+ s" rank: ${rank}" +
+ s" userFeatures: [${userFeatures.size}]" +
+ s"(${userFeatures.take(2).toList}...)" +
+ s" productFeatures: [${productFeatures.size}]" +
+ s"(${productFeatures.take(2).toList}...)" +
+ s" userStringIntMap: [${userStringIntMap.size}]" +
+ s"(${userStringIntMap.take(2).toString}...)]" +
+ s" itemStringIntMap: [${itemStringIntMap.size}]" +
+ s"(${itemStringIntMap.take(2).toString}...)]" +
+ s" items: [${items.size}]" +
+ s"(${items.take(2).toString}...)]"
+ }
+}
+
+/**
+ * Use ALS to build item x feature matrix
+ */
+class ALSAlgorithm(val ap: ALSAlgorithmParams)
+ extends P2LAlgorithm[PreparedData, ALSModel, Query, PredictedResult] {
+
+ @transient lazy val logger = Logger[this.type]
+
+ def train(data: PreparedData): ALSModel = {
+ require(!data.viewEvents.take(1).isEmpty,
+ s"viewEvents in PreparedData cannot be empty." +
+ " Please check if DataSource generates TrainingData" +
+ " and Preprator generates PreparedData correctly.")
+ require(!data.users.take(1).isEmpty,
+ s"users in PreparedData cannot be empty." +
+ " Please check if DataSource generates TrainingData" +
+ " and Preprator generates PreparedData correctly.")
+ require(!data.items.take(1).isEmpty,
+ s"items in PreparedData cannot be empty." +
+ " Please check if DataSource generates TrainingData" +
+ " and Preprator generates PreparedData correctly.")
+ // create User and item's String ID to integer index BiMap
+ val userStringIntMap = BiMap.stringInt(data.users.keys)
+ val itemStringIntMap = BiMap.stringInt(data.items.keys)
+
+ // collect Item as Map and convert ID to Int index
+ val items: Map[Int, Item] = data.items.map { case (id, item) =>
+ (itemStringIntMap(id), item)
+ }.collectAsMap.toMap
+
+ val mllibRatings = data.viewEvents
+ .map { r =>
+ // Convert user and item String IDs to Int index for MLlib
+ val uindex = userStringIntMap.getOrElse(r.user, -1)
+ val iindex = itemStringIntMap.getOrElse(r.item, -1)
+
+ if (uindex == -1)
+ logger.info(s"Couldn't convert nonexistent user ID ${r.user}"
+ + " to Int index.")
+
+ if (iindex == -1)
+ logger.info(s"Couldn't convert nonexistent item ID ${r.item}"
+ + " to Int index.")
+
+ ((uindex, iindex), 1)
+ }.filter { case ((u, i), v) =>
+ // keep events with valid user and item index
+ (u != -1) && (i != -1)
+ }.reduceByKey(_ + _) // aggregate all view events of same user-item pair
+ .map { case ((u, i), v) =>
+ // MLlibRating requires integer index for user and item
+ MLlibRating(u, i, v)
+ }.cache()
+
+ // MLLib ALS cannot handle empty training data.
+ require(!mllibRatings.take(1).isEmpty,
+ s"mllibRatings cannot be empty." +
+ " Please check if your events contain valid user and item ID.")
+
+ // seed for MLlib ALS
+ val seed = ap.seed.getOrElse(System.nanoTime)
+
+ val m = ALS.trainImplicit(
+ ratings = mllibRatings,
+ rank = ap.rank,
+ iterations = ap.numIterations,
+ lambda = ap.lambda,
+ blocks = -1,
+ alpha = 1.0,
+ seed = seed)
+
+ new ALSModel(
+ rank = m.rank,
+ userFeatures = m.userFeatures.collectAsMap.toMap,
+ productFeatures = m.productFeatures.collectAsMap.toMap,
+ userStringIntMap = userStringIntMap,
+ itemStringIntMap = itemStringIntMap,
+ items = items
+ )
+ }
+
+ def predict(model: ALSModel, query: Query): PredictedResult = {
+
+ val userFeatures = model.userFeatures
+ val productFeatures = model.productFeatures
+ val items = model.items
+
+ val whiteList: Option[Set[Int]] = query.whiteList.map( set =>
+ set.map(model.itemStringIntMap.get(_)).flatten
+ )
+ val blackList: Option[Set[Int]] = query.blackList.map ( set =>
+ set.map(model.itemStringIntMap.get(_)).flatten
+ )
+
+ val indexScores: Map[Int, Double] =
+ model.userStringIntMap.get(query.user).map { userIndex =>
+ userFeatures.get(userIndex)
+ }
+ // flatten Option[Option[Array[Double]]] to Option[Array[Double]]
+ .flatten
+ .map { uf =>
+ productFeatures.par // convert to parallel collection
+ .filter { case (i, v) =>
+ isCandidateItem(
+ i = i,
+ items = items,
+ categories = query.categories,
+ whiteList = whiteList,
+ blackList = blackList
+ )
+ }
+ .map { case (i, f) =>
+ (i, dotProduct(uf, f))
+ }
+ .seq // convert back to sequential collection
+
+ }.getOrElse{
+ logger.info(s"No userFeature found for user ${query.user}.")
+ Map[Int, Double]()
+ }
+
+ val ord = Ordering.by[(Int, Double), Double](_._2).reverse
+ val topScores = getTopN(indexScores, query.num)(ord).toArray
+
+ val itemScores = topScores.map { case (i, s) =>
+ new ItemScore(
+ item = model.itemIntStringMap(i),
+ score = s
+ )
+ }
+
+ new PredictedResult(itemScores)
+ }
+
+ private
+ def getTopN[T](s: Iterable[T], n: Int)(implicit ord: Ordering[T]): Seq[T] = {
+
+ val q = PriorityQueue()
+
+ for (x <- s) {
+ if (q.size < n)
+ q.enqueue(x)
+ else {
+ // q is full
+ if (ord.compare(x, q.head) < 0) {
+ q.dequeue()
+ q.enqueue(x)
+ }
+ }
+ }
+
+ q.dequeueAll.toSeq.reverse
+ }
+
+ private
+ def dotProduct(v1: Array[Double], v2: Array[Double]): Double = {
+ val size = v1.size
+ var i = 0
+ var d: Double = 0
+ while (i < size) {
+ d += v1(i) * v2(i)
+ i += 1
+ }
+ d
+ }
+
+ private
+ def isCandidateItem(
+ i: Int,
+ items: Map[Int, Item],
+ categories: Option[Set[String]],
+ whiteList: Option[Set[Int]],
+ blackList: Option[Set[Int]]
+ ): Boolean = {
+ whiteList.map(_.contains(i)).getOrElse(true) &&
+ blackList.map(!_.contains(i)).getOrElse(true) &&
+ // filter categories
+ categories.map { cat =>
+ items(i).categories.map { itemCat =>
+ // keep this item if has ovelap categories with the query
+ !(itemCat.toSet.intersect(cat).isEmpty)
+ }.getOrElse(false) // discard this item if it has no categories
+ }.getOrElse(true)
+ }
+
+}
diff --git a/src/main/scala/DataSource.scala b/src/main/scala/DataSource.scala
new file mode 100644
index 0000000..c102b72
--- /dev/null
+++ b/src/main/scala/DataSource.scala
@@ -0,0 +1,114 @@
+package org.template.ecommercerecommendation
+
+import io.prediction.controller.PDataSource
+import io.prediction.controller.EmptyEvaluationInfo
+import io.prediction.controller.EmptyActualResult
+import io.prediction.controller.Params
+import io.prediction.data.storage.Event
+import io.prediction.data.storage.Storage
+
+import org.apache.spark.SparkContext
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd.RDD
+
+import grizzled.slf4j.Logger
+
+case class DataSourceParams(appId: Int) extends Params
+
+class DataSource(val dsp: DataSourceParams)
+ extends PDataSource[TrainingData,
+ EmptyEvaluationInfo, Query, EmptyActualResult] {
+
+ @transient lazy val logger = Logger[this.type]
+
+ override
+ def readTraining(sc: SparkContext): TrainingData = {
+ val eventsDb = Storage.getPEvents()
+
+ // create a RDD of (entityID, User)
+ val usersRDD: RDD[(String, User)] = eventsDb.aggregateProperties(
+ appId = dsp.appId,
+ entityType = "user"
+ )(sc).map { case (entityId, properties) =>
+ val user = try {
+ User()
+ } catch {
+ case e: Exception => {
+ logger.error(s"Failed to get properties ${properties} of" +
+ s" user ${entityId}. Exception: ${e}.")
+ throw e
+ }
+ }
+ (entityId, user)
+ }.cache()
+
+ // create a RDD of (entityID, Item)
+ val itemsRDD: RDD[(String, Item)] = eventsDb.aggregateProperties(
+ appId = dsp.appId,
+ entityType = "item"
+ )(sc).map { case (entityId, properties) =>
+ val item = try {
+ // Assume categories is optional property of item.
+ Item(categories = properties.getOpt[List[String]]("categories"))
+ } catch {
+ case e: Exception => {
+ logger.error(s"Failed to get properties ${properties} of" +
+ s" item ${entityId}. Exception: ${e}.")
+ throw e
+ }
+ }
+ (entityId, item)
+ }.cache()
+
+ // get all "user" "view" "item" events
+ val viewEventsRDD: RDD[ViewEvent] = eventsDb.find(
+ appId = dsp.appId,
+ entityType = Some("user"),
+ eventNames = Some(List("view")),
+ // targetEntityType is optional field of an event.
+ targetEntityType = Some(Some("item")))(sc)
+ // eventsDb.find() returns RDD[Event]
+ .map { event =>
+ val viewEvent = try {
+ event.event match {
+ case "view" => ViewEvent(
+ user = event.entityId,
+ item = event.targetEntityId.get,
+ t = event.eventTime.getMillis)
+ case _ => throw new Exception(s"Unexpected event ${event} is read.")
+ }
+ } catch {
+ case e: Exception => {
+ logger.error(s"Cannot convert ${event} to ViewEvent." +
+ s" Exception: ${e}.")
+ throw e
+ }
+ }
+ viewEvent
+ }.cache()
+
+ new TrainingData(
+ users = usersRDD,
+ items = itemsRDD,
+ viewEvents = viewEventsRDD
+ )
+ }
+}
+
+case class User()
+
+case class Item(categories: Option[List[String]])
+
+case class ViewEvent(user: String, item: String, t: Long)
+
+class TrainingData(
+ val users: RDD[(String, User)],
+ val items: RDD[(String, Item)],
+ val viewEvents: RDD[ViewEvent]
+) extends Serializable {
+ override def toString = {
+ s"users: [${users.count()} (${users.take(2).toList}...)]" +
+ s"items: [${items.count()} (${items.take(2).toList}...)]" +
+ s"viewEvents: [${viewEvents.count()}] (${viewEvents.take(2).toList}...)"
+ }
+}
diff --git a/src/main/scala/Engine.scala b/src/main/scala/Engine.scala
new file mode 100644
index 0000000..42ec4d4
--- /dev/null
+++ b/src/main/scala/Engine.scala
@@ -0,0 +1,31 @@
+package org.template.ecommercerecommendation
+
+import io.prediction.controller.IEngineFactory
+import io.prediction.controller.Engine
+
+case class Query(
+ user: String,
+ num: Int,
+ categories: Option[Set[String]],
+ whiteList: Option[Set[String]],
+ blackList: Option[Set[String]]
+) extends Serializable
+
+case class PredictedResult(
+ itemScores: Array[ItemScore]
+) extends Serializable
+
+case class ItemScore(
+ item: String,
+ score: Double
+) extends Serializable
+
+object ECommerceRecommendationEngine extends IEngineFactory {
+ def apply() = {
+ new Engine(
+ classOf[DataSource],
+ classOf[Preparator],
+ Map("als" -> classOf[ALSAlgorithm]),
+ classOf[Serving])
+ }
+}
diff --git a/src/main/scala/Preparator.scala b/src/main/scala/Preparator.scala
new file mode 100644
index 0000000..4dd45cf
--- /dev/null
+++ b/src/main/scala/Preparator.scala
@@ -0,0 +1,24 @@
+package org.template.ecommercerecommendation
+
+import io.prediction.controller.PPreparator
+
+import org.apache.spark.SparkContext
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd.RDD
+
+class Preparator
+ extends PPreparator[TrainingData, PreparedData] {
+
+ def prepare(sc: SparkContext, trainingData: TrainingData): PreparedData = {
+ new PreparedData(
+ users = trainingData.users,
+ items = trainingData.items,
+ viewEvents = trainingData.viewEvents)
+ }
+}
+
+class PreparedData(
+ val users: RDD[(String, User)],
+ val items: RDD[(String, Item)],
+ val viewEvents: RDD[ViewEvent]
+) extends Serializable
diff --git a/src/main/scala/Serving.scala b/src/main/scala/Serving.scala
new file mode 100644
index 0000000..21cf2df
--- /dev/null
+++ b/src/main/scala/Serving.scala
@@ -0,0 +1,13 @@
+package org.template.ecommercerecommendation
+
+import io.prediction.controller.LServing
+
+class Serving
+ extends LServing[Query, PredictedResult] {
+
+ override
+ def serve(query: Query,
+ predictedResults: Seq[PredictedResult]): PredictedResult = {
+ predictedResults.head
+ }
+}