blob: bf99690088f254ccc261b7573c0166fdc7fdfff2 [file] [log] [blame]
package io.prediction.examples.java.recommendations.tutorial5;
import io.prediction.controller.java.LJavaAlgorithm;
import io.prediction.controller.java.EmptyParams;
import io.prediction.examples.java.recommendations.tutorial1.TrainingData;
import io.prediction.examples.java.recommendations.tutorial1.Query;
import io.prediction.engines.util.MahoutUtil;
import org.apache.mahout.cf.taste.recommender.Recommender;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.common.TasteException;
import scala.Tuple4;
import java.util.List;
import java.util.ArrayList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/* Simple Mahout ItemBased Algorithm integration for demonstration purpose */
public class MahoutAlgorithm extends
LJavaAlgorithm<EmptyParams, TrainingData, MahoutAlgoModel, Query, Float> {
final static Logger logger = LoggerFactory.getLogger(MahoutAlgorithm.class);
MahoutAlgoParams params;
public MahoutAlgorithm(MahoutAlgoParams params) {
this.params = params;
}
@Override
public MahoutAlgoModel train(TrainingData data) {
List<Tuple4<Integer, Integer, Float, Long>> ratings = new ArrayList<
Tuple4<Integer, Integer, Float, Long>>();
for (TrainingData.Rating r : data.ratings) {
// no timestamp
ratings.add(new Tuple4<Integer, Integer, Float, Long>(r.uid, r.iid, r.rating, 0L));
}
DataModel dataModel = MahoutUtil.jBuildDataModel(ratings);
return new MahoutAlgoModel(dataModel, params);
}
@Override
public Float predict(MahoutAlgoModel model, Query query) {
float predicted;
try {
predicted = model.getRecommender().estimatePreference((long) query.uid, (long) query.iid);
} catch (TasteException e) {
predicted = Float.NaN;
}
return predicted;
}
}