blob: 518bcef53f0e011772b20f29afaf398a53c2c183 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.selection.cv;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.ignite.lang.IgniteBiPredicate;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.environment.parallelism.Promise;
import org.apache.ignite.ml.math.functions.IgniteDoubleConsumer;
import org.apache.ignite.ml.math.functions.IgniteSupplier;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.pipeline.Pipeline;
import org.apache.ignite.ml.pipeline.PipelineMdl;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.selection.paramgrid.BruteForceStrategy;
import org.apache.ignite.ml.selection.paramgrid.EvolutionOptimizationStrategy;
import org.apache.ignite.ml.selection.paramgrid.HyperParameterTuningStrategy;
import org.apache.ignite.ml.selection.paramgrid.ParamGrid;
import org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator;
import org.apache.ignite.ml.selection.paramgrid.RandomStrategy;
import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.selection.scoring.metric.Metric;
import org.apache.ignite.ml.selection.scoring.metric.MetricName;
import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper;
import org.apache.ignite.ml.selection.split.mapper.UniformMapper;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.util.genetic.Chromosome;
import org.apache.ignite.ml.util.genetic.GeneticAlgorithm;
import org.jetbrains.annotations.NotNull;
/**
* Cross validation score calculator. Cross validation is an approach that allows to avoid overfitting that is made the
* following way: the training set is split into k smaller sets. The following procedure is followed for each of the k
* “folds”:
* <ul>
* <li>A model is trained using k-1 of the folds as training data;</li>
* <li>the resulting model is validated on the remaining part of the data (i.e., it is used as a test set to compute
* a performance measure such as accuracy).</li>
* </ul>
*
* @param <M> Type of model.
* @param <K> Type of a key in {@code upstream} data.
* @param <V> Type of a value in {@code upstream} data.
*/
public abstract class AbstractCrossValidation<M extends IgniteModel<Vector, Double>, K, V> implements Serializable {
/** Learning environment builder. */
protected LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder();
/** Learning Environment. */
protected LearningEnvironment environment = envBuilder.buildForTrainer();
/** Trainer. */
protected DatasetTrainer<M, Double> trainer;
/** Pipeline. */
protected Pipeline<K, V, Integer, Double> pipeline;
/** Metric. */
protected Metric metric;
/** Preprocessor. */
protected Preprocessor<K, V> preprocessor;
/** Filter. */
protected IgniteBiPredicate<K, V> filter = (k, v) -> true;
/** Amount of folds. */
protected int amountOfFolds;
/** Parts. */
protected int parts;
/** Parameter grid. */
protected ParamGrid paramGrid;
/** Execution over the pipeline or the chain of preprocessors and separate trainer, otherwise. */
protected boolean isRunningOnPipeline = true;
/** Mapper. */
protected UniformMapper<K, V> mapper = new SHA256UniformMapper<>();
/**
* Finds the best set of hyper-parameters based on parameter search strategy.
*/
public CrossValidationResult tuneHyperParameters() {
HyperParameterTuningStrategy hyperParamTuningStgy = paramGrid.getHyperParameterTuningStrategy();
if (hyperParamTuningStgy instanceof BruteForceStrategy) return scoreBruteForceHyperparameterOptimization();
if (hyperParamTuningStgy instanceof RandomStrategy) return scoreRandomSearchHyperparameterOptimization();
if (hyperParamTuningStgy instanceof EvolutionOptimizationStrategy)
return scoreEvolutionAlgorithmSearchHyperparameterOptimization();
else throw new UnsupportedOperationException("This strategy is not supported yet [strategy="
+ paramGrid.getHyperParameterTuningStrategy().getName() + "]");
}
/**
* Finds the best set of hyper-parameters based on Genetic Programming approach.
*/
private CrossValidationResult scoreEvolutionAlgorithmSearchHyperparameterOptimization() {
EvolutionOptimizationStrategy stgy = (EvolutionOptimizationStrategy) paramGrid.getHyperParameterTuningStrategy();
List<Double[]> paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIdx()).generate();
// initialization
List<Double[]> paramSetsCp = new ArrayList<>(paramSets);
Collections.shuffle(paramSetsCp, new Random(stgy.getSeed()));
int sizeOfPopulation = 20;
List<Double[]> rndParamSets = paramSetsCp.subList(0, sizeOfPopulation);
CrossValidationResult cvRes = new CrossValidationResult();
Function<Chromosome, Double> fitnessFunction = (Chromosome chromosome) -> {
TaskResult tr = calculateScoresForFixedParamSet(chromosome.toDoubleArray());
cvRes.addScores(tr.locScores, tr.paramMap);
final double locAvgScore = Arrays.stream(tr.locScores).average().orElse(Double.MIN_VALUE);
if (locAvgScore >= cvRes.getBestAvgScore()) {
cvRes.setBestScore(tr.locScores);
cvRes.setBestHyperParams(tr.paramMap);
}
return locAvgScore;
};
Random rnd = new Random(stgy.getSeed()); //TODO: common seed for shared lambdas can produce the same value on each function call? or sequent?
BiFunction<Integer, Double, Double> mutator = (Integer geneIdx, Double geneValue) -> {
Double newGeneVal;
Double[] possibleGeneValues = paramGrid.getParamRawData().get(geneIdx);
newGeneVal = possibleGeneValues[rnd.nextInt(possibleGeneValues.length)];
return newGeneVal;
};
GeneticAlgorithm ga = new GeneticAlgorithm(rndParamSets);
ga.withFitnessFunction(fitnessFunction)
.withMutationOperator(mutator)
.withAmountOfEliteChromosomes(stgy.getAmountOfEliteChromosomes())
.withCrossingoverProbability(stgy.getCrossingoverProbability())
.withCrossoverStgy(stgy.getCrossoverStgy())
.withAmountOfGenerations(stgy.getAmountOfGenerations())
.withSelectionStgy(stgy.getSelectionStgy())
.withMutationProbability(stgy.getMutationProbability());
if (environment.parallelismStrategy().getParallelism() > 1)
ga.runParallel(environment);
else
ga.run();
return cvRes;
}
/**
* Finds the best set of hyperparameters based on Random Serach.
*/
private CrossValidationResult scoreRandomSearchHyperparameterOptimization() {
RandomStrategy stgy = (RandomStrategy) paramGrid.getHyperParameterTuningStrategy();
List<Double[]> paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIdx()).generate();
List<Double[]> paramSetsCp = new ArrayList<>(paramSets);
Collections.shuffle(paramSetsCp, new Random(stgy.getSeed()));
CrossValidationResult cvRes = new CrossValidationResult();
List<Double[]> rndParamSets = paramSetsCp.subList(0, stgy.getMaxTries());
List<IgniteSupplier<TaskResult>> tasks = rndParamSets.stream()
.map(paramSet -> (IgniteSupplier<TaskResult>) (() -> calculateScoresForFixedParamSet(paramSet)))
.collect(Collectors.toList());
List<TaskResult> taskResults = environment.parallelismStrategy().submit(tasks).stream()
.map(Promise::unsafeGet)
.collect(Collectors.toList());
taskResults.forEach(tr -> {
cvRes.addScores(tr.locScores, tr.paramMap);
final double locAvgScore = Arrays.stream(tr.locScores).average().orElse(Double.MIN_VALUE);
if (locAvgScore >= cvRes.getBestAvgScore()) {
cvRes.setBestScore(tr.locScores);
cvRes.setBestHyperParams(tr.paramMap);
}
});
return cvRes;
}
/**
* Finds the best set of hyperparameters based on brute force approach .
*/
private CrossValidationResult scoreBruteForceHyperparameterOptimization() {
List<Double[]> paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIdx()).generate();
CrossValidationResult cvRes = new CrossValidationResult();
List<IgniteSupplier<TaskResult>> tasks = paramSets.stream()
.map(paramSet -> (IgniteSupplier<TaskResult>) (() -> calculateScoresForFixedParamSet(paramSet)))
.collect(Collectors.toList());
List<TaskResult> taskResults = environment.parallelismStrategy().submit(tasks).stream()
.map(Promise::unsafeGet)
.collect(Collectors.toList());
taskResults.forEach(tr -> {
cvRes.addScores(tr.locScores, tr.paramMap);
final double locAvgScore = Arrays.stream(tr.locScores).average().orElse(Double.MIN_VALUE);
if (locAvgScore > cvRes.getBestAvgScore()) {
cvRes.setBestScore(tr.locScores);
cvRes.setBestHyperParams(tr.paramMap);
}
});
return cvRes;
}
/**
* Represents the scores and map of parameters.
*/
public static class TaskResult {
/** Parameter map. */
private Map<String, Double> paramMap;
/** Local scores. */
private double[] locScores;
/**
* @param paramMap Parameter map.
* @param locScores Locale scores.
*/
public TaskResult(Map<String, Double> paramMap, double[] locScores) {
this.paramMap = paramMap;
this.locScores = locScores;
}
/**
* @param paramMap Parameter map.
*/
public void setParamMap(Map<String, Double> paramMap) {
this.paramMap = paramMap;
}
/**
* @param locScores Local scores.
*/
public void setLocScores(double[] locScores) {
this.locScores = locScores;
}
}
/**
* Calculates scores by folds and wrap it to TaskResult object.
*
* @param paramSet Parameter set.
*/
private TaskResult calculateScoresForFixedParamSet(Double[] paramSet) {
Map<String, Double> paramMap = injectAndGetParametersFromPipeline(paramGrid, paramSet);
double[] locScores = scoreByFolds();
return new TaskResult(paramMap, locScores);
}
/**
* Calculates score by folds.
*/
public abstract double[] scoreByFolds();
/**
* Forms the parameter map from parameter grid and parameter set.
*
* @param paramGrid Parameter grid.
* @param paramSet Parameter set.
*/
@NotNull private Map<String, Double> injectAndGetParametersFromPipeline(ParamGrid paramGrid, Double[] paramSet) {
Map<String, Double> paramMap = new HashMap<>();
for (int paramIdx = 0; paramIdx < paramSet.length; paramIdx++) {
IgniteDoubleConsumer setter = paramGrid.getSetterByIndex(paramIdx);
Double paramVal = paramSet[paramIdx];
setter.accept(paramVal);
paramMap.put(paramGrid.getParamNameByIndex(paramIdx), paramVal);
}
return paramMap;
}
/**
* Computes cross-validated metrics.
*
* @param datasetBuilderSupplier Dataset builder supplier.
* @return Array of scores of the estimator for each run of the cross validation.
*/
protected double[] score(Function<IgniteBiPredicate<K, V>, DatasetBuilder<K, V>> datasetBuilderSupplier) {
double[] scores = new double[amountOfFolds];
double foldSize = 1.0 / amountOfFolds;
for (int i = 0; i < amountOfFolds; i++) {
double from = foldSize * i;
double to = foldSize * (i + 1);
IgniteBiPredicate<K, V> trainSetFilter = (k, v) -> {
double pnt = mapper.map(k, v);
return pnt < from || pnt > to;
};
IgniteBiPredicate<K, V> testSetFilter = (k, v) -> !trainSetFilter.apply(k, v);
DatasetBuilder<K, V> trainSet = datasetBuilderSupplier.apply(trainSetFilter);
M mdl = trainer.fit(trainSet, preprocessor);
DatasetBuilder<K, V> testSet = datasetBuilderSupplier.apply(testSetFilter);
scores[i] = Evaluator.evaluate(testSet, mdl, preprocessor, metric).getSingle();
}
return scores;
}
/**
* Computes cross-validated metrics.
*
* @param datasetBuilderSupplier Dataset builder supplier.
* @return Array of scores of the estimator for each run of the cross validation.
*/
protected double[] scorePipeline(Function<IgniteBiPredicate<K, V>, DatasetBuilder<K, V>> datasetBuilderSupplier) {
double[] scores = new double[amountOfFolds];
double foldSize = 1.0 / amountOfFolds;
for (int i = 0; i < amountOfFolds; i++) {
double from = foldSize * i;
double to = foldSize * (i + 1);
IgniteBiPredicate<K, V> trainSetFilter = (k, v) -> {
double pnt = mapper.map(k, v);
return pnt < from || pnt > to;
};
IgniteBiPredicate<K, V> testSetFilter = (k, v) -> !trainSetFilter.apply(k, v);
DatasetBuilder<K, V> trainSet = datasetBuilderSupplier.apply(trainSetFilter);
PipelineMdl<K, V> mdl = pipeline.fit(trainSet);
DatasetBuilder<K, V> testSet = datasetBuilderSupplier.apply(testSetFilter);
scores[i] = Evaluator.evaluate(testSet, mdl, pipeline.getFinalPreprocessor(), metric).getSingle();
}
return scores;
}
/**
* @param trainer Trainer.
*/
public AbstractCrossValidation<M, K, V> withTrainer(DatasetTrainer<M, Double> trainer) {
this.trainer = trainer;
return this;
}
/**
* @param metric Metric.
*/
public AbstractCrossValidation<M, K, V> withMetric(MetricName metric) {
this.metric = metric.create();
return this;
}
/**
* @param preprocessor Preprocessor.
*/
public AbstractCrossValidation<M, K, V> withPreprocessor(Preprocessor<K, V> preprocessor) {
this.preprocessor = preprocessor;
return this;
}
/**
* @param filter Filter.
*/
public AbstractCrossValidation<M, K, V> withFilter(IgniteBiPredicate<K, V> filter) {
this.filter = filter;
return this;
}
/**
* @param amountOfFolds Amount of folds.
*/
public AbstractCrossValidation<M, K, V> withAmountOfFolds(int amountOfFolds) {
this.amountOfFolds = amountOfFolds;
return this;
}
/**
* @param paramGrid Parameter grid.
*/
public AbstractCrossValidation<M, K, V> withParamGrid(ParamGrid paramGrid) {
this.paramGrid = paramGrid;
return this;
}
/**
* @param runningOnPipeline Running on pipeline.
*/
public AbstractCrossValidation<M, K, V> isRunningOnPipeline(boolean runningOnPipeline) {
isRunningOnPipeline = runningOnPipeline;
return this;
}
/**
* Changes learning Environment.
*
* @param envBuilder Learning environment builder.
*/
// TODO: IGNITE-10441 Think about more elegant ways to perform fluent API.
public AbstractCrossValidation<M, K, V> withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
this.envBuilder = envBuilder;
environment = envBuilder.buildForTrainer();
return this;
}
/**
* @param pipeline Pipeline.
*/
public AbstractCrossValidation<M, K, V> withPipeline(Pipeline<K, V, Integer, Double> pipeline) {
this.pipeline = pipeline;
return this;
}
/**
* @param mapper Mapper.
*/
public AbstractCrossValidation<M, K, V> withMapper(UniformMapper<K, V> mapper) {
this.mapper = mapper;
return this;
}
}