blob: cea87c56d242a19d7a8af3f4bcdeb6ceaf580144 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.pipeline;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.trainers.DatasetTrainer;
/**
* A simple pipeline, which acts as a global trainer which produce a Pipeline Model.
* A Pipeline consists of a sequence of stages, each of which is either a Preprocessing Stage or a Trainer.
* When {@code fit()} method is called, the stages are executed in order.
*
* @param <K> Type of a key in {@code upstream} data.
* @param <V> Type of a value in {@code upstream} data.
*/
public class Pipeline<K, V, C extends Serializable, L> implements Serializable {
/** Final Feature extractor. */
private Preprocessor<K, V> finalPreprocessor;
/** Vectorizer. */
private Vectorizer<K, V, C, L> vectorizer;
/** Preprocessor stages. */
private List<PreprocessingTrainer> preprocessingTrainers = new ArrayList<>();
/** Final trainer stage. */
private DatasetTrainer finalStage;
/** Learning environment builder. */
private LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder();
/**
* @param vectorizer Vectorizer.
*/
public Pipeline<K, V, C, L> addVectorizer(Vectorizer<K, V, C, L> vectorizer) {
this.vectorizer = vectorizer;
return this;
}
/**
* Adds a preprocessor.
*
* @param preprocessingTrainer The parameter value.
* @return The updated Pipeline.
*/
public Pipeline<K, V, C, L> addPreprocessingTrainer(PreprocessingTrainer preprocessingTrainer) {
preprocessingTrainers.add(preprocessingTrainer);
return this;
}
/**
* Adds a trainer.
*
* @param trainer The parameter value.
* @return The updated Pipeline.
*/
public Pipeline<K, V, C, L> addTrainer(DatasetTrainer trainer) {
this.finalStage = trainer;
return this;
}
/**
* Returns trainer.
*/
public DatasetTrainer getTrainer() {
return finalStage;
}
/**
* Fits the pipeline to the input cache.
*
* @param ignite Ignite instance.
* @param cache Ignite cache with {@code upstream} data.
* @return The fitted model based on chain of preprocessors and final trainer.
*/
public PipelineMdl<K, V> fit(Ignite ignite, IgniteCache<K, V> cache) {
DatasetBuilder datasetBuilder = new CacheBasedDatasetBuilder<>(ignite, cache);
return fit(datasetBuilder);
}
/**
* Set learning environment builder.
*
* @param envBuilder Learning environment builder.
*/
public void setEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
this.envBuilder = envBuilder;
}
/**
* Fits the pipeline to the input mock data.
*
* @param data Data.
* @param parts Number of partitions.
* @return The fitted model based on chain of preprocessors and final trainer.
*/
public PipelineMdl<K, V> fit(Map<K, V> data, int parts) {
DatasetBuilder datasetBuilder = new LocalDatasetBuilder<>(data, parts);
return fit(datasetBuilder);
}
/** Fits the pipeline to the input dataset builder. */
public PipelineMdl<K, V> fit(DatasetBuilder datasetBuilder) {
if (finalStage == null)
throw new IllegalStateException("The Pipeline should be finished with the Training Stage.");
// Reload for new fit
finalPreprocessor = vectorizer;
preprocessingTrainers.forEach(e -> {
finalPreprocessor = e.fit(
envBuilder,
datasetBuilder,
finalPreprocessor
);
});
LearningEnvironment env = LearningEnvironmentBuilder.defaultBuilder().buildForTrainer();
env.initDeployingContext(finalPreprocessor);
IgniteModel<Vector, Double> internalMdl = finalStage
.fit(
datasetBuilder,
finalPreprocessor,
env
);
return new PipelineMdl<K, V>()
.withPreprocessor(finalPreprocessor)
.withInternalMdl(internalMdl);
}
/**
* Returns the final preprocessor for evaluation needs.
*/
public Preprocessor<K, V> getFinalPreprocessor() {
return finalPreprocessor;
}
}