| // Licensed to the Apache Software Foundation (ASF) under one or more |
| // contributor license agreements. See the NOTICE file distributed with |
| // this work for additional information regarding copyright ownership. |
| // The ASF licenses this file to You under the Apache License, Version 2.0 |
| // (the "License"); you may not use this file except in compliance with |
| // the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| = Cross-Validation |
| |
| Cross validation functionality in Apache Ignite is represented by the `CrossValidation` class. This is a calculator parameterized by the type of model, type of label and key-value types of data. After instantiation (constructor doesn’t accept any additional parameters) we can use a score method to perform cross validation. |
| |
| Let’s imagine that we have a trainer, a training set and we want to make cross validation using accuracy as a metric and using 4 folds. Apache Ignite allows us to do this as shown in the following example: |
| |
| |
| == Cross-Validation (without Pipeline API usage) |
| |
| [source, java] |
| ---- |
| // Create classification trainer |
| DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0); |
| |
| // Create cross-validation instance |
| CrossValidation<DecisionTreeNode, Integer, Vector> scoreCalculator |
| = new CrossValidation<>(); |
| |
| // Set up the cross-validation process |
| scoreCalculator |
| .withIgnite(ignite) |
| .withUpstreamCache(trainingSet) |
| .withTrainer(trainer) |
| .withMetric(MetricName.ACCURACY) |
| .withPreprocessor(vectorizer) |
| .withAmountOfFolds(4) |
| .isRunningOnPipeline(false) |
| |
| // Calculate accuracy for each fold |
| double[] accuracyScores = scoreCalculator.scoreByFolds(); |
| ---- |
| |
| In this example we specify trainer and metric as parameters, after that we pass common training arguments such as a link to the Ignite instance, cache, vectorizers, and finally specify the number of folds. This method returns an array containing chosen metrics for all possible splits of the training set. |
| |
| == Cross-Validation (with Pipeline API usage) |
| |
| Define the pipeline and pass it as a parameter to Cross-Validation instance to run cross-validation on Pipeline. |
| |
| CAUTION: The Pipeline API is experimental and could be changed in the next releases. |
| |
| |
| [source, java] |
| ---- |
| // Create classification trainer |
| DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0); |
| |
| Pipeline<Integer, Vector, Integer, Double> pipeline |
| = new Pipeline<Integer, Vector, Integer, Double>() |
| .addVectorizer(vectorizer) |
| .addPreprocessingTrainer(new ImputerTrainer<Integer, Vector>()) |
| .addPreprocessingTrainer(new MinMaxScalerTrainer<Integer, Vector>()) |
| .addTrainer(trainer); |
| |
| |
| // Create cross-validation instance |
| CrossValidation<DecisionTreeNode, Integer, Vector> scoreCalculator |
| = new CrossValidation<>(); |
| |
| // Set up the cross-validation process |
| scoreCalculator |
| .withIgnite(ignite) |
| .withUpstreamCache(trainingSet) |
| .withPipeline(pipeline) |
| .withMetric(MetricName.ACCURACY) |
| .withPreprocessor(vectorizer) |
| .withAmountOfFolds(4) |
| .isRunningOnPipeline(false) |
| |
| // Calculate accuracy for each fold |
| double[] accuracyScores = scoreCalculator.scoreByFolds(); |
| ---- |
| |
| |
| == Example |
| |
| To see how the Cross Validation can be used in practice, try https://github.com/apache/ignite/blob/master/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java[this example] and see step https://github.com/apache/ignite/blob/master/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_pipeline.java[8 of ML Tutorial] that are available on GitHub and delivered with every Apache Ignite distribution. |