Merge branch 'develop' for v0.1.1
diff --git a/.gitignore b/.gitignore
index ea4e89d..64fa18b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
-data/sample_movielens_data.txt
manifest.json
+target/
+pio.log
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..9044594
--- /dev/null
+++ b/README.md
@@ -0,0 +1,90 @@
+# Similar Product Template
+
+## Documentation
+
+Please refer to http://docs.prediction.io/templates/similarproduct/quickstart/
+
+## Versions
+
+### develop
+
+### v0.1.1
+
+- Persist RDD to memory (.cache()) in DataSource for better performance
+- Use local model for faster serving.
+
+### v0.1.0
+
+- initial version
+
+
+## Development Notes
+
+### import sample data
+
+```
+$ python data/import_eventserver.py --access_key <your_access_key>
+```
+
+### query
+
+normal:
+
+```
+curl -H "Content-Type: application/json" \
+-d '{ "items": ["i1", "i3", "i10", "i2", "i5", "i31", "i9"], "num": 10}' \
+http://localhost:8000/queries.json \
+-w %{time_connect}:%{time_starttransfer}:%{time_total}
+```
+
+```
+curl -H "Content-Type: application/json" \
+-d '{
+ "items": ["i1", "i3", "i10", "i2", "i5", "i31", "i9"],
+ "num": 10,
+ "categories" : ["c4", "c3"]
+}' \
+http://localhost:8000/queries.json \
+-w %{time_connect}:%{time_starttransfer}:%{time_total}
+```
+
+```
+curl -H "Content-Type: application/json" \
+-d '{
+ "items": ["i1", "i3", "i10", "i2", "i5", "i31", "i9"],
+ "num": 10,
+ "whiteList": ["i21", "i26", "i40"]
+}' \
+http://localhost:8000/queries.json \
+-w %{time_connect}:%{time_starttransfer}:%{time_total}
+```
+
+```
+curl -H "Content-Type: application/json" \
+-d '{
+ "items": ["i1", "i3", "i10", "i2", "i5", "i31", "i9"],
+ "num": 10,
+ "blackList": ["i21", "i26", "i40"]
+}' \
+http://localhost:8000/queries.json \
+-w %{time_connect}:%{time_starttransfer}:%{time_total}
+```
+
+unknown item:
+
+```
+curl -H "Content-Type: application/json" \
+-d '{ "items": ["unk1", "i3", "i10", "i2", "i5", "i31", "i9"], "num": 10}' \
+http://localhost:8000/queries.json \
+-w %{time_connect}:%{time_starttransfer}:%{time_total}
+```
+
+
+all unknown items:
+
+```
+curl -H "Content-Type: application/json" \
+-d '{ "items": ["unk1", "unk2", "unk3", "unk4"], "num": 10}' \
+http://localhost:8000/queries.json \
+-w %{time_connect}:%{time_starttransfer}:%{time_total}
+```
diff --git a/src/main/scala/ALSAlgorithm.scala b/src/main/scala/ALSAlgorithm.scala
index 977edd8..a4dfcbf 100644
--- a/src/main/scala/ALSAlgorithm.scala
+++ b/src/main/scala/ALSAlgorithm.scala
@@ -1,14 +1,11 @@
package org.template.similarproduct
-import io.prediction.controller.PAlgorithm
+import io.prediction.controller.P2LAlgorithm
import io.prediction.controller.Params
-import io.prediction.controller.IPersistentModel
-import io.prediction.controller.IPersistentModelLoader
import io.prediction.data.storage.BiMap
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
-import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.recommendation.ALS
import org.apache.spark.mllib.recommendation.{Rating => MLlibRating}
@@ -23,26 +20,15 @@
seed: Option[Long]) extends Params
class ALSModel(
- val productFeatures: RDD[(Int, Array[Double])],
+ val productFeatures: Map[Int, Array[Double]],
val itemStringIntMap: BiMap[String, Int],
val items: Map[Int, Item]
-) extends IPersistentModel[ALSAlgorithmParams] with Serializable {
+) extends Serializable {
@transient lazy val itemIntStringMap = itemStringIntMap.inverse
- def save(id: String, params: ALSAlgorithmParams,
- sc: SparkContext): Boolean = {
-
- productFeatures.saveAsObjectFile(s"/tmp/${id}/productFeatures")
- sc.parallelize(Seq(itemStringIntMap))
- .saveAsObjectFile(s"/tmp/${id}/itemStringIntMap")
- sc.parallelize(Seq(items))
- .saveAsObjectFile(s"/tmp/${id}/items")
- true
- }
-
override def toString = {
- s" productFeatures: [${productFeatures.count()}]" +
+ s" productFeatures: [${productFeatures.size}]" +
s"(${productFeatures.take(2).toList}...)" +
s" itemStringIntMap: [${itemStringIntMap.size}]" +
s"(${itemStringIntMap.take(2).toString}...)]" +
@@ -51,24 +37,11 @@
}
}
-object ALSModel
- extends IPersistentModelLoader[ALSAlgorithmParams, ALSModel] {
- def apply(id: String, params: ALSAlgorithmParams,
- sc: Option[SparkContext]) = {
- new ALSModel(
- productFeatures = sc.get.objectFile(s"/tmp/${id}/productFeatures"),
- itemStringIntMap = sc.get
- .objectFile[BiMap[String, Int]](s"/tmp/${id}/itemStringIntMap").first,
- items = sc.get
- .objectFile[Map[Int, Item]](s"/tmp/${id}/items").first)
- }
-}
-
/**
* Use ALS to build item x feature matrix
*/
class ALSAlgorithm(val ap: ALSAlgorithmParams)
- extends PAlgorithm[PreparedData, ALSModel, Query, PredictedResult] {
+ extends P2LAlgorithm[PreparedData, ALSModel, Query, PredictedResult] {
@transient lazy val logger = Logger[this.type]
@@ -136,7 +109,7 @@
seed = seed)
new ALSModel(
- productFeatures = m.productFeatures,
+ productFeatures = m.productFeatures.collectAsMap.toMap,
itemStringIntMap = itemStringIntMap,
items = items
)
@@ -144,17 +117,16 @@
def predict(model: ALSModel, query: Query): PredictedResult = {
+ val productFeatures = model.productFeatures
+
// convert items to Int index
val queryList: Set[Int] = query.items.map(model.itemStringIntMap.get(_))
.flatten.toSet
- val queryFeatures: Vector[Array[Double]] = queryList.toVector.par
- .map { item =>
- // productFeatures may not contain the requested item
- val qf: Option[Array[Double]] = model.productFeatures
- .lookup(item).headOption
- qf
- }.seq.flatten
+ val queryFeatures: Vector[Array[Double]] = queryList.toVector
+ // productFeatures may not contain the requested item
+ .map { item => productFeatures.get(item) }
+ .flatten
val whiteList: Option[Set[Int]] = query.whiteList.map( set =>
set.map(model.itemStringIntMap.get(_)).flatten
@@ -169,14 +141,15 @@
logger.info(s"No productFeatures vector for query items ${query.items}.")
Array[(Int, Double)]()
} else {
- model.productFeatures
+ productFeatures.par // convert to parallel collection
.mapValues { f =>
queryFeatures.map{ qf =>
cosine(qf, f)
}.reduce(_ + _)
}
.filter(_._2 > 0) // keep items with score > 0
- .collect()
+ .seq // convert back to sequential collection
+ .toArray
}
val filteredScore = indexScores.view.filter { case (i, v) =>
diff --git a/src/main/scala/DataSource.scala b/src/main/scala/DataSource.scala
index b45fe37..bea337d 100644
--- a/src/main/scala/DataSource.scala
+++ b/src/main/scala/DataSource.scala
@@ -40,7 +40,7 @@
}
}
(entityId, user)
- }
+ }.cache()
// create a RDD of (entityID, Item)
val itemsRDD: RDD[(String, Item)] = eventsDb.aggregateProperties(
@@ -58,7 +58,7 @@
}
}
(entityId, item)
- }
+ }.cache()
// get all "user" "view" "item" events
val viewEventsRDD: RDD[ViewEvent] = eventsDb.find(
@@ -85,7 +85,7 @@
}
}
viewEvent
- }
+ }.cache()
new TrainingData(
users = usersRDD,