blob: 1f1c36b101de1452d010790f2f285878b816afda [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.naivebayes.compound;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.ObjectWriter;
import com.fasterxml.jackson.databind.SerializationFeature;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.environment.deploy.DeployableObject;
import org.apache.ignite.ml.inference.json.JSONModel;
import org.apache.ignite.ml.inference.json.JSONModelMixIn;
import org.apache.ignite.ml.inference.json.JSONWritable;
import org.apache.ignite.ml.inference.json.JacksonHelper;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel;
import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel;
/**
* A compound Naive Bayes model which uses a composition of{@code GaussianNaiveBayesModel} and {@code
* DiscreteNaiveBayesModel}.
*/
public class CompoundNaiveBayesModel implements IgniteModel<Vector, Double>, Exportable<CompoundNaiveBayesModel>,
JSONWritable, DeployableObject {
/** Serial version uid. */
private static final long serialVersionUID = -5045925321135798960L;
/** Prior probabilities of each class. */
private double[] priorProbabilities;
/** Labels. */
private double[] labels;
/** Gaussian Bayes model. */
private GaussianNaiveBayesModel gaussianModel;
/** Feature ids which should be skipped in Gaussian model. */
private Collection<Integer> gaussianFeatureIdsToSkip = Collections.emptyList();
/** Discrete Bayes model. */
private DiscreteNaiveBayesModel discreteModel;
/** Feature ids which should be skipped in Discrete model. */
private Collection<Integer> discreteFeatureIdsToSkip = Collections.emptyList();
/** */
public CompoundNaiveBayesModel() {
}
/** {@inheritDoc} */
@Override public <P> void saveModel(Exporter<CompoundNaiveBayesModel, P> exporter, P path) {
exporter.save(this, path);
}
/** {@inheritDoc} */
@Override public Double predict(Vector vector) {
double[] probabilityPowers = new double[priorProbabilities.length];
for (int i = 0; i < priorProbabilities.length; i++)
probabilityPowers[i] = Math.log(priorProbabilities[i]);
if (discreteModel != null)
probabilityPowers = sum(probabilityPowers, discreteModel.probabilityPowers(skipFeatures(vector, discreteFeatureIdsToSkip)));
if (gaussianModel != null)
probabilityPowers = sum(probabilityPowers, gaussianModel.probabilityPowers(skipFeatures(vector, gaussianFeatureIdsToSkip)));
int maxLbIdx = 0;
for (int i = 0; i < probabilityPowers.length; i++) {
if (probabilityPowers[i] > probabilityPowers[maxLbIdx])
maxLbIdx = i;
}
return labels[maxLbIdx];
}
/** Returns a gaussian model. */
public GaussianNaiveBayesModel getGaussianModel() {
return gaussianModel;
}
/** Returns a discrete model. */
public DiscreteNaiveBayesModel getDiscreteModel() {
return discreteModel;
}
/** */
public double[] getPriorProbabilities() {
return priorProbabilities;
}
/** */
public double[] getLabels() {
return labels;
}
/** */
public Collection<Integer> getGaussianFeatureIdsToSkip() {
return gaussianFeatureIdsToSkip;
}
/** */
public Collection<Integer> getDiscreteFeatureIdsToSkip() {
return discreteFeatureIdsToSkip;
}
/** Sets prior probabilities. */
public CompoundNaiveBayesModel withPriorProbabilities(double[] priorProbabilities) {
this.priorProbabilities = priorProbabilities.clone();
return this;
}
/** Sets labels. */
public CompoundNaiveBayesModel withLabels(double[] labels) {
this.labels = labels.clone();
return this;
}
/** Sets a gaussian model. */
public CompoundNaiveBayesModel withGaussianModel(GaussianNaiveBayesModel gaussianModel) {
this.gaussianModel = gaussianModel;
return this;
}
/** Sets a discrete model. */
public CompoundNaiveBayesModel withDiscreteModel(DiscreteNaiveBayesModel discreteModel) {
this.discreteModel = discreteModel;
return this;
}
/** Sets feature ids to skip in Gaussian Bayes. */
public CompoundNaiveBayesModel withGaussianFeatureIdsToSkip(Collection<Integer> gaussianFeatureIdsToSkip) {
this.gaussianFeatureIdsToSkip = gaussianFeatureIdsToSkip;
return this;
}
/** Sets feature ids to skip in discrete Bayes. */
public CompoundNaiveBayesModel withDiscreteFeatureIdsToSkip(Collection<Integer> discreteFeatureIdsToSkip) {
this.discreteFeatureIdsToSkip = discreteFeatureIdsToSkip;
return this;
}
/** Returns index by index sum of two arrays. */
private static double[] sum(double[] arr1, double[] arr2) {
assert arr1.length == arr2.length;
double[] result = new double[arr1.length];
for (int i = 0; i < arr1.length; i++)
result[i] = arr1[i] + arr2[i];
return result;
}
/** Returns a new (shorter) vector without features provided in {@param featureIdsToSkip}. */
private static Vector skipFeatures(Vector vector, Collection<Integer> featureIdsToSkip) {
int newSize = vector.size() - featureIdsToSkip.size();
double[] newFeaturesValues = new double[newSize];
int index = 0;
for (int j = 0; j < vector.size(); j++) {
if (featureIdsToSkip.contains(j)) continue;
newFeaturesValues[index] = vector.get(j);
++index;
}
return VectorUtils.of(newFeaturesValues);
}
/** {@inheritDoc} */
@JsonIgnore
@Override public List<Object> getDependencies() {
return Arrays.asList(discreteModel, gaussianModel);
}
/** {@inheritDoc} */
@Override public void toJSON(Path path) {
ObjectMapper mapper = new ObjectMapper().configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
mapper.addMixIn(CompoundNaiveBayesModel.class, JSONModelMixIn.class);
ObjectWriter writer = mapper
.writerFor(CompoundNaiveBayesModel.class)
.withAttribute("formatVersion", JSONModel.JSON_MODEL_FORMAT_VERSION)
.withAttribute("timestamp", System.currentTimeMillis())
.withAttribute("uid", "dt_" + UUID.randomUUID().toString())
.withAttribute("modelClass", CompoundNaiveBayesModel.class.getSimpleName());
try {
File file = new File(path.toAbsolutePath().toString());
writer.writeValue(file, this);
} catch (IOException e) {
e.printStackTrace();
}
}
/** Loads CompoundNaiveBayesModel from JSON file. */
public static CompoundNaiveBayesModel fromJSON(Path path) {
ObjectMapper mapper = new ObjectMapper();
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
CompoundNaiveBayesModel mdl;
try {
JacksonHelper.readAndValidateBasicJsonModelProperties(path, mapper, CompoundNaiveBayesModel.class.getSimpleName());
mdl = mapper.readValue(new File(path.toAbsolutePath().toString()), CompoundNaiveBayesModel.class);
return mdl;
} catch (IOException e) {
e.printStackTrace();
}
return null;
}
}