result of similar items
diff --git a/src/main/java/org/template/recommendation/Algorithm.java b/src/main/java/org/template/recommendation/Algorithm.java
index 42d20df..d348b39 100644
--- a/src/main/java/org/template/recommendation/Algorithm.java
+++ b/src/main/java/org/template/recommendation/Algorithm.java
@@ -121,7 +121,7 @@
if (recentProductFeatures.isEmpty()) {
return new PredictedResult(mostPopularItems());
} else {
- return new PredictedResult(similarItems(recentProductFeatures));
+ return new PredictedResult(similarItems(recentProductFeatures, model.getProductFeatures(), model.getIndexItemMap(), query.getNumber()));
}
}
}
@@ -177,20 +177,38 @@
itemScores.add(new ItemScore(indexItemMap.get(longKey), score));
}
- Collections.sort(itemScores);
+ Collections.sort(itemScores, Collections.reverseOrder());
return itemScores.subList(0, Math.min(number, itemScores.size()));
}
- private List<ItemScore> similarItems(List<double[]> recentProductFeatures) {
+ private List<ItemScore> similarItems(List<double[]> recentProductFeatures, Map<Object, double[]> productFeatures, Map<Long, String> indexItemMap, int number) {
+ Set<Object> productKeys = productFeatures.keySet();
+ List<ItemScore> itemScores = new ArrayList<>(productFeatures.size());
+ for (Object key : productKeys) {
+ long longKey = ((Integer)key).longValue();
+ double[] feature = productFeatures.get(key);
+ double similarity = 0.0;
+ for (double[] recentFeature : recentProductFeatures) {
+ similarity += cosineSimilarity(feature, recentFeature);
+ }
+ itemScores.add(new ItemScore(indexItemMap.get(longKey), similarity));
+ }
+ Collections.sort(itemScores, Collections.reverseOrder());
-
- return Collections.emptyList();
+ return itemScores.subList(0, Math.min(number, itemScores.size()));
}
private List<ItemScore> mostPopularItems() {
return Collections.emptyList();
}
+ private double cosineSimilarity(double[] a, double[] b) {
+ DoubleMatrix matrixA = new DoubleMatrix(a);
+ DoubleMatrix matrixB = new DoubleMatrix(b);
+
+ return matrixA.dot(matrixB) / (matrixA.norm2() * matrixB.norm2());
+ }
+
}