blob: 8e64c68e67f4ec375e8cff7c123def60c7b79b0e [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
= Cross-Validation
Cross validation functionality in Apache Ignite is represented by the `CrossValidation` class. This is a calculator parameterized by the type of model, type of label and key-value types of data. After instantiation (constructor doesnt accept any additional parameters) we can use a score method to perform cross validation.
Lets imagine that we have a trainer, a training set and we want to make cross validation using accuracy as a metric and using 4 folds. Apache Ignite allows us to do this as shown in the following example:
== Cross-Validation (without Pipeline API usage)
[source, java]
----
// Create classification trainer
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
// Create cross-validation instance
CrossValidation<DecisionTreeNode, Integer, Vector> scoreCalculator
= new CrossValidation<>();
// Set up the cross-validation process
scoreCalculator
.withIgnite(ignite)
.withUpstreamCache(trainingSet)
.withTrainer(trainer)
.withMetric(MetricName.ACCURACY)
.withPreprocessor(vectorizer)
.withAmountOfFolds(4)
.isRunningOnPipeline(false)
// Calculate accuracy for each fold
double[] accuracyScores = scoreCalculator.scoreByFolds();
----
In this example we specify trainer and metric as parameters, after that we pass common training arguments such as a link to the Ignite instance, cache, vectorizers, and finally specify the number of folds. This method returns an array containing chosen metrics for all possible splits of the training set.
== Cross-Validation (with Pipeline API usage)
Define the pipeline and pass it as a parameter to Cross-Validation instance to run cross-validation on Pipeline.
CAUTION: The Pipeline API is experimental and could be changed in the next releases.
[source, java]
----
// Create classification trainer
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
Pipeline<Integer, Vector, Integer, Double> pipeline
= new Pipeline<Integer, Vector, Integer, Double>()
.addVectorizer(vectorizer)
.addPreprocessingTrainer(new ImputerTrainer<Integer, Vector>())
.addPreprocessingTrainer(new MinMaxScalerTrainer<Integer, Vector>())
.addTrainer(trainer);
// Create cross-validation instance
CrossValidation<DecisionTreeNode, Integer, Vector> scoreCalculator
= new CrossValidation<>();
// Set up the cross-validation process
scoreCalculator
.withIgnite(ignite)
.withUpstreamCache(trainingSet)
.withPipeline(pipeline)
.withMetric(MetricName.ACCURACY)
.withPreprocessor(vectorizer)
.withAmountOfFolds(4)
.isRunningOnPipeline(false)
// Calculate accuracy for each fold
double[] accuracyScores = scoreCalculator.scoreByFolds();
----
== Example
To see how the Cross Validation can be used in practice, try https://github.com/apache/ignite/blob/master/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java[this example] and see step https://github.com/apache/ignite/blob/master/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_pipeline.java[8 of ML Tutorial] that are available on GitHub and delivered with every Apache Ignite distribution.