blob: 349e9456bb56ad64de584452654085abf4d94560 [file] [log] [blame]
package org.example.recommendation;
import com.google.common.collect.Sets;
import org.apache.predictionio.controller.java.PJavaAlgorithm;
import org.apache.predictionio.data.storage.Event;
import org.apache.predictionio.data.store.java.LJavaEventStore;
import org.apache.predictionio.data.store.java.OptionHelper;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.recommendation.ALS;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import org.apache.spark.rdd.RDD;
import org.jblas.DoubleMatrix;
import org.joda.time.DateTime;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Option;
import scala.Tuple2;
import scala.concurrent.duration.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
public class Algorithm extends PJavaAlgorithm<PreparedData, Model, Query, PredictedResult> {
private static final Logger logger = LoggerFactory.getLogger(Algorithm.class);
private final AlgorithmParams ap;
public Algorithm(AlgorithmParams ap) {
this.ap = ap;
}
@Override
public Model train(SparkContext sc, PreparedData preparedData) {
TrainingData data = preparedData.getTrainingData();
// user stuff
JavaPairRDD<String, Integer> userIndexRDD = data.getUsers().map(new Function<Tuple2<String, User>, String>() {
@Override
public String call(Tuple2<String, User> idUser) throws Exception {
return idUser._1();
}
}).zipWithIndex().mapToPair(new PairFunction<Tuple2<String, Long>, String, Integer>() {
@Override
public Tuple2<String, Integer> call(Tuple2<String, Long> element) throws Exception {
return new Tuple2<>(element._1(), element._2().intValue());
}
});
final Map<String, Integer> userIndexMap = userIndexRDD.collectAsMap();
// item stuff
JavaPairRDD<String, Integer> itemIndexRDD = data.getItems().map(new Function<Tuple2<String, Item>, String>() {
@Override
public String call(Tuple2<String, Item> idItem) throws Exception {
return idItem._1();
}
}).zipWithIndex().mapToPair(new PairFunction<Tuple2<String, Long>, String, Integer>() {
@Override
public Tuple2<String, Integer> call(Tuple2<String, Long> element) throws Exception {
return new Tuple2<>(element._1(), element._2().intValue());
}
});
final Map<String, Integer> itemIndexMap = itemIndexRDD.collectAsMap();
JavaPairRDD<Integer, String> indexItemRDD = itemIndexRDD.mapToPair(new PairFunction<Tuple2<String, Integer>, Integer, String>() {
@Override
public Tuple2<Integer, String> call(Tuple2<String, Integer> element) throws Exception {
return element.swap();
}
});
final Map<Integer, String> indexItemMap = indexItemRDD.collectAsMap();
// ratings stuff
JavaRDD<Rating> ratings = data.getViewEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() {
@Override
public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent viewEvent) throws Exception {
Integer userIndex = userIndexMap.get(viewEvent.getUser());
Integer itemIndex = itemIndexMap.get(viewEvent.getItem());
return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1);
}
}).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() {
@Override
public Boolean call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception {
return (element != null);
}
}).reduceByKey(new Function2<Integer, Integer, Integer>() {
@Override
public Integer call(Integer integer, Integer integer2) throws Exception {
return integer + integer2;
}
}).map(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Rating>() {
@Override
public Rating call(Tuple2<Tuple2<Integer, Integer>, Integer> userItemCount) throws Exception {
return new Rating(userItemCount._1()._1(), userItemCount._1()._2(), userItemCount._2().doubleValue());
}
});
if (ratings.isEmpty())
throw new AssertionError("Please check if your events contain valid user and item ID.");
// MLlib ALS stuff
MatrixFactorizationModel matrixFactorizationModel = ALS.trainImplicit(JavaRDD.toRDD(ratings), ap.getRank(), ap.getIteration(), ap.getLambda(), -1, 1.0, ap.getSeed());
JavaPairRDD<Integer, double[]> userFeatures = matrixFactorizationModel.userFeatures().toJavaRDD().mapToPair(new PairFunction<Tuple2<Object, double[]>, Integer, double[]>() {
@Override
public Tuple2<Integer, double[]> call(Tuple2<Object, double[]> element) throws Exception {
return new Tuple2<>((Integer) element._1(), element._2());
}
});
JavaPairRDD<Integer, double[]> productFeaturesRDD = matrixFactorizationModel.productFeatures().toJavaRDD().mapToPair(new PairFunction<Tuple2<Object, double[]>, Integer, double[]>() {
@Override
public Tuple2<Integer, double[]> call(Tuple2<Object, double[]> element) throws Exception {
return new Tuple2<>((Integer) element._1(), element._2());
}
});
// popularity scores
JavaRDD<ItemScore> itemPopularityScore = data.getBuyEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() {
@Override
public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent buyEvent) throws Exception {
Integer userIndex = userIndexMap.get(buyEvent.getUser());
Integer itemIndex = itemIndexMap.get(buyEvent.getItem());
return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1);
}
}).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() {
@Override
public Boolean call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception {
return (element != null);
}
}).mapToPair(new PairFunction<Tuple2<Tuple2<Integer, Integer>, Integer>, Integer, Integer>() {
@Override
public Tuple2<Integer, Integer> call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception {
return new Tuple2<>(element._1()._2(), element._2());
}
}).reduceByKey(new Function2<Integer, Integer, Integer>() {
@Override
public Integer call(Integer integer, Integer integer2) throws Exception {
return integer + integer2;
}
}).map(new Function<Tuple2<Integer, Integer>, ItemScore>() {
@Override
public ItemScore call(Tuple2<Integer, Integer> element) throws Exception {
return new ItemScore(indexItemMap.get(element._1()), element._2().doubleValue());
}
});
JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = indexItemRDD.join(productFeaturesRDD);
return new Model(userFeatures, indexItemFeatures, userIndexRDD, itemIndexRDD, itemPopularityScore, data.getItems().collectAsMap());
}
@Override
public PredictedResult predict(Model model, final Query query) {
final JavaPairRDD<String, Integer> matchedUser = model.getUserIndex().filter(new Function<Tuple2<String, Integer>, Boolean>() {
@Override
public Boolean call(Tuple2<String, Integer> userIndex) throws Exception {
return userIndex._1().equals(query.getUserEntityId());
}
});
double[] userFeature = null;
if (!matchedUser.isEmpty()) {
final Integer matchedUserIndex = matchedUser.first()._2();
userFeature = model.getUserFeatures().filter(new Function<Tuple2<Integer, double[]>, Boolean>() {
@Override
public Boolean call(Tuple2<Integer, double[]> element) throws Exception {
return element._1().equals(matchedUserIndex);
}
}).first()._2();
}
if (userFeature != null) {
return new PredictedResult(topItemsForUser(userFeature, model, query));
} else {
List<double[]> recentProductFeatures = getRecentProductFeatures(query, model);
if (recentProductFeatures.isEmpty()) {
return new PredictedResult(mostPopularItems(model, query));
} else {
return new PredictedResult(similarItems(recentProductFeatures, model, query));
}
}
}
@Override
public RDD<Tuple2<Object, PredictedResult>> batchPredict(Model model, RDD<Tuple2<Object, Query>> qs) {
List<Tuple2<Object, Query>> indexQueries = qs.toJavaRDD().collect();
List<Tuple2<Object, PredictedResult>> results = new ArrayList<>();
for (Tuple2<Object, Query> indexQuery : indexQueries) {
results.add(new Tuple2<>(indexQuery._1(), predict(model, indexQuery._2())));
}
return new JavaSparkContext(qs.sparkContext()).parallelize(results).rdd();
}
private List<double[]> getRecentProductFeatures(Query query, Model model) {
try {
List<double[]> result = new ArrayList<>();
List<Event> events = LJavaEventStore.findByEntity(
ap.getAppName(),
"user",
query.getUserEntityId(),
OptionHelper.<String>none(),
OptionHelper.some(ap.getSimilarItemEvents()),
OptionHelper.some(OptionHelper.some("item")),
OptionHelper.<Option<String>>none(),
OptionHelper.<DateTime>none(),
OptionHelper.<DateTime>none(),
OptionHelper.some(10),
true,
Duration.apply(10, TimeUnit.SECONDS));
for (final Event event : events) {
if (event.targetEntityId().isDefined()) {
JavaPairRDD<String, Integer> filtered = model.getItemIndex().filter(new Function<Tuple2<String, Integer>, Boolean>() {
@Override
public Boolean call(Tuple2<String, Integer> element) throws Exception {
return element._1().equals(event.targetEntityId().get());
}
});
final Integer itemIndex = filtered.first()._2();
if (!filtered.isEmpty()) {
JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = model.getIndexItemFeatures().filter(new Function<Tuple2<Integer, Tuple2<String, double[]>>, Boolean>() {
@Override
public Boolean call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
return itemIndex.equals(element._1());
}
});
List<Tuple2<Integer, Tuple2<String, double[]>>> oneIndexItemFeatures = indexItemFeatures.collect();
if (oneIndexItemFeatures.size() > 0) {
result.add(oneIndexItemFeatures.get(0)._2()._2());
}
}
}
}
return result;
} catch (Exception e) {
logger.error("Error reading recent events for user " + query.getUserEntityId());
throw new RuntimeException(e.getMessage(), e);
}
}
private List<ItemScore> topItemsForUser(double[] userFeature, Model model, Query query) {
final DoubleMatrix userMatrix = new DoubleMatrix(userFeature);
JavaRDD<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() {
@Override
public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
return new ItemScore(element._2()._1(), userMatrix.dot(new DoubleMatrix(element._2()._2())));
}
});
itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
return sortAndTake(itemScores, query.getNumber());
}
private List<ItemScore> similarItems(final List<double[]> recentProductFeatures, Model model, Query query) {
JavaRDD<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() {
@Override
public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
double similarity = 0.0;
for (double[] recentFeature : recentProductFeatures) {
similarity += cosineSimilarity(element._2()._2(), recentFeature);
}
return new ItemScore(element._2()._1(), similarity);
}
});
itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
return sortAndTake(itemScores, query.getNumber());
}
private List<ItemScore> mostPopularItems(Model model, Query query) {
JavaRDD<ItemScore> itemScores = validScores(model.getItemPopularityScore(), query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
return sortAndTake(itemScores, query.getNumber());
}
private double cosineSimilarity(double[] a, double[] b) {
DoubleMatrix matrixA = new DoubleMatrix(a);
DoubleMatrix matrixB = new DoubleMatrix(b);
return matrixA.dot(matrixB) / (matrixA.norm2() * matrixB.norm2());
}
private List<ItemScore> sortAndTake(JavaRDD<ItemScore> all, int number) {
return all.sortBy(new Function<ItemScore, Double>() {
@Override
public Double call(ItemScore itemScore) throws Exception {
return itemScore.getScore();
}
}, false, all.partitions().size()).take(number);
}
private JavaRDD<ItemScore> validScores(JavaRDD<ItemScore> all, final Set<String> whitelist, final Set<String> blacklist, final Set<String> categories, final Map<String, Item> items, String userEntityId) {
final Set<String> seenItemEntityIds = seenItemEntityIds(userEntityId);
final Set<String> unavailableItemEntityIds = unavailableItemEntityIds();
return all.filter(new Function<ItemScore, Boolean>() {
@Override
public Boolean call(ItemScore itemScore) throws Exception {
Item item = items.get(itemScore.getItemEntityId());
return (item != null
&& passWhitelistCriteria(whitelist, item.getEntityId())
&& passBlacklistCriteria(blacklist, item.getEntityId())
&& passCategoryCriteria(categories, item)
&& passUnseenCriteria(seenItemEntityIds, item.getEntityId())
&& passAvailabilityCriteria(unavailableItemEntityIds, item.getEntityId()));
}
});
}
private boolean passWhitelistCriteria(Set<String> whitelist, String itemEntityId) {
return (whitelist.isEmpty() || whitelist.contains(itemEntityId));
}
private boolean passBlacklistCriteria(Set<String> blacklist, String itemEntityId) {
return !blacklist.contains(itemEntityId);
}
private boolean passCategoryCriteria(Set<String> categories, Item item) {
return (categories.isEmpty() || Sets.intersection(categories, item.getCategories()).size() > 0);
}
private boolean passUnseenCriteria(Set<String> seen, String itemEntityId) {
return !seen.contains(itemEntityId);
}
private boolean passAvailabilityCriteria(Set<String> unavailableItemEntityIds, String entityId) {
return !unavailableItemEntityIds.contains(entityId);
}
private Set<String> unavailableItemEntityIds() {
try {
List<Event> unavailableConstraintEvents = LJavaEventStore.findByEntity(
ap.getAppName(),
"constraint",
"unavailableItems",
OptionHelper.<String>none(),
OptionHelper.some(Collections.singletonList("$set")),
OptionHelper.<Option<String>>none(),
OptionHelper.<Option<String>>none(),
OptionHelper.<DateTime>none(),
OptionHelper.<DateTime>none(),
OptionHelper.some(1),
true,
Duration.apply(10, TimeUnit.SECONDS));
if (unavailableConstraintEvents.isEmpty()) return Collections.emptySet();
Event unavailableConstraint = unavailableConstraintEvents.get(0);
List<String> unavailableItems = unavailableConstraint.properties().getStringList("items");
return new HashSet<>(unavailableItems);
} catch (Exception e) {
logger.error("Error reading constraint events");
throw new RuntimeException(e.getMessage(), e);
}
}
private Set<String> seenItemEntityIds(String userEntityId) {
if (!ap.isUnseenOnly()) return Collections.emptySet();
try {
Set<String> result = new HashSet<>();
List<Event> seenEvents = LJavaEventStore.findByEntity(
ap.getAppName(),
"user",
userEntityId,
OptionHelper.<String>none(),
OptionHelper.some(ap.getSeenItemEvents()),
OptionHelper.some(OptionHelper.some("item")),
OptionHelper.<Option<String>>none(),
OptionHelper.<DateTime>none(),
OptionHelper.<DateTime>none(),
OptionHelper.<Integer>none(),
true,
Duration.apply(10, TimeUnit.SECONDS));
for (Event event : seenEvents) {
result.add(event.targetEntityId().get());
}
return result;
} catch (Exception e) {
logger.error("Error reading seen events for user " + userEntityId);
throw new RuntimeException(e.getMessage(), e);
}
}
}