blob: b5c798403c00fbbf44347da44fe28c0173185f21 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.examples.ml.recommendation;
import java.io.IOException;
import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
import org.apache.ignite.examples.ml.util.SandboxMLCache;
import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.recommendation.ObjectSubjectRatingTriplet;
import org.apache.ignite.ml.recommendation.RecommendationModel;
import org.apache.ignite.ml.recommendation.RecommendationTrainer;
/**
* Example of recommendation system based on MovieLens dataset (see https://grouplens.org/datasets/movielens/).
* In this example we create a cache with MovieLens rating data. Each entry in this cache represents a rating point
* (rating set by a single user to a single movie). Then we pass this cache to {@link RecommendationTrainer} and so
* that train {@link RecommendationModel}. This model predicts rating with assumed to be set by any user to any movie.
* When model is ready we calculate R2 score.
*/
public class MovieLensExample {
/** Run example. */
public static void main(String[] args) throws IOException {
System.out.println();
System.out.println(">>> Recommendation system over cache based dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
IgniteCache<Integer, RatingPoint> movielensCache = loadMovieLensDataset(ignite, 10_000);
try {
LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(1);
RecommendationTrainer trainer = new RecommendationTrainer()
.withMaxIterations(-1)
.withMinMdlImprovement(10)
.withBatchSize(10)
.withLearningRate(10)
.withLearningEnvironmentBuilder(envBuilder)
.withTrainerEnvironment(envBuilder.buildForTrainer());
RecommendationModel<Integer, Integer> mdl = trainer.fit(
new CacheBasedDatasetBuilder<>(ignite, movielensCache)
);
double mean = 0;
try (QueryCursor<Cache.Entry<Integer, RatingPoint>> cursor = movielensCache.query(new ScanQuery<>())) {
for (Cache.Entry<Integer, RatingPoint> e : cursor) {
ObjectSubjectRatingTriplet<Integer, Integer> triplet = e.getValue();
mean += triplet.getRating();
}
mean /= movielensCache.size();
}
double tss = 0, rss = 0;
try (QueryCursor<Cache.Entry<Integer, RatingPoint>> cursor = movielensCache.query(new ScanQuery<>())) {
for (Cache.Entry<Integer, RatingPoint> e : cursor) {
ObjectSubjectRatingTriplet<Integer, Integer> triplet = e.getValue();
tss += Math.pow(triplet.getRating() - mean, 2);
rss += Math.pow(triplet.getRating() - mdl.predict(triplet), 2);
}
}
double r2 = 1.0 - rss / tss;
System.out.println("R2 score: " + r2);
}
finally {
movielensCache.destroy();
}
}
finally {
System.out.flush();
}
}
/**
* Loads MovieLens dataset into cache.
*
* @param ignite Ignite instance.
* @param cnt Number of rating point to be loaded.
* @return Ignite cache with loaded MovieLens dataset.
* @throws IOException If dataset not found.
*/
private static IgniteCache<Integer, RatingPoint> loadMovieLensDataset(Ignite ignite, int cnt) throws IOException {
CacheConfiguration<Integer, RatingPoint> cacheConfiguration = new CacheConfiguration<>();
cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 100));
cacheConfiguration.setName("MOVIELENS");
IgniteCache<Integer, RatingPoint> dataCache = ignite.createCache(cacheConfiguration);
int seq = 0;
for (String s : new SandboxMLCache(ignite).loadDataset(MLSandboxDatasets.MOVIELENS)) {
String[] line = s.split(",");
int userId = Integer.valueOf(line[0]);
int movieId = Integer.valueOf(line[1]);
double rating = Double.valueOf(line[2]);
dataCache.put(seq++, new RatingPoint(movieId, userId, rating));
if (seq == cnt)
break;
}
return dataCache;
}
/**
* Rating point that represents a result of assesment of a single movie by a single user.
*/
private static class RatingPoint extends ObjectSubjectRatingTriplet<Integer, Integer> {
/** */
private static final long serialVersionUID = -7301471870043910312L;
/**
* Constructs a new instance of rating point.
*
* @param movieId Movie identifier.
* @param userId User identifier..
* @param rating Rating.
*/
public RatingPoint(Integer movieId, Integer userId, Double rating) {
super(movieId, userId, rating);
}
}
}