blob: cbf74371fd6d3907ebb610b9baa388e2ce7447be [file] [log] [blame]
package io.prediction.engines.java.olditemrec.algos;
import io.prediction.engines.java.olditemrec.data.PreparedData;
import io.prediction.engines.java.olditemrec.data.Query;
import io.prediction.engines.java.olditemrec.data.Prediction;
import io.prediction.engines.java.olditemrec.data.Model;
import io.prediction.controller.java.LJavaAlgorithm;
import io.prediction.controller.java.JavaParams;
import io.prediction.controller.Params;
import org.apache.mahout.cf.taste.recommender.Recommender;
import org.apache.mahout.cf.taste.recommender.RecommendedItem;
import org.apache.mahout.cf.taste.common.TasteException;
import java.util.HashMap;
import java.util.Map;
import java.util.ArrayList;
import java.util.List;
import java.util.Iterator;
import java.util.concurrent.Executors;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
/** Mahout Legacy single machine algorithm */
public abstract class AbstractMahoutAlgorithm<AP extends Params>
extends LJavaAlgorithm<AP, PreparedData, Model, Query, Prediction> {
MahoutParams params;
Logger logger;
public AbstractMahoutAlgorithm(MahoutParams params, Logger logger) {
this.params = params;
this.logger = logger;
}
abstract public Recommender buildRecommender(PreparedData data) throws TasteException;
@Override
public Model train(PreparedData data) {
Recommender recommender = null;
try {
recommender = buildRecommender(data);
} catch (TasteException e) {
logger.error("Caught TasteException " + e.getMessage());
}
List<Long> uids = new ArrayList<Long>();
try {
Iterator<Long> it = data.dataModel.getUserIDs();
while (it.hasNext()) {
uids.add(it.next());
}
} catch (TasteException e){
logger.error("Caught TasteException " + e.getMessage());
}
//Map<Long, List<RecommendedItem>> itemRecScores = batchTrain(uids, recommender);
Map<Long, List<RecommendedItem>> itemRecScores = batchTrainMulti(uids, recommender);
Model m = new Model(itemRecScores);
logger.info(m.toString());
return m;
}
@Override
public Prediction predict(Model model, Query query) {
List<RecommendedItem> items = model.itemRecScores.get((long) query.uid);
List<Integer> iids = new ArrayList<Integer>();
List<Float> scores = new ArrayList<Float>();
if (items != null) {
int i = 0;
for (RecommendedItem item: items) {
logger.info(item.toString());
iids.add((int) item.getItemID());
scores.add(item.getValue());
i++;
if (i >= query.n) break;
}
} else {
logger.info("null");
}
return new Prediction(iids, scores);
}
private Map<Long, List<RecommendedItem>> batchTrain(List<Long> uids, Recommender recommender) {
Map<Long, List<RecommendedItem>> itemRecScores = new HashMap<Long, List<RecommendedItem>>();
for (Long uid : uids) {
List<RecommendedItem> items;
try {
items = recommender.recommend(uid, params.numRecommendations);
} catch (TasteException e) {
items = new ArrayList<RecommendedItem>();
logger.error("Caught TasteException " + e.getMessage());
}
if (items.size() > 0) {
itemRecScores.put(uid, items);
}
}
return itemRecScores;
}
// multi threaded batch training
private Map<Long, List<RecommendedItem>> batchTrainMulti(
List<Long> uids, Recommender recommender) {
int numProcessors = Runtime.getRuntime().availableProcessors();
ExecutorService executor = Executors.newFixedThreadPool(numProcessors);
ExecutorCompletionService<Map<Long, List<RecommendedItem>>> pool =
new ExecutorCompletionService<Map<Long, List<RecommendedItem>>>(executor);
List<List<Long>> batches = createBatchesByNumber(uids, numProcessors);
for (List<Long> batch : batches) {
RecommendTask task = new RecommendTask(batch, recommender);
pool.submit(task);
}
// reduce stage
Map<Long, List<RecommendedItem>> itemRecScores = new HashMap<Long, List<RecommendedItem>>();
for (int i = 0; i < batches.size(); i++) {
try {
itemRecScores.putAll(pool.take().get());
} catch (Exception e) {
logger.error("Caught Exception: " + e.getMessage());
}
}
executor.shutdown();
return itemRecScores;
}
/**
* Create batches by specifing number of batches
* @param data List of Data
* @param num Number of batches
*/
private List<List<Long>> createBatchesByNumber(List<Long> data, int num) {
List<List<Long>> batches = new ArrayList<List<Long>>();
// example cases data.sie()/num:
// 21/4 = 5 => 5, 5, 5, 6
// 8/4 = 2 => 2, 2, 2, 2
// 6/4 = 1 => 1, 1, 1, 3
// 4/4 = 1 => 1, 1, 1, 1
// 1/4 = 0 => 1
// 0/4 = 0 => empty
int total = data.size();
int batchSize = 1; // min batchSize is 1;
if (total > num)
batchSize = total / num;
int from = 0;
int batchNum = 0; // last batch indicator
int until;
while (from < total) {
batchNum += 1;
if (batchNum == num) // if last batch, just put reamining to the batch
until = total;
else
until = Math.min(total, from + batchSize); // cap by total size
List<Long> subList = data.subList(from, until);
logger.info("Create Batch: [" + from + ", " + until + ")");
batches.add(subList);
from = until;
}
return batches;
}
/*
private List<List<Long>> createBatchesBySize(List<Long> data, int batchSize) {
List<List<Long>> batches = new ArrayList<List<Long>>();
int i = 0;
boolean newBatch = true;
int remain = data.size();
List<Long> batchData = new ArrayList<Long>();
int actual = 0;
for (long d : data) {
if (newBatch) {
actual = Math.min(remain, batchSize);
newBatch = false;
}
batchData.add(d);
remain--;
i++;
if (i == actual) { // already fill up one batch
batches.add(batchData);
i = 0;
newBatch = true;
batchData = new ArrayList<Long>();
}
}
logger.info("batches size: " + batches.size());
return batches;
}*/
private class RecommendTask implements Callable<Map<Long, List<RecommendedItem>>> {
private final List<Long> uids;
private final Recommender recommender;
public RecommendTask(List<Long> uids, Recommender recommender) {
this.uids = uids;
this.recommender = recommender;
}
@Override
public Map<Long, List<RecommendedItem>> call() {
return batchTrain(uids, recommender);
}
}
}