blob: 0144ea73074b85c830ee73fff1145beca3eafb6a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.examples.ml.inference.exchange;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.UUID;
import javax.cache.Cache;
import org.apache.commons.math3.util.Precision;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer;
import org.apache.ignite.ml.knn.NNClassificationModel;
import org.apache.ignite.ml.knn.ann.ANNClassificationModel;
import org.apache.ignite.ml.knn.ann.ANNClassificationTrainer;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.distances.ManhattanDistance;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
/**
* Run ANN multi-class classification trainer ({@link ANNClassificationTrainer}) over distributed dataset.
* <p>
* Code in this example launches Ignite grid and fills the cache with test data points (based on the
* <a href="https://en.wikipedia.org/wiki/Iris_flower_data_set"></a>Iris dataset</a>).</p>
* <p>
* After that it trains the model based on the specified data using
* <a href="https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm">kNN</a> algorithm.</p>
* <p>
* Finally, this example loops over the test set of data points, applies the trained model to predict what cluster does
* this point belong to, and compares prediction to expected outcome (ground truth).</p>
* <p>
* You can change the test data used in this example and re-run it to explore this algorithm further.</p>
*/
public class ANNClassificationExportImportExample {
/**
* Run example.
*/
public static void main(String[] args) throws IOException {
System.out.println();
System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
IgniteCache<Integer, double[]> dataCache = null;
Path jsonMdlPath = null;
try {
dataCache = getTestCache(ignite);
ANNClassificationTrainer trainer = new ANNClassificationTrainer()
.withDistance(new ManhattanDistance())
.withK(50)
.withMaxIterations(1000)
.withEpsilon(1e-2);
ANNClassificationModel mdl = (ANNClassificationModel) trainer.fit(
ignite,
dataCache,
new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)
).withK(5)
.withDistanceMeasure(new EuclideanDistance())
.withWeighted(true);
System.out.println("\n>>> Exported ANN model: " + mdl.toString(true));
double accuracy = evaluateModel(dataCache, mdl);
System.out.println("\n>>> Accuracy for exported ANN model:" + accuracy);
jsonMdlPath = Files.createTempFile(null, null);
mdl.toJSON(jsonMdlPath);
ANNClassificationModel modelImportedFromJSON = ANNClassificationModel.fromJSON(jsonMdlPath);
System.out.println("\n>>> Imported ANN model: " + modelImportedFromJSON.toString(true));
accuracy = evaluateModel(dataCache, modelImportedFromJSON);
System.out.println("\n>>> Accuracy for imported ANN model:" + accuracy);
System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example completed.");
}
finally {
if (dataCache != null)
dataCache.destroy();
if (jsonMdlPath != null)
Files.deleteIfExists(jsonMdlPath);
}
}
finally {
System.out.flush();
}
}
/** */
private static double evaluateModel(IgniteCache<Integer, double[]> dataCache, NNClassificationModel knnMdl) {
int amountOfErrors = 0;
int totalAmount = 0;
double accuracy;
try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
System.out.println(">>> ---------------------------------");
System.out.println(">>> | Prediction\t| Ground Truth\t|");
System.out.println(">>> ---------------------------------");
for (Cache.Entry<Integer, double[]> observation : observations) {
double[] val = observation.getValue();
double[] inputs = Arrays.copyOfRange(val, 1, val.length);
double groundTruth = val[0];
double prediction = knnMdl.predict(new DenseVector(inputs));
totalAmount++;
if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
amountOfErrors++;
System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
}
System.out.println(">>> ---------------------------------");
accuracy = 1 - amountOfErrors / (double) totalAmount;
}
return accuracy;
}
/**
* Fills cache with data and returns it.
*
* @param ignite Ignite instance.
* @return Filled Ignite Cache.
*/
private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) {
CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>();
cacheConfiguration.setName("TEST_" + UUID.randomUUID());
cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));
IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration);
for (int k = 0; k < 10; k++) { // multiplies the Iris dataset k times.
for (int i = 0; i < data.length; i++)
cache.put(k * 10000 + i, mutate(data[i], k));
}
return cache;
}
/**
* Tiny changing of data depending on k parameter.
*
* @param datum The vector data.
* @param k The passed parameter.
* @return The changed vector data.
*/
private static double[] mutate(double[] datum, int k) {
for (int i = 0; i < datum.length; i++)
datum[i] += k / 100000;
return datum;
}
/**
* The Iris dataset.
*/
private static final double[][] data = {
{1, 5.1, 3.5, 1.4, 0.2},
{1, 4.9, 3, 1.4, 0.2},
{1, 4.7, 3.2, 1.3, 0.2},
{1, 4.6, 3.1, 1.5, 0.2},
{1, 5, 3.6, 1.4, 0.2},
{1, 5.4, 3.9, 1.7, 0.4},
{1, 4.6, 3.4, 1.4, 0.3},
{1, 5, 3.4, 1.5, 0.2},
{1, 4.4, 2.9, 1.4, 0.2},
{1, 4.9, 3.1, 1.5, 0.1},
{1, 5.4, 3.7, 1.5, 0.2},
{1, 4.8, 3.4, 1.6, 0.2},
{1, 4.8, 3, 1.4, 0.1},
{1, 4.3, 3, 1.1, 0.1},
{1, 5.8, 4, 1.2, 0.2},
{1, 5.7, 4.4, 1.5, 0.4},
{1, 5.4, 3.9, 1.3, 0.4},
{1, 5.1, 3.5, 1.4, 0.3},
{1, 5.7, 3.8, 1.7, 0.3},
{1, 5.1, 3.8, 1.5, 0.3},
{1, 5.4, 3.4, 1.7, 0.2},
{1, 5.1, 3.7, 1.5, 0.4},
{1, 4.6, 3.6, 1, 0.2},
{1, 5.1, 3.3, 1.7, 0.5},
{1, 4.8, 3.4, 1.9, 0.2},
{1, 5, 3, 1.6, 0.2},
{1, 5, 3.4, 1.6, 0.4},
{1, 5.2, 3.5, 1.5, 0.2},
{1, 5.2, 3.4, 1.4, 0.2},
{1, 4.7, 3.2, 1.6, 0.2},
{1, 4.8, 3.1, 1.6, 0.2},
{1, 5.4, 3.4, 1.5, 0.4},
{1, 5.2, 4.1, 1.5, 0.1},
{1, 5.5, 4.2, 1.4, 0.2},
{1, 4.9, 3.1, 1.5, 0.1},
{1, 5, 3.2, 1.2, 0.2},
{1, 5.5, 3.5, 1.3, 0.2},
{1, 4.9, 3.1, 1.5, 0.1},
{1, 4.4, 3, 1.3, 0.2},
{1, 5.1, 3.4, 1.5, 0.2},
{1, 5, 3.5, 1.3, 0.3},
{1, 4.5, 2.3, 1.3, 0.3},
{1, 4.4, 3.2, 1.3, 0.2},
{1, 5, 3.5, 1.6, 0.6},
{1, 5.1, 3.8, 1.9, 0.4},
{1, 4.8, 3, 1.4, 0.3},
{1, 5.1, 3.8, 1.6, 0.2},
{1, 4.6, 3.2, 1.4, 0.2},
{1, 5.3, 3.7, 1.5, 0.2},
{1, 5, 3.3, 1.4, 0.2},
{2, 7, 3.2, 4.7, 1.4},
{2, 6.4, 3.2, 4.5, 1.5},
{2, 6.9, 3.1, 4.9, 1.5},
{2, 5.5, 2.3, 4, 1.3},
{2, 6.5, 2.8, 4.6, 1.5},
{2, 5.7, 2.8, 4.5, 1.3},
{2, 6.3, 3.3, 4.7, 1.6},
{2, 4.9, 2.4, 3.3, 1},
{2, 6.6, 2.9, 4.6, 1.3},
{2, 5.2, 2.7, 3.9, 1.4},
{2, 5, 2, 3.5, 1},
{2, 5.9, 3, 4.2, 1.5},
{2, 6, 2.2, 4, 1},
{2, 6.1, 2.9, 4.7, 1.4},
{2, 5.6, 2.9, 3.6, 1.3},
{2, 6.7, 3.1, 4.4, 1.4},
{2, 5.6, 3, 4.5, 1.5},
{2, 5.8, 2.7, 4.1, 1},
{2, 6.2, 2.2, 4.5, 1.5},
{2, 5.6, 2.5, 3.9, 1.1},
{2, 5.9, 3.2, 4.8, 1.8},
{2, 6.1, 2.8, 4, 1.3},
{2, 6.3, 2.5, 4.9, 1.5},
{2, 6.1, 2.8, 4.7, 1.2},
{2, 6.4, 2.9, 4.3, 1.3},
{2, 6.6, 3, 4.4, 1.4},
{2, 6.8, 2.8, 4.8, 1.4},
{2, 6.7, 3, 5, 1.7},
{2, 6, 2.9, 4.5, 1.5},
{2, 5.7, 2.6, 3.5, 1},
{2, 5.5, 2.4, 3.8, 1.1},
{2, 5.5, 2.4, 3.7, 1},
{2, 5.8, 2.7, 3.9, 1.2},
{2, 6, 2.7, 5.1, 1.6},
{2, 5.4, 3, 4.5, 1.5},
{2, 6, 3.4, 4.5, 1.6},
{2, 6.7, 3.1, 4.7, 1.5},
{2, 6.3, 2.3, 4.4, 1.3},
{2, 5.6, 3, 4.1, 1.3},
{2, 5.5, 2.5, 4, 1.3},
{2, 5.5, 2.6, 4.4, 1.2},
{2, 6.1, 3, 4.6, 1.4},
{2, 5.8, 2.6, 4, 1.2},
{2, 5, 2.3, 3.3, 1},
{2, 5.6, 2.7, 4.2, 1.3},
{2, 5.7, 3, 4.2, 1.2},
{2, 5.7, 2.9, 4.2, 1.3},
{2, 6.2, 2.9, 4.3, 1.3},
{2, 5.1, 2.5, 3, 1.1},
{2, 5.7, 2.8, 4.1, 1.3},
{3, 6.3, 3.3, 6, 2.5},
{3, 5.8, 2.7, 5.1, 1.9},
{3, 7.1, 3, 5.9, 2.1},
{3, 6.3, 2.9, 5.6, 1.8},
{3, 6.5, 3, 5.8, 2.2},
{3, 7.6, 3, 6.6, 2.1},
{3, 4.9, 2.5, 4.5, 1.7},
{3, 7.3, 2.9, 6.3, 1.8},
{3, 6.7, 2.5, 5.8, 1.8},
{3, 7.2, 3.6, 6.1, 2.5},
{3, 6.5, 3.2, 5.1, 2},
{3, 6.4, 2.7, 5.3, 1.9},
{3, 6.8, 3, 5.5, 2.1},
{3, 5.7, 2.5, 5, 2},
{3, 5.8, 2.8, 5.1, 2.4},
{3, 6.4, 3.2, 5.3, 2.3},
{3, 6.5, 3, 5.5, 1.8},
{3, 7.7, 3.8, 6.7, 2.2},
{3, 7.7, 2.6, 6.9, 2.3},
{3, 6, 2.2, 5, 1.5},
{3, 6.9, 3.2, 5.7, 2.3},
{3, 5.6, 2.8, 4.9, 2},
{3, 7.7, 2.8, 6.7, 2},
{3, 6.3, 2.7, 4.9, 1.8},
{3, 6.7, 3.3, 5.7, 2.1},
{3, 7.2, 3.2, 6, 1.8},
{3, 6.2, 2.8, 4.8, 1.8},
{3, 6.1, 3, 4.9, 1.8},
{3, 6.4, 2.8, 5.6, 2.1},
{3, 7.2, 3, 5.8, 1.6},
{3, 7.4, 2.8, 6.1, 1.9},
{3, 7.9, 3.8, 6.4, 2},
{3, 6.4, 2.8, 5.6, 2.2},
{3, 6.3, 2.8, 5.1, 1.5},
{3, 6.1, 2.6, 5.6, 1.4},
{3, 7.7, 3, 6.1, 2.3},
{3, 6.3, 3.4, 5.6, 2.4},
{3, 6.4, 3.1, 5.5, 1.8},
{3, 6, 3, 4.8, 1.8},
{3, 6.9, 3.1, 5.4, 2.1},
{3, 6.7, 3.1, 5.6, 2.4},
{3, 6.9, 3.1, 5.1, 2.3},
{3, 5.8, 2.7, 5.1, 1.9},
{3, 6.8, 3.2, 5.9, 2.3},
{3, 6.7, 3.3, 5.7, 2.5},
{3, 6.7, 3, 5.2, 2.3},
{3, 6.3, 2.5, 5, 1.9},
{3, 6.5, 3, 5.2, 2},
{3, 6.2, 3.4, 5.4, 2.3},
{3, 5.9, 3, 5.1, 1.8}
};
}