Parallelize filtering valid items.
diff --git a/src/main/java/org/template/recommendation/Algorithm.java b/src/main/java/org/template/recommendation/Algorithm.java
index a9af2af..24b4e5c 100644
--- a/src/main/java/org/template/recommendation/Algorithm.java
+++ b/src/main/java/org/template/recommendation/Algorithm.java
@@ -157,7 +157,7 @@
JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = indexItemRDD.join(productFeaturesRDD);
- return new Model(userFeatures, indexItemFeatures, userIndexRDD, itemIndexRDD, itemPopularityScore, data.getItems());
+ return new Model(userFeatures, indexItemFeatures, userIndexRDD, itemIndexRDD, itemPopularityScore, data.getItems().collectAsMap());
}
@Override
@@ -260,21 +260,19 @@
private List<ItemScore> topItemsForUser(double[] userFeature, Model model, Query query) {
final DoubleMatrix userMatrix = new DoubleMatrix(userFeature);
- List<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() {
+ JavaRDD<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() {
@Override
public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
return new ItemScore(element._2()._1(), userMatrix.dot(new DoubleMatrix(element._2()._2())));
}
- }).collect();
+ });
itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
- Collections.sort(itemScores, Collections.reverseOrder());
-
- return new ArrayList<>(itemScores.subList(0, Math.min(query.getNumber(), itemScores.size())));
+ return sortAndTake(itemScores, query.getNumber());
}
private List<ItemScore> similarItems(final List<double[]> recentProductFeatures, Model model, Query query) {
- List<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() {
+ JavaRDD<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() {
@Override
public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
double similarity = 0.0;
@@ -284,18 +282,15 @@
return new ItemScore(element._2()._1(), similarity);
}
- }).collect();
+ });
itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
- Collections.sort(itemScores, Collections.reverseOrder());
-
- return new ArrayList<>(itemScores.subList(0, Math.min(query.getNumber(), itemScores.size())));
+ return sortAndTake(itemScores, query.getNumber());
}
private List<ItemScore> mostPopularItems(Model model, Query query) {
- List<ItemScore> itemScores = validScores(model.getItemPopularityScore().collect(), query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
- Collections.sort(itemScores, Collections.reverseOrder());
- return new ArrayList<>(itemScores.subList(0, Math.min(query.getNumber(), itemScores.size())));
+ JavaRDD<ItemScore> itemScores = validScores(model.getItemPopularityScore(), query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
+ return sortAndTake(itemScores, query.getNumber());
}
private double cosineSimilarity(double[] a, double[] b) {
@@ -305,31 +300,32 @@
return matrixA.dot(matrixB) / (matrixA.norm2() * matrixB.norm2());
}
- private List<ItemScore> validScores(List<ItemScore> all, Set<String> whitelist, Set<String> blacklist, Set<String> categories, JavaPairRDD<String, Item> items, String userEntityId) {
- List<ItemScore> result = new ArrayList<>();
- Set<String> seenItemEntityIds = seenItemEntityIds(userEntityId);
- Set<String> unavailableItemEntityIds = unavailableItemEntityIds();
- for (final ItemScore itemScore : all) {
- JavaPairRDD<String, Item> possibleItems = items.filter(new Function<Tuple2<String, Item>, Boolean>() {
- @Override
- public Boolean call(Tuple2<String, Item> element) throws Exception {
- return element._1().equals(itemScore.getItemEntityId());
- }
- });
+ private List<ItemScore> sortAndTake(JavaRDD<ItemScore> all, int number) {
+ return all.sortBy(new Function<ItemScore, Double>() {
+ @Override
+ public Double call(ItemScore itemScore) throws Exception {
+ return itemScore.getScore();
+ }
+ }, false, all.partitions().size()).take(number);
+ }
- if (!possibleItems.isEmpty()) {
- Item item = possibleItems.first()._2();
- if (passWhitelistCriteria(whitelist, item.getEntityId())
+ private JavaRDD<ItemScore> validScores(JavaRDD<ItemScore> all, final Set<String> whitelist, final Set<String> blacklist, final Set<String> categories, final Map<String, Item> items, String userEntityId) {
+ final Set<String> seenItemEntityIds = seenItemEntityIds(userEntityId);
+ final Set<String> unavailableItemEntityIds = unavailableItemEntityIds();
+
+ return all.filter(new Function<ItemScore, Boolean>() {
+ @Override
+ public Boolean call(ItemScore itemScore) throws Exception {
+ Item item = items.get(itemScore.getItemEntityId());
+
+ return (item != null
+ && passWhitelistCriteria(whitelist, item.getEntityId())
&& passBlacklistCriteria(blacklist, item.getEntityId())
&& passCategoryCriteria(categories, item)
&& passUnseenCriteria(seenItemEntityIds, item.getEntityId())
- && passAvailabilityCriteria(unavailableItemEntityIds, item.getEntityId())) {
- result.add(itemScore);
- }
+ && passAvailabilityCriteria(unavailableItemEntityIds, item.getEntityId()));
}
- }
-
- return result;
+ });
}
private boolean passWhitelistCriteria(Set<String> whitelist, String itemEntityId) {
diff --git a/src/main/java/org/template/recommendation/Model.java b/src/main/java/org/template/recommendation/Model.java
index bb8c541..129cb2a 100644
--- a/src/main/java/org/template/recommendation/Model.java
+++ b/src/main/java/org/template/recommendation/Model.java
@@ -11,6 +11,8 @@
import scala.Tuple2;
import java.io.Serializable;
+import java.util.Collections;
+import java.util.Map;
public class Model implements Serializable, PersistentModel<AlgorithmParams> {
private static final Logger logger = LoggerFactory.getLogger(Model.class);
@@ -19,9 +21,9 @@
private final JavaPairRDD<String, Integer> userIndex;
private final JavaPairRDD<String, Integer> itemIndex;
private final JavaRDD<ItemScore> itemPopularityScore;
- private final JavaPairRDD<String, Item> items;
+ private final Map<String, Item> items;
- public Model(JavaPairRDD<Integer, double[]> userFeatures, JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures, JavaPairRDD<String, Integer> userIndex, JavaPairRDD<String, Integer> itemIndex, JavaRDD<ItemScore> itemPopularityScore, JavaPairRDD<String, Item> items) {
+ public Model(JavaPairRDD<Integer, double[]> userFeatures, JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures, JavaPairRDD<String, Integer> userIndex, JavaPairRDD<String, Integer> itemIndex, JavaRDD<ItemScore> itemPopularityScore, Map<String, Item> items) {
this.userFeatures = userFeatures;
this.indexItemFeatures = indexItemFeatures;
this.userIndex = userIndex;
@@ -50,7 +52,7 @@
return itemPopularityScore;
}
- public JavaPairRDD<String, Item> getItems() {
+ public Map<String, Item> getItems() {
return items;
}
@@ -61,7 +63,7 @@
userIndex.saveAsObjectFile("/tmp/" + id + "/userIndex");
itemIndex.saveAsObjectFile("/tmp/" + id + "/itemIndex");
itemPopularityScore.saveAsObjectFile("/tmp/" + id + "/itemPopularityScore");
- items.saveAsObjectFile("/tmp/" + id + "/items");
+ new JavaSparkContext(sc).parallelize(Collections.singletonList(items)).saveAsObjectFile("/tmp/" + id + "/items");
logger.info("Saved model to /tmp/" + id);
return true;
@@ -74,7 +76,7 @@
JavaPairRDD<String, Integer> userIndex = JavaPairRDD.<String, Integer>fromJavaRDD(jsc.<Tuple2<String, Integer>>objectFile("/tmp/" + id + "/userIndex"));
JavaPairRDD<String, Integer> itemIndex = JavaPairRDD.<String, Integer>fromJavaRDD(jsc.<Tuple2<String, Integer>>objectFile("/tmp/" + id + "/itemIndex"));
JavaRDD<ItemScore> itemPopularityScore = jsc.objectFile("/tmp/" + id + "/itemPopularityScore");
- JavaPairRDD<String, Item> items = JavaPairRDD.<String, Item>fromJavaRDD(jsc.<Tuple2<String, Item>>objectFile("/tmp/" + id + "/items"));
+ Map<String, Item> items = jsc.<Map<String, Item>>objectFile("/tmp/" + id + "/items").collect().get(0);
logger.info("loaded model");
return new Model(userFeatures, indexItemFeatures, userIndex, itemIndex, itemPopularityScore, items);