blob: d0a647093f0e0685c6d2cd7a2689aa7371c047dc [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.naivebayes.gaussian;
import java.util.Collections;
import java.util.List;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.environment.deploy.DeployableObject;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.naivebayes.BayesModel;
/**
* Simple naive Bayes model which predicts result value {@code y} belongs to a class {@code C_k, k in [0..K]} as {@code
* p(C_k,y) = p(C_k)*p(y_1,C_k) *...*p(y_n,C_k) / p(y)}. Return the number of the most possible class.
*/
public class GaussianNaiveBayesModel implements BayesModel<GaussianNaiveBayesModel, Vector, Double>, DeployableObject {
/** Serial version uid. */
private static final long serialVersionUID = -127386523291350345L;
/** Means of features for all classes. kth row contains means for labels[k] class. */
private final double[][] means;
/** Variances of features for all classes. kth row contains variances for labels[k] class */
private final double[][] variances;
/** Prior probabilities of each class */
private final double[] classProbabilities;
/** Labels. */
private final double[] labels;
/** Feature sum, squared sum and count per label. */
private final GaussianNaiveBayesSumsHolder sumsHolder;
/**
* @param means Means of features for all classes.
* @param variances Variances of features for all classes.
* @param classProbabilities Probabilities for all classes.
* @param labels Labels.
* @param sumsHolder Feature sum, squared sum and count sum per label. This data is used for future model updating.
*/
public GaussianNaiveBayesModel(double[][] means, double[][] variances,
double[] classProbabilities, double[] labels, GaussianNaiveBayesSumsHolder sumsHolder) {
this.means = means;
this.variances = variances;
this.classProbabilities = classProbabilities;
this.labels = labels;
this.sumsHolder = sumsHolder;
}
/** {@inheritDoc} */
@Override public <P> void saveModel(Exporter<GaussianNaiveBayesModel, P> exporter, P path) {
exporter.save(this, path);
}
/** Returns a number of class to which the input belongs. */
@Override public Double predict(Vector vector) {
double[] probabilityPowers = probabilityPowers(vector);
int max = 0;
for (int i = 0; i < probabilityPowers.length; i++) {
probabilityPowers[i] += Math.log(classProbabilities[i]);
if (probabilityPowers[i] > probabilityPowers[max])
max = i;
}
return labels[max];
}
/** {@inheritDoc} */
@Override public double[] probabilityPowers(Vector vector) {
double[] probabilityPowers = new double[classProbabilities.length];
for (int i = 0; i < classProbabilities.length; i++) {
for (int j = 0; j < vector.size(); j++) {
double x = vector.get(j);
double probability = gauss(x, means[i][j], variances[i][j]);
probabilityPowers[i] += (probability > 0 ? Math.log(probability) : .0);
}
}
return probabilityPowers;
}
/** A getter for means.*/
public double[][] getMeans() {
return means;
}
/** A getter for variances.*/
public double[][] getVariances() {
return variances;
}
/** A getter for classProbabilities.*/
public double[] getClassProbabilities() {
return classProbabilities;
}
/** A getter for labels.*/
public double[] getLabels() {
return labels;
}
/** A getter for sumsHolder.*/
public GaussianNaiveBayesSumsHolder getSumsHolder() {
return sumsHolder;
}
/** Gauss distribution. */
private static double gauss(double x, double mean, double variance) {
return Math.exp(-1. * Math.pow(x - mean, 2) / (2. * variance)) / Math.sqrt(2. * Math.PI * variance);
}
/** {@inheritDoc} */
@Override public List<Object> getDependencies() {
return Collections.emptyList();
}
}