blob: 1a1977124879b84ef02da70826de65a4f7988bf3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.examples.ml.tree;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.tree.DecisionTreeNode;
import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer;
/**
* Example of using distributed {@link DecisionTreeRegressionTrainer}.
* <p>
* Code in this example launches Ignite grid and fills the cache with generated test data points ({@code sin(x)} on
* interval {@code [0, 10)}).</p>
* <p>
* After that it creates classification trainer and uses it to train the model on the training set.</p>
* <p>
* Finally, this example loops over the test data points, applies the trained model, and compares prediction to expected
* outcome (ground truth).</p>
* <p>
* You can change the test data used in this example and re-run it to explore this algorithm further.</p>
*/
public class DecisionTreeRegressionTrainerExample {
/**
* Executes example.
*
* @param args Command line arguments, none required.
*/
public static void main(String... args) {
System.out.println(">>> Decision tree regression trainer example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
// Create cache with training data.
CacheConfiguration<Integer, LabeledVector<Double>> trainingSetCfg = new CacheConfiguration<>();
trainingSetCfg.setName("TRAINING_SET");
trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
IgniteCache<Integer, LabeledVector<Double>> trainingSet = null;
try {
trainingSet = ignite.createCache(trainingSetCfg);
// Fill training data.
generatePoints(trainingSet);
// Create regression trainer.
DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(10, 0);
// Train decision tree model.
DecisionTreeNode mdl = trainer.fit(ignite, trainingSet, new LabeledDummyVectorizer<>());
System.out.println(">>> Decision tree regression model: " + mdl);
System.out.println(">>> ---------------------------------");
System.out.println(">>> | Prediction\t| Ground Truth\t|");
System.out.println(">>> ---------------------------------");
// Calculate score.
for (int x = 0; x < 10; x++) {
double predicted = mdl.predict(VectorUtils.of(x));
System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.sin(x));
}
System.out.println(">>> ---------------------------------");
System.out.println(">>> Decision tree regression trainer example completed.");
}
finally {
trainingSet.destroy();
}
}
finally {
System.out.flush();
}
}
/**
* Generates {@code sin(x)} on interval {@code [0, 10)} and loads into the specified cache.
*/
private static void generatePoints(IgniteCache<Integer, LabeledVector<Double>> trainingSet) {
for (int i = 0; i < 1000; i++) {
double x = i / 100.0;
double y = Math.sin(x);
trainingSet.put(i, new LabeledVector<>(VectorUtils.of(x), y));
}
}
}