blob: 9eefb74cf04360d704fc2aff6bb35efb1d5e69f7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.composition;
import java.util.Collections;
import java.util.List;
import com.fasterxml.jackson.annotation.JsonIgnore;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
import org.apache.ignite.ml.environment.deploy.DeployableObject;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.util.ModelTrace;
/**
* Model consisting of several models and prediction aggregation strategy.
*/
public class ModelsComposition<M extends IgniteModel<Vector, Double>>
implements IgniteModel<Vector, Double>, Exportable<ModelsCompositionFormat>, DeployableObject {
/**
* Predictions aggregator.
*/
protected PredictionsAggregator predictionsAggregator;
/**
* Models.
*/
protected List<M> models;
/**
* Constructs a new instance of composition of models.
*
* @param models Basic models.
* @param predictionsAggregator Predictions aggregator.
*/
public ModelsComposition(List<M> models, PredictionsAggregator predictionsAggregator) {
this.predictionsAggregator = predictionsAggregator;
this.models = Collections.unmodifiableList(models);
}
/** */
public ModelsComposition() {
}
/**
* Applies containing models to features and aggregate them to one prediction.
*
* @param features Features vector.
* @return Estimation.
*/
@Override public Double predict(Vector features) {
double[] predictions = new double[models.size()];
for (int i = 0; i < models.size(); i++)
predictions[i] = models.get(i).predict(features);
return predictionsAggregator.apply(predictions);
}
/**
* Returns predictions aggregator.
*/
public PredictionsAggregator getPredictionsAggregator() {
return predictionsAggregator;
}
/**
* Returns containing models.
*/
public List<M> getModels() {
return models;
}
/** {@inheritDoc} */
@Override public <P> void saveModel(Exporter<ModelsCompositionFormat, P> exporter, P path) {
ModelsCompositionFormat format = new ModelsCompositionFormat(models, predictionsAggregator);
exporter.save(format, path);
}
/** {@inheritDoc} */
@Override public String toString() {
return toString(false);
}
/** {@inheritDoc} */
@Override public String toString(boolean pretty) {
return ModelTrace.builder("Models composition", pretty)
.addField("aggregator", predictionsAggregator.toString(pretty))
.addField("models", models)
.toString();
}
/** {@inheritDoc} */
@JsonIgnore
@Override public List<Object> getDependencies() {
return Collections.singletonList(predictionsAggregator);
}
}