use RDD in model. custom persistence logic. use rdd in some prediction calculation
diff --git a/src/main/java/org/template/recommendation/Algorithm.java b/src/main/java/org/template/recommendation/Algorithm.java
index b7abb25..0fd4c59 100644
--- a/src/main/java/org/template/recommendation/Algorithm.java
+++ b/src/main/java/org/template/recommendation/Algorithm.java
@@ -44,37 +44,48 @@
TrainingData data = preparedData.getTrainingData();
// user stuff
- JavaPairRDD<String, Long> userIndexRDD = data.getUsers().map(new Function<Tuple2<String, User>, String>() {
+ JavaPairRDD<String, Integer> userIndexRDD = data.getUsers().map(new Function<Tuple2<String, User>, String>() {
@Override
public String call(Tuple2<String, User> idUser) throws Exception {
return idUser._1();
}
- }).zipWithIndex();
- final Map<String, Long> userIndexMap = userIndexRDD.collectAsMap();
+ }).zipWithIndex().mapToPair(new PairFunction<Tuple2<String, Long>, String, Integer>() {
+ @Override
+ public Tuple2<String, Integer> call(Tuple2<String, Long> element) throws Exception {
+ return new Tuple2<>(element._1(), element._2().intValue());
+ }
+ });
+ final Map<String, Integer> userIndexMap = userIndexRDD.collectAsMap();
// item stuff
- JavaPairRDD<String, Long> itemIndexRDD = data.getItems().map(new Function<Tuple2<String, Item>, String>() {
+ JavaPairRDD<String, Integer> itemIndexRDD = data.getItems().map(new Function<Tuple2<String, Item>, String>() {
@Override
public String call(Tuple2<String, Item> idItem) throws Exception {
return idItem._1();
}
- }).zipWithIndex();
- final Map<String, Long> itemIndexMap = itemIndexRDD.collectAsMap();
- final Map<Long, String> indexItemMap = itemIndexRDD.mapToPair(new PairFunction<Tuple2<String, Long>, Long, String>() {
+ }).zipWithIndex().mapToPair(new PairFunction<Tuple2<String, Long>, String, Integer>() {
@Override
- public Tuple2<Long, String> call(Tuple2<String, Long> element) throws Exception {
+ public Tuple2<String, Integer> call(Tuple2<String, Long> element) throws Exception {
+ return new Tuple2<>(element._1(), element._2().intValue());
+ }
+ });
+ final Map<String, Integer> itemIndexMap = itemIndexRDD.collectAsMap();
+ JavaPairRDD<Integer, String> indexItemRDD = itemIndexRDD.mapToPair(new PairFunction<Tuple2<String, Integer>, Integer, String>() {
+ @Override
+ public Tuple2<Integer, String> call(Tuple2<String, Integer> element) throws Exception {
return element.swap();
}
- }).collectAsMap();
+ });
+ final Map<Integer, String> indexItemMap = indexItemRDD.collectAsMap();
// ratings stuff
JavaRDD<Rating> ratings = data.getViewEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() {
@Override
public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent viewEvent) throws Exception {
- Long userIndex = userIndexMap.get(viewEvent.getUser());
- Long itemIndex = itemIndexMap.get(viewEvent.getItem());
+ Integer userIndex = userIndexMap.get(viewEvent.getUser());
+ Integer itemIndex = itemIndexMap.get(viewEvent.getItem());
- return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex.intValue(), itemIndex.intValue()), 1);
+ return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1);
}
}).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() {
@Override
@@ -98,17 +109,27 @@
// MLlib ALS stuff
MatrixFactorizationModel matrixFactorizationModel = ALS.trainImplicit(JavaRDD.toRDD(ratings), ap.getRank(), ap.getIteration(), ap.getLambda(), -1, 1.0, ap.getSeed());
- Map<Object, double[]> userFeatures = JavaPairRDD.fromJavaRDD(matrixFactorizationModel.userFeatures().toJavaRDD()).collectAsMap();
- Map<Object, double[]> productFeatures = JavaPairRDD.fromJavaRDD(matrixFactorizationModel.productFeatures().toJavaRDD()).collectAsMap();
+ JavaPairRDD<Integer, double[]> userFeatures = matrixFactorizationModel.userFeatures().toJavaRDD().mapToPair(new PairFunction<Tuple2<Object, double[]>, Integer, double[]>() {
+ @Override
+ public Tuple2<Integer, double[]> call(Tuple2<Object, double[]> element) throws Exception {
+ return new Tuple2<>((Integer) element._1(), element._2());
+ }
+ });
+ JavaPairRDD<Integer, double[]> productFeaturesRDD = matrixFactorizationModel.productFeatures().toJavaRDD().mapToPair(new PairFunction<Tuple2<Object, double[]>, Integer, double[]>() {
+ @Override
+ public Tuple2<Integer, double[]> call(Tuple2<Object, double[]> element) throws Exception {
+ return new Tuple2<>((Integer) element._1(), element._2());
+ }
+ });
// popularity scores
JavaRDD<ItemScore> itemPopularityScore = data.getBuyEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() {
@Override
public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent buyEvent) throws Exception {
- Long userIndex = userIndexMap.get(buyEvent.getUser());
- Long itemIndex = itemIndexMap.get(buyEvent.getItem());
+ Integer userIndex = userIndexMap.get(buyEvent.getUser());
+ Integer itemIndex = itemIndexMap.get(buyEvent.getItem());
- return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex.intValue(), itemIndex.intValue()), 1);
+ return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1);
}
}).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() {
@Override
@@ -123,31 +144,39 @@
}).map(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, ItemScore>() {
@Override
public ItemScore call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception {
- return new ItemScore(indexItemMap.get(element._1()._2().longValue()), element._2().doubleValue());
+ return new ItemScore(indexItemMap.get(element._1()._2()), element._2().doubleValue());
}
});
- return new Model(userFeatures, productFeatures, userIndexRDD, indexItemMap, itemIndexRDD, itemPopularityScore, data.getItems());
+ JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = indexItemRDD.join(productFeaturesRDD);
+
+ return new Model(userFeatures, indexItemFeatures, userIndexRDD, itemIndexRDD, itemPopularityScore, data.getItems());
}
@Override
public PredictedResult predict(Model model, final Query query) {
- JavaPairRDD<String, Long> matchedUser = model.getUserIndexRDD().filter(new Function<Tuple2<String, Long>, Boolean>() {
+ final JavaPairRDD<String, Integer> matchedUser = model.getUserIndex().filter(new Function<Tuple2<String, Integer>, Boolean>() {
@Override
- public Boolean call(Tuple2<String, Long> userIndex) throws Exception {
+ public Boolean call(Tuple2<String, Integer> userIndex) throws Exception {
return userIndex._1().equals(query.getUserEntityId());
}
});
double[] userFeature = null;
if (!matchedUser.isEmpty()) {
- userFeature = model.getUserFeatures().get(matchedUser.first()._2());
+ final Integer matchedUserIndex = matchedUser.first()._2();
+ userFeature = model.getUserFeatures().filter(new Function<Tuple2<Integer, double[]>, Boolean>() {
+ @Override
+ public Boolean call(Tuple2<Integer, double[]> element) throws Exception {
+ return element._1().equals(matchedUserIndex);
+ }
+ }).first()._2();
}
if (userFeature != null) {
return new PredictedResult(topItemsForUser(userFeature, model, query));
} else {
- List<double[]> recentProductFeatures = getRecentProductFeatures(query, model.getProductFeatures(), model.getItemIndexRDD());
+ List<double[]> recentProductFeatures = getRecentProductFeatures(query, model);
if (recentProductFeatures.isEmpty()) {
return new PredictedResult(mostPopularItems(model, query));
} else {
@@ -156,7 +185,7 @@
}
}
- private List<double[]> getRecentProductFeatures(Query query, Map<Object, double[]> productFeatures, JavaPairRDD<String, Long> itemIndexRDD) {
+ private List<double[]> getRecentProductFeatures(Query query, Model model) {
try {
List<double[]> result = new ArrayList<>();
@@ -177,15 +206,22 @@
for (final Event event : events) {
if (event.targetEntityId().isDefined()) {
- JavaPairRDD<String, Long> filtered = itemIndexRDD.filter(new Function<Tuple2<String, Long>, Boolean>() {
+ JavaPairRDD<String, Integer> filtered = model.getItemIndex().filter(new Function<Tuple2<String, Integer>, Boolean>() {
@Override
- public Boolean call(Tuple2<String, Long> element) throws Exception {
+ public Boolean call(Tuple2<String, Integer> element) throws Exception {
return element._1().equals(event.targetEntityId().get());
}
});
+ final Integer itemIndex = filtered.first()._2();
+
if (!filtered.isEmpty()) {
- result.add(productFeatures.get(filtered.first()._2().intValue()));
+ result.add(model.getIndexItemFeatures().filter(new Function<Tuple2<Integer, Tuple2<String, double[]>>, Boolean>() {
+ @Override
+ public Boolean call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
+ return itemIndex.equals(element._1());
+ }
+ }).first()._2()._2());
}
}
}
@@ -198,14 +234,14 @@
}
private List<ItemScore> topItemsForUser(double[] userFeature, Model model, Query query) {
- DoubleMatrix userMatrix = new DoubleMatrix(userFeature);
- Set<Object> keys = model.getProductFeatures().keySet();
- List<ItemScore> itemScores = new ArrayList<>(model.getProductFeatures().size());
- for (Object key : keys) {
- long longKey = ((Integer) key).longValue();
- double score = userMatrix.dot(new DoubleMatrix(model.getProductFeatures().get(key)));
- itemScores.add(new ItemScore(model.getIndexItemMap().get(longKey), score));
- }
+ final DoubleMatrix userMatrix = new DoubleMatrix(userFeature);
+
+ List<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() {
+ @Override
+ public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
+ return new ItemScore(element._2()._1(), userMatrix.dot(new DoubleMatrix(element._2()._2())));
+ }
+ }).collect();
itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems());
Collections.sort(itemScores, Collections.reverseOrder());
@@ -213,18 +249,18 @@
return itemScores.subList(0, Math.min(query.getNumber(), itemScores.size()));
}
- private List<ItemScore> similarItems(List<double[]> recentProductFeatures, Model model, Query query) {
- Set<Object> productKeys = model.getProductFeatures().keySet();
- List<ItemScore> itemScores = new ArrayList<>(model.getProductFeatures().size());
- for (Object key : productKeys) {
- long longKey = ((Integer) key).longValue();
- double[] feature = model.getProductFeatures().get(key);
- double similarity = 0.0;
- for (double[] recentFeature : recentProductFeatures) {
- similarity += cosineSimilarity(feature, recentFeature);
+ private List<ItemScore> similarItems(final List<double[]> recentProductFeatures, Model model, Query query) {
+ List<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() {
+ @Override
+ public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
+ double similarity = 0.0;
+ for (double[] recentFeature : recentProductFeatures) {
+ similarity += cosineSimilarity(element._2()._2(), recentFeature);
+ }
+
+ return new ItemScore(element._2()._1(), similarity);
}
- itemScores.add(new ItemScore(model.getIndexItemMap().get(longKey), similarity));
- }
+ }).collect();
itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems());
Collections.sort(itemScores, Collections.reverseOrder());
diff --git a/src/main/java/org/template/recommendation/Model.java b/src/main/java/org/template/recommendation/Model.java
index 13b09c4..d9bf1dc 100644
--- a/src/main/java/org/template/recommendation/Model.java
+++ b/src/main/java/org/template/recommendation/Model.java
@@ -5,50 +5,45 @@
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import scala.Tuple2;
import java.io.Serializable;
-import java.util.Map;
public class Model implements Serializable, PersistentModel<AlgorithmParams> {
private static final Logger logger = LoggerFactory.getLogger(Model.class);
- private final Map<Object, double[]> userFeatures;
- private final Map<Object, double[]> productFeatures;
- private final JavaPairRDD<String, Long> userIndexRDD;
- private final Map<Long, String> indexItemMap;
- private final JavaPairRDD<String, Long> itemIndexRDD;
+ private final JavaPairRDD<Integer, double[]> userFeatures;
+ private final JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures;
+ private final JavaPairRDD<String, Integer> userIndex;
+ private final JavaPairRDD<String, Integer> itemIndex;
private final JavaRDD<ItemScore> itemPopularityScore;
private final JavaPairRDD<String, Item> items;
- public Model(Map<Object, double[]> userFeatures, Map<Object, double[]> productFeatures, JavaPairRDD<String, Long> userIndexRDD, Map<Long, String> itemIndexMap, JavaPairRDD<String, Long> itemIndexRDD, JavaRDD<ItemScore> itemPopularityScore, JavaPairRDD<String, Item> items) {
+ public Model(JavaPairRDD<Integer, double[]> userFeatures, JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures, JavaPairRDD<String, Integer> userIndex, JavaPairRDD<String, Integer> itemIndex, JavaRDD<ItemScore> itemPopularityScore, JavaPairRDD<String, Item> items) {
this.userFeatures = userFeatures;
- this.productFeatures = productFeatures;
- this.userIndexRDD = userIndexRDD;
- this.indexItemMap = itemIndexMap;
- this.itemIndexRDD = itemIndexRDD;
+ this.indexItemFeatures = indexItemFeatures;
+ this.userIndex = userIndex;
+ this.itemIndex = itemIndex;
this.itemPopularityScore = itemPopularityScore;
this.items = items;
}
- public Map<Object, double[]> getUserFeatures() {
+ public JavaPairRDD<Integer, double[]> getUserFeatures() {
return userFeatures;
}
- public Map<Object, double[]> getProductFeatures() {
- return productFeatures;
+ public JavaPairRDD<Integer, Tuple2<String, double[]>> getIndexItemFeatures() {
+ return indexItemFeatures;
}
- public JavaPairRDD<String, Long> getUserIndexRDD() {
- return userIndexRDD;
+ public JavaPairRDD<String, Integer> getUserIndex() {
+ return userIndex;
}
- public Map<Long, String> getIndexItemMap() {
- return indexItemMap;
- }
-
- public JavaPairRDD<String, Long> getItemIndexRDD() {
- return itemIndexRDD;
+ public JavaPairRDD<String, Integer> getItemIndex() {
+ return itemIndex;
}
public JavaRDD<ItemScore> getItemPopularityScore() {
@@ -61,14 +56,27 @@
@Override
public boolean save(String id, AlgorithmParams params, SparkContext sc) {
+ userFeatures.saveAsObjectFile("/tmp/" + id + "/userFeatures");
+ indexItemFeatures.saveAsObjectFile("/tmp/" + id + "/indexItemFeatures");
+ userIndex.saveAsObjectFile("/tmp/" + id + "/userIndex");
+ itemIndex.saveAsObjectFile("/tmp/" + id + "/itemIndex");
+ itemPopularityScore.saveAsObjectFile("/tmp/" + id + "/itemPopularityScore");
+ items.saveAsObjectFile("/tmp/" + id + "/items");
- logger.info("saved model");
- return false;
+ logger.info("Saved model to /tmp/" + id);
+ return true;
}
public static Model load(String id, Params params, SparkContext sc) {
+ JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc);
+ JavaPairRDD<Integer, double[]> userFeatures = JavaPairRDD.fromJavaRDD(jsc.<Tuple2<Integer, double[]>>objectFile("/tmp/" + id + "/userFeatures"));
+ JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = JavaPairRDD.fromJavaRDD(jsc.<Tuple2<Integer, Tuple2<String, double[]>>>objectFile("/tmp/" + id + "/indexItemFeatures"));
+ JavaPairRDD<String, Integer> userIndex = JavaPairRDD.fromJavaRDD(jsc.<Tuple2<String, Integer>>objectFile("/tmp/" + id + "/userIndex"));
+ JavaPairRDD<String, Integer> itemIndex = JavaPairRDD.fromJavaRDD(jsc.<Tuple2<String, Integer>>objectFile("/tmp/" + id + "/itemIndex"));
+ JavaRDD<ItemScore> itemPopularityScore = jsc.objectFile("/tmp/" + id + "/itemPopularityScore");
+ JavaPairRDD<String, Item> items = JavaPairRDD.fromJavaRDD(jsc.<Tuple2<String, Item>>objectFile("/tmp/" + id + "/items"));
logger.info("loaded model");
- return null;
+ return new Model(userFeatures, indexItemFeatures, userIndex, itemIndex, itemPopularityScore, items);
}
}