blob: f3cdbbefc2a2a62fdac9b4eccc827893a4392564 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.examples.ml.knn;
import java.util.Arrays;
import java.util.UUID;
import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
import org.apache.ignite.ml.knn.classification.KNNStrategy;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.apache.ignite.thread.IgniteThread;
/**
* Run kNN multi-class classification trainer over distributed dataset.
*
* @see KNNClassificationTrainer
*/
public class KNNClassificationExample {
/** Run example. */
public static void main(String[] args) throws InterruptedException {
System.out.println();
System.out.println(">>> kNN multi-class classification algorithm over cached dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
KNNClassificationExample.class.getSimpleName(), () -> {
IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
KNNClassificationTrainer trainer = new KNNClassificationTrainer();
KNNClassificationModel knnMdl = trainer.fit(
new CacheBasedDatasetBuilder<>(ignite, dataCache),
(k, v) -> Arrays.copyOfRange(v, 1, v.length),
(k, v) -> v[0]
).withK(3)
.withDistanceMeasure(new EuclideanDistance())
.withStrategy(KNNStrategy.WEIGHTED);
System.out.println(">>> ---------------------------------");
System.out.println(">>> | Prediction\t| Ground Truth\t|");
System.out.println(">>> ---------------------------------");
int amountOfErrors = 0;
int totalAmount = 0;
try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
for (Cache.Entry<Integer, double[]> observation : observations) {
double[] val = observation.getValue();
double[] inputs = Arrays.copyOfRange(val, 1, val.length);
double groundTruth = val[0];
double prediction = knnMdl.apply(new DenseLocalOnHeapVector(inputs));
totalAmount++;
if(groundTruth != prediction)
amountOfErrors++;
System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
}
System.out.println(">>> ---------------------------------");
System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
}
});
igniteThread.start();
igniteThread.join();
}
}
/**
* Fills cache with data and returns it.
*
* @param ignite Ignite instance.
* @return Filled Ignite Cache.
*/
private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) {
CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>();
cacheConfiguration.setName("TEST_" + UUID.randomUUID());
cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));
IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration);
for (int i = 0; i < data.length; i++)
cache.put(i, data[i]);
return cache;
}
/** The Iris dataset. */
private static final double[][] data = {
{1, 5.1, 3.5, 1.4, 0.2},
{1, 4.9, 3, 1.4, 0.2},
{1, 4.7, 3.2, 1.3, 0.2},
{1, 4.6, 3.1, 1.5, 0.2},
{1, 5, 3.6, 1.4, 0.2},
{1, 5.4, 3.9, 1.7, 0.4},
{1, 4.6, 3.4, 1.4, 0.3},
{1, 5, 3.4, 1.5, 0.2},
{1, 4.4, 2.9, 1.4, 0.2},
{1, 4.9, 3.1, 1.5, 0.1},
{1, 5.4, 3.7, 1.5, 0.2},
{1, 4.8, 3.4, 1.6, 0.2},
{1, 4.8, 3, 1.4, 0.1},
{1, 4.3, 3, 1.1, 0.1},
{1, 5.8, 4, 1.2, 0.2},
{1, 5.7, 4.4, 1.5, 0.4},
{1, 5.4, 3.9, 1.3, 0.4},
{1, 5.1, 3.5, 1.4, 0.3},
{1, 5.7, 3.8, 1.7, 0.3},
{1, 5.1, 3.8, 1.5, 0.3},
{1, 5.4, 3.4, 1.7, 0.2},
{1, 5.1, 3.7, 1.5, 0.4},
{1, 4.6, 3.6, 1, 0.2},
{1, 5.1, 3.3, 1.7, 0.5},
{1, 4.8, 3.4, 1.9, 0.2},
{1, 5, 3, 1.6, 0.2},
{1, 5, 3.4, 1.6, 0.4},
{1, 5.2, 3.5, 1.5, 0.2},
{1, 5.2, 3.4, 1.4, 0.2},
{1, 4.7, 3.2, 1.6, 0.2},
{1, 4.8, 3.1, 1.6, 0.2},
{1, 5.4, 3.4, 1.5, 0.4},
{1, 5.2, 4.1, 1.5, 0.1},
{1, 5.5, 4.2, 1.4, 0.2},
{1, 4.9, 3.1, 1.5, 0.1},
{1, 5, 3.2, 1.2, 0.2},
{1, 5.5, 3.5, 1.3, 0.2},
{1, 4.9, 3.1, 1.5, 0.1},
{1, 4.4, 3, 1.3, 0.2},
{1, 5.1, 3.4, 1.5, 0.2},
{1, 5, 3.5, 1.3, 0.3},
{1, 4.5, 2.3, 1.3, 0.3},
{1, 4.4, 3.2, 1.3, 0.2},
{1, 5, 3.5, 1.6, 0.6},
{1, 5.1, 3.8, 1.9, 0.4},
{1, 4.8, 3, 1.4, 0.3},
{1, 5.1, 3.8, 1.6, 0.2},
{1, 4.6, 3.2, 1.4, 0.2},
{1, 5.3, 3.7, 1.5, 0.2},
{1, 5, 3.3, 1.4, 0.2},
{2, 7, 3.2, 4.7, 1.4},
{2, 6.4, 3.2, 4.5, 1.5},
{2, 6.9, 3.1, 4.9, 1.5},
{2, 5.5, 2.3, 4, 1.3},
{2, 6.5, 2.8, 4.6, 1.5},
{2, 5.7, 2.8, 4.5, 1.3},
{2, 6.3, 3.3, 4.7, 1.6},
{2, 4.9, 2.4, 3.3, 1},
{2, 6.6, 2.9, 4.6, 1.3},
{2, 5.2, 2.7, 3.9, 1.4},
{2, 5, 2, 3.5, 1},
{2, 5.9, 3, 4.2, 1.5},
{2, 6, 2.2, 4, 1},
{2, 6.1, 2.9, 4.7, 1.4},
{2, 5.6, 2.9, 3.6, 1.3},
{2, 6.7, 3.1, 4.4, 1.4},
{2, 5.6, 3, 4.5, 1.5},
{2, 5.8, 2.7, 4.1, 1},
{2, 6.2, 2.2, 4.5, 1.5},
{2, 5.6, 2.5, 3.9, 1.1},
{2, 5.9, 3.2, 4.8, 1.8},
{2, 6.1, 2.8, 4, 1.3},
{2, 6.3, 2.5, 4.9, 1.5},
{2, 6.1, 2.8, 4.7, 1.2},
{2, 6.4, 2.9, 4.3, 1.3},
{2, 6.6, 3, 4.4, 1.4},
{2, 6.8, 2.8, 4.8, 1.4},
{2, 6.7, 3, 5, 1.7},
{2, 6, 2.9, 4.5, 1.5},
{2, 5.7, 2.6, 3.5, 1},
{2, 5.5, 2.4, 3.8, 1.1},
{2, 5.5, 2.4, 3.7, 1},
{2, 5.8, 2.7, 3.9, 1.2},
{2, 6, 2.7, 5.1, 1.6},
{2, 5.4, 3, 4.5, 1.5},
{2, 6, 3.4, 4.5, 1.6},
{2, 6.7, 3.1, 4.7, 1.5},
{2, 6.3, 2.3, 4.4, 1.3},
{2, 5.6, 3, 4.1, 1.3},
{2, 5.5, 2.5, 4, 1.3},
{2, 5.5, 2.6, 4.4, 1.2},
{2, 6.1, 3, 4.6, 1.4},
{2, 5.8, 2.6, 4, 1.2},
{2, 5, 2.3, 3.3, 1},
{2, 5.6, 2.7, 4.2, 1.3},
{2, 5.7, 3, 4.2, 1.2},
{2, 5.7, 2.9, 4.2, 1.3},
{2, 6.2, 2.9, 4.3, 1.3},
{2, 5.1, 2.5, 3, 1.1},
{2, 5.7, 2.8, 4.1, 1.3},
{3, 6.3, 3.3, 6, 2.5},
{3, 5.8, 2.7, 5.1, 1.9},
{3, 7.1, 3, 5.9, 2.1},
{3, 6.3, 2.9, 5.6, 1.8},
{3, 6.5, 3, 5.8, 2.2},
{3, 7.6, 3, 6.6, 2.1},
{3, 4.9, 2.5, 4.5, 1.7},
{3, 7.3, 2.9, 6.3, 1.8},
{3, 6.7, 2.5, 5.8, 1.8},
{3, 7.2, 3.6, 6.1, 2.5},
{3, 6.5, 3.2, 5.1, 2},
{3, 6.4, 2.7, 5.3, 1.9},
{3, 6.8, 3, 5.5, 2.1},
{3, 5.7, 2.5, 5, 2},
{3, 5.8, 2.8, 5.1, 2.4},
{3, 6.4, 3.2, 5.3, 2.3},
{3, 6.5, 3, 5.5, 1.8},
{3, 7.7, 3.8, 6.7, 2.2},
{3, 7.7, 2.6, 6.9, 2.3},
{3, 6, 2.2, 5, 1.5},
{3, 6.9, 3.2, 5.7, 2.3},
{3, 5.6, 2.8, 4.9, 2},
{3, 7.7, 2.8, 6.7, 2},
{3, 6.3, 2.7, 4.9, 1.8},
{3, 6.7, 3.3, 5.7, 2.1},
{3, 7.2, 3.2, 6, 1.8},
{3, 6.2, 2.8, 4.8, 1.8},
{3, 6.1, 3, 4.9, 1.8},
{3, 6.4, 2.8, 5.6, 2.1},
{3, 7.2, 3, 5.8, 1.6},
{3, 7.4, 2.8, 6.1, 1.9},
{3, 7.9, 3.8, 6.4, 2},
{3, 6.4, 2.8, 5.6, 2.2},
{3, 6.3, 2.8, 5.1, 1.5},
{3, 6.1, 2.6, 5.6, 1.4},
{3, 7.7, 3, 6.1, 2.3},
{3, 6.3, 3.4, 5.6, 2.4},
{3, 6.4, 3.1, 5.5, 1.8},
{3, 6, 3, 4.8, 1.8},
{3, 6.9, 3.1, 5.4, 2.1},
{3, 6.7, 3.1, 5.6, 2.4},
{3, 6.9, 3.1, 5.1, 2.3},
{3, 5.8, 2.7, 5.1, 1.9},
{3, 6.8, 3.2, 5.9, 2.3},
{3, 6.7, 3.3, 5.7, 2.5},
{3, 6.7, 3, 5.2, 2.3},
{3, 6.3, 2.5, 5, 1.9},
{3, 6.5, 3, 5.2, 2},
{3, 6.2, 3.4, 5.4, 2.3},
{3, 5.9, 3, 5.1, 1.8}
};
}