blob: 03c82ef43008b5dace89f13005dc081e72a91809 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.examples.ml.regression.linear;
import java.util.Arrays;
import java.util.UUID;
import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerPreprocessor;
import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
import org.apache.ignite.thread.IgniteThread;
/**
* Run linear regression model over cached dataset.
*
* @see LinearRegressionLSQRTrainer
* @see MinMaxScalerTrainer
* @see MinMaxScalerPreprocessor
*/
public class LinearRegressionLSQRTrainerWithMinMaxScalerExample {
/** */
private static final double[][] data = {
{8, 78, 284, 9.100000381, 109},
{9.300000191, 68, 433, 8.699999809, 144},
{7.5, 70, 739, 7.199999809, 113},
{8.899999619, 96, 1792, 8.899999619, 97},
{10.19999981, 74, 477, 8.300000191, 206},
{8.300000191, 111, 362, 10.89999962, 124},
{8.800000191, 77, 671, 10, 152},
{8.800000191, 168, 636, 9.100000381, 162},
{10.69999981, 82, 329, 8.699999809, 150},
{11.69999981, 89, 634, 7.599999905, 134},
{8.5, 149, 631, 10.80000019, 292},
{8.300000191, 60, 257, 9.5, 108},
{8.199999809, 96, 284, 8.800000191, 111},
{7.900000095, 83, 603, 9.5, 182},
{10.30000019, 130, 686, 8.699999809, 129},
{7.400000095, 145, 345, 11.19999981, 158},
{9.600000381, 112, 1357, 9.699999809, 186},
{9.300000191, 131, 544, 9.600000381, 177},
{10.60000038, 80, 205, 9.100000381, 127},
{9.699999809, 130, 1264, 9.199999809, 179},
{11.60000038, 140, 688, 8.300000191, 80},
{8.100000381, 154, 354, 8.399999619, 103},
{9.800000191, 118, 1632, 9.399999619, 101},
{7.400000095, 94, 348, 9.800000191, 117},
{9.399999619, 119, 370, 10.39999962, 88},
{11.19999981, 153, 648, 9.899999619, 78},
{9.100000381, 116, 366, 9.199999809, 102},
{10.5, 97, 540, 10.30000019, 95},
{11.89999962, 176, 680, 8.899999619, 80},
{8.399999619, 75, 345, 9.600000381, 92},
{5, 134, 525, 10.30000019, 126},
{9.800000191, 161, 870, 10.39999962, 108},
{9.800000191, 111, 669, 9.699999809, 77},
{10.80000019, 114, 452, 9.600000381, 60},
{10.10000038, 142, 430, 10.69999981, 71},
{10.89999962, 238, 822, 10.30000019, 86},
{9.199999809, 78, 190, 10.69999981, 93},
{8.300000191, 196, 867, 9.600000381, 106},
{7.300000191, 125, 969, 10.5, 162},
{9.399999619, 82, 499, 7.699999809, 95},
{9.399999619, 125, 925, 10.19999981, 91},
{9.800000191, 129, 353, 9.899999619, 52},
{3.599999905, 84, 288, 8.399999619, 110},
{8.399999619, 183, 718, 10.39999962, 69},
{10.80000019, 119, 540, 9.199999809, 57},
{10.10000038, 180, 668, 13, 106},
{9, 82, 347, 8.800000191, 40},
{10, 71, 345, 9.199999809, 50},
{11.30000019, 118, 463, 7.800000191, 35},
{11.30000019, 121, 728, 8.199999809, 86},
{12.80000019, 68, 383, 7.400000095, 57},
{10, 112, 316, 10.39999962, 57},
{6.699999809, 109, 388, 8.899999619, 94}
};
/** Run example. */
public static void main(String[] args) throws InterruptedException {
System.out.println();
System.out.println(">>> Linear regression model over cached dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
LinearRegressionLSQRTrainerWithMinMaxScalerExample.class.getSimpleName(), () -> {
IgniteCache<Integer, Vector> dataCache = getTestCache(ignite);
System.out.println(">>> Create new minmaxscaling trainer object.");
MinMaxScalerTrainer<Integer, Vector> normalizationTrainer = new MinMaxScalerTrainer<>();
System.out.println(">>> Perform the training to get the minmaxscaling preprocessor.");
IgniteBiFunction<Integer, Vector, Vector> preprocessor = normalizationTrainer.fit(
ignite,
dataCache,
(k, v) -> {
double[] arr = v.asArray();
return VectorUtils.of(Arrays.copyOfRange(arr, 1, arr.length));
}
);
System.out.println(">>> Create new linear regression trainer object.");
LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
System.out.println(">>> Perform the training to get the model.");
LinearRegressionModel mdl = trainer.fit(ignite, dataCache, preprocessor, (k, v) -> v.get(0));
System.out.println(">>> Linear regression model: " + mdl);
System.out.println(">>> ---------------------------------");
System.out.println(">>> | Prediction\t| Ground Truth\t|");
System.out.println(">>> ---------------------------------");
try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
for (Cache.Entry<Integer, Vector> observation : observations) {
Integer key = observation.getKey();
Vector val = observation.getValue();
double groundTruth = val.get(0);
double prediction = mdl.apply(preprocessor.apply(key, val));
System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
}
}
System.out.println(">>> ---------------------------------");
});
igniteThread.start();
igniteThread.join();
}
}
/**
* Fills cache with data and returns it.
*
* @param ignite Ignite instance.
* @return Filled Ignite Cache.
*/
private static IgniteCache<Integer, Vector> getTestCache(Ignite ignite) {
CacheConfiguration<Integer, Vector> cacheConfiguration = new CacheConfiguration<>();
cacheConfiguration.setName("TEST_" + UUID.randomUUID());
cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));
IgniteCache<Integer, Vector> cache = ignite.createCache(cacheConfiguration);
for (int i = 0; i < data.length; i++)
cache.put(i, VectorUtils.of(data[i]));
return cache;
}
}