blob: 129cb2aefb301ae1d2561379905cf6f110091321 [file] [log] [blame]
package org.template.recommendation;
import io.prediction.controller.Params;
import io.prediction.controller.PersistentModel;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;
import java.io.Serializable;
import java.util.Collections;
import java.util.Map;
public class Model implements Serializable, PersistentModel<AlgorithmParams> {
private static final Logger logger = LoggerFactory.getLogger(Model.class);
private final JavaPairRDD<Integer, double[]> userFeatures;
private final JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures;
private final JavaPairRDD<String, Integer> userIndex;
private final JavaPairRDD<String, Integer> itemIndex;
private final JavaRDD<ItemScore> itemPopularityScore;
private final Map<String, Item> items;
public Model(JavaPairRDD<Integer, double[]> userFeatures, JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures, JavaPairRDD<String, Integer> userIndex, JavaPairRDD<String, Integer> itemIndex, JavaRDD<ItemScore> itemPopularityScore, Map<String, Item> items) {
this.userFeatures = userFeatures;
this.indexItemFeatures = indexItemFeatures;
this.userIndex = userIndex;
this.itemIndex = itemIndex;
this.itemPopularityScore = itemPopularityScore;
this.items = items;
}
public JavaPairRDD<Integer, double[]> getUserFeatures() {
return userFeatures;
}
public JavaPairRDD<Integer, Tuple2<String, double[]>> getIndexItemFeatures() {
return indexItemFeatures;
}
public JavaPairRDD<String, Integer> getUserIndex() {
return userIndex;
}
public JavaPairRDD<String, Integer> getItemIndex() {
return itemIndex;
}
public JavaRDD<ItemScore> getItemPopularityScore() {
return itemPopularityScore;
}
public Map<String, Item> getItems() {
return items;
}
@Override
public boolean save(String id, AlgorithmParams params, SparkContext sc) {
userFeatures.saveAsObjectFile("/tmp/" + id + "/userFeatures");
indexItemFeatures.saveAsObjectFile("/tmp/" + id + "/indexItemFeatures");
userIndex.saveAsObjectFile("/tmp/" + id + "/userIndex");
itemIndex.saveAsObjectFile("/tmp/" + id + "/itemIndex");
itemPopularityScore.saveAsObjectFile("/tmp/" + id + "/itemPopularityScore");
new JavaSparkContext(sc).parallelize(Collections.singletonList(items)).saveAsObjectFile("/tmp/" + id + "/items");
logger.info("Saved model to /tmp/" + id);
return true;
}
public static Model load(String id, Params params, SparkContext sc) {
JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc);
JavaPairRDD<Integer, double[]> userFeatures = JavaPairRDD.<Integer, double[]>fromJavaRDD(jsc.<Tuple2<Integer, double[]>>objectFile("/tmp/" + id + "/userFeatures"));
JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = JavaPairRDD.<Integer, Tuple2<String, double[]>>fromJavaRDD(jsc.<Tuple2<Integer, Tuple2<String, double[]>>>objectFile("/tmp/" + id + "/indexItemFeatures"));
JavaPairRDD<String, Integer> userIndex = JavaPairRDD.<String, Integer>fromJavaRDD(jsc.<Tuple2<String, Integer>>objectFile("/tmp/" + id + "/userIndex"));
JavaPairRDD<String, Integer> itemIndex = JavaPairRDD.<String, Integer>fromJavaRDD(jsc.<Tuple2<String, Integer>>objectFile("/tmp/" + id + "/itemIndex"));
JavaRDD<ItemScore> itemPopularityScore = jsc.objectFile("/tmp/" + id + "/itemPopularityScore");
Map<String, Item> items = jsc.<Map<String, Item>>objectFile("/tmp/" + id + "/items").collect().get(0);
logger.info("loaded model");
return new Model(userFeatures, indexItemFeatures, userIndex, itemIndex, itemPopularityScore, items);
}
}