blob: bb8c54123caa6ff97929c0a1e89525f55c9b84f8 [file] [log] [blame]
package org.template.recommendation;
import io.prediction.controller.Params;
import io.prediction.controller.PersistentModel;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;
import java.io.Serializable;
public class Model implements Serializable, PersistentModel<AlgorithmParams> {
private static final Logger logger = LoggerFactory.getLogger(Model.class);
private final JavaPairRDD<Integer, double[]> userFeatures;
private final JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures;
private final JavaPairRDD<String, Integer> userIndex;
private final JavaPairRDD<String, Integer> itemIndex;
private final JavaRDD<ItemScore> itemPopularityScore;
private final JavaPairRDD<String, Item> items;
public Model(JavaPairRDD<Integer, double[]> userFeatures, JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures, JavaPairRDD<String, Integer> userIndex, JavaPairRDD<String, Integer> itemIndex, JavaRDD<ItemScore> itemPopularityScore, JavaPairRDD<String, Item> items) {
this.userFeatures = userFeatures;
this.indexItemFeatures = indexItemFeatures;
this.userIndex = userIndex;
this.itemIndex = itemIndex;
this.itemPopularityScore = itemPopularityScore;
this.items = items;
}
public JavaPairRDD<Integer, double[]> getUserFeatures() {
return userFeatures;
}
public JavaPairRDD<Integer, Tuple2<String, double[]>> getIndexItemFeatures() {
return indexItemFeatures;
}
public JavaPairRDD<String, Integer> getUserIndex() {
return userIndex;
}
public JavaPairRDD<String, Integer> getItemIndex() {
return itemIndex;
}
public JavaRDD<ItemScore> getItemPopularityScore() {
return itemPopularityScore;
}
public JavaPairRDD<String, Item> getItems() {
return items;
}
@Override
public boolean save(String id, AlgorithmParams params, SparkContext sc) {
userFeatures.saveAsObjectFile("/tmp/" + id + "/userFeatures");
indexItemFeatures.saveAsObjectFile("/tmp/" + id + "/indexItemFeatures");
userIndex.saveAsObjectFile("/tmp/" + id + "/userIndex");
itemIndex.saveAsObjectFile("/tmp/" + id + "/itemIndex");
itemPopularityScore.saveAsObjectFile("/tmp/" + id + "/itemPopularityScore");
items.saveAsObjectFile("/tmp/" + id + "/items");
logger.info("Saved model to /tmp/" + id);
return true;
}
public static Model load(String id, Params params, SparkContext sc) {
JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc);
JavaPairRDD<Integer, double[]> userFeatures = JavaPairRDD.<Integer, double[]>fromJavaRDD(jsc.<Tuple2<Integer, double[]>>objectFile("/tmp/" + id + "/userFeatures"));
JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = JavaPairRDD.<Integer, Tuple2<String, double[]>>fromJavaRDD(jsc.<Tuple2<Integer, Tuple2<String, double[]>>>objectFile("/tmp/" + id + "/indexItemFeatures"));
JavaPairRDD<String, Integer> userIndex = JavaPairRDD.<String, Integer>fromJavaRDD(jsc.<Tuple2<String, Integer>>objectFile("/tmp/" + id + "/userIndex"));
JavaPairRDD<String, Integer> itemIndex = JavaPairRDD.<String, Integer>fromJavaRDD(jsc.<Tuple2<String, Integer>>objectFile("/tmp/" + id + "/itemIndex"));
JavaRDD<ItemScore> itemPopularityScore = jsc.objectFile("/tmp/" + id + "/itemPopularityScore");
JavaPairRDD<String, Item> items = JavaPairRDD.<String, Item>fromJavaRDD(jsc.<Tuple2<String, Item>>objectFile("/tmp/" + id + "/items"));
logger.info("loaded model");
return new Model(userFeatures, indexItemFeatures, userIndex, itemIndex, itemPopularityScore, items);
}
}