blob: bddd42512294aaae15276eb8f08076a7a0d4035e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.examples.ml.inference.exchange;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.ml.composition.boosting.GDBModel;
import org.apache.ignite.ml.composition.boosting.GDBTrainer;
import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer;
import org.jetbrains.annotations.NotNull;
/**
* Example represents a solution for the task of regression learning based on Gradient Boosting on trees implementation.
* It shows an initialization of {@link GDBRegressionOnTreesTrainer}, initialization of Ignite Cache, learning step and
* comparing of predicted and real values.
* <p>
* In this example dataset is created automatically by parabolic function {@code f(x) = x^2}.</p>
*/
public class GDBOnTreesRegressionExportImportExample {
/**
* Run example.
*
* @param args Command line arguments, none required.
*/
public static void main(String[] args) throws IOException {
System.out.println();
System.out.println(">>> GDB regression trainer example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
// Create cache with training data.
CacheConfiguration<Integer, double[]> trainingSetCfg = createCacheConfiguration();
IgniteCache<Integer, double[]> trainingSet = null;
Path jsonMdlPath = null;
try {
trainingSet = fillTrainingData(ignite, trainingSetCfg);
// Create regression trainer.
GDBTrainer trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.)
.withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.001));
// Train decision tree model.
GDBModel mdl = trainer.fit(
ignite,
trainingSet,
new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)
);
System.out.println("\n>>> Exported GDB regression model: " + mdl.toString(true));
predictOnGeneratedData(mdl);
jsonMdlPath = Files.createTempFile(null, null);
mdl.toJSON(jsonMdlPath);
IgniteFunction<Double, Double> lbMapper = lb -> lb;
GDBModel modelImportedFromJSON = GDBModel.fromJSON(jsonMdlPath).withLblMapping(lbMapper);
System.out.println("\n>>> Imported GDB regression model: " + modelImportedFromJSON.toString(true));
predictOnGeneratedData(modelImportedFromJSON);
System.out.println(">>> GDB regression trainer example completed.");
}
finally {
if (trainingSet != null)
trainingSet.destroy();
if (jsonMdlPath != null)
Files.deleteIfExists(jsonMdlPath);
}
}
finally {
System.out.flush();
}
}
/** */
private static void predictOnGeneratedData(GDBModel mdl) {
System.out.println(">>> ---------------------------------");
System.out.println(">>> | Prediction\t| Valid answer \t|");
System.out.println(">>> ---------------------------------");
// Calculate score.
for (int x = -5; x < 5; x++) {
double predicted = mdl.predict(VectorUtils.of(x));
System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.pow(x, 2));
}
System.out.println(">>> ---------------------------------");
}
/**
* Create cache configuration.
*/
@NotNull private static CacheConfiguration<Integer, double[]> createCacheConfiguration() {
CacheConfiguration<Integer, double[]> trainingSetCfg = new CacheConfiguration<>();
trainingSetCfg.setName("TRAINING_SET");
trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
return trainingSetCfg;
}
/**
* Fill parabolic training data.
*
* @param ignite Ignite instance.
* @param trainingSetCfg Training set config.
*/
@NotNull private static IgniteCache<Integer, double[]> fillTrainingData(Ignite ignite,
CacheConfiguration<Integer, double[]> trainingSetCfg) {
IgniteCache<Integer, double[]> trainingSet = ignite.getOrCreateCache(trainingSetCfg);
for (int i = -50; i <= 50; i++) {
double x = ((double)i) / 10.0;
double y = Math.pow(x, 2);
trainingSet.put(i, new double[] {x, y});
}
return trainingSet;
}
}