blob: 805b6eead5d66639c335e7c85d4286b122a5f400 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.s4.model;
import org.apache.s4.util.MatrixOps;
import org.ejml.data.D1Matrix64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;
/**
* A multivariate Gaussian model with parameters mean (mu) and variance (sigma
* squared). Only diagonal covariance matrices are supported.
*
* @author Leo Neumeyer
*
*/
public class GaussianModel extends Model {
public final static double SMALL_VARIANCE = 0.01f;
public final static double minNumSamples = 0.01f;
private boolean isDiagonal = true; // Full covariance not yet supported.
private D1Matrix64F sumx;
private D1Matrix64F sumxsq;
private D1Matrix64F tmpArray;
private double numSamples;
private D1Matrix64F mean;
private D1Matrix64F variance; // ==> sigma squared
private int numElements;
private double const1; // -(N/2)log(2PI) Depends only on numElements.
private double const2; // const1 - sum(log sigma_i) Also depends on
// variance.
/**
* @param numElements
* the model dimension.
* @param train
* allocate training arrays when true.
*/
public GaussianModel(int numElements, boolean train) {
this(numElements, null, null, train);
}
/**
* Initialize model, no allocation of training arrays.
*
* @param numElements
* the model dimension.
* @param mean
* model parameter.
* @param variance
* model parameter.
*/
public GaussianModel(int numElements, D1Matrix64F mean, D1Matrix64F variance) {
this(numElements, mean, variance, false);
}
/**
* Initialize model, no allocation of training arrays.
*
* @param numElements
* the model dimension.
* @param mean
* model parameter.
* @param variance
* model parameter.
* @param train
* allocate training arrays when true.
*/
public GaussianModel(int numElements, D1Matrix64F mean,
D1Matrix64F variance, boolean train) {
super();
this.numElements = numElements;
tmpArray = new DenseMatrix64F(numElements, 1);
if (mean == null) {
this.mean = new DenseMatrix64F(numElements, 1);
} else {
this.mean = mean;
}
if (variance == null) {
this.variance = new DenseMatrix64F(numElements, 1);
CommonOps.set(this.variance, SMALL_VARIANCE);
} else {
this.variance = variance;
}
const1 = -numElements * (float) Math.log(2 * Math.PI) / 2;
MatrixOps.elementLog(this.variance, tmpArray);
const2 = const1 - CommonOps.elementSum(tmpArray) / 2.0;
/* Allocate arrays needed for estimation. */
if (train == true) {
setTrain(true);
sumx = new DenseMatrix64F(numElements, 1);
sumxsq = new DenseMatrix64F(numElements, 1);
clearStatistics();
} else {
setTrain(false);
}
}
public Model create() {
return new GaussianModel(numElements, isTrain);
}
/**
* @param obs
* the observed data vector.
* @return the log probability.
*/
public double logProb(D1Matrix64F obs) {
CommonOps.sub(mean, obs, tmpArray);
MatrixOps.elementSquare(tmpArray);
CommonOps.elementDiv(tmpArray, variance);
return const2 - CommonOps.elementSum(tmpArray) / 2.0;
}
/**
* @param obs
* the observed data vector.
* @return the log probability.
*/
public double logProb(float[] obs) {
return logProb(MatrixOps.floatArrayToMatrix(obs));
}
/**
* @param obs
* the observed data vector.
* @return the log probability.
*/
public double logProb(double[] obs) {
return logProb(MatrixOps.doubleArrayToMatrix(obs));
}
/*
* (non-Javadoc)
*
* @see org.apache.s4.model.Model#evaluate(double[])
*/
public double prob(double[] obs) {
return prob(MatrixOps.doubleArrayToMatrix(obs));
}
/** Evaluate using float array. */
public double prob(float[] obs) {
return prob(MatrixOps.floatArrayToMatrix(obs));
}
/**
* @param obs
* the observed data vector.
* @return the probability.
*/
public double prob(D1Matrix64F obs) {
return Math.exp(logProb(obs));
}
/*
* (non-Javadoc)
*
* @see org.apache.s4.model.Model#update(double[])
*/
public void update(double[] obs) {
update(MatrixOps.doubleArrayToMatrix(obs));
}
/** Update using float array. */
public void update(float[] obs) {
update(MatrixOps.floatArrayToMatrix(obs));
}
/**
* Update sufficient statistics.
*
* @param obs
* the observed data vector.
*/
public void update(D1Matrix64F obs) {
if (isTrain() == true) {
/* Update sufficient statistics. */
CommonOps.add(obs, sumx, sumx);
MatrixOps.elementSquare(obs, tmpArray);
CommonOps.add(tmpArray, sumxsq, sumxsq);
numSamples++;
}
}
/**
* Update sufficient statistics.
*
* @param obs
* the observed data vector.
* @param weight
* the weight assigned to this observation.
*/
public void update(D1Matrix64F obs, double weight) {
if (isTrain() == true) {
/* Update sufficient statistics. */
CommonOps.scale(weight, obs, tmpArray);
CommonOps.add(tmpArray, sumx, sumx);
MatrixOps.elementSquare(obs, tmpArray);
CommonOps.scale(weight, tmpArray);
CommonOps.add(tmpArray, sumxsq, sumxsq);
numSamples += weight;
}
}
/*
* (non-Javadoc)
*
* @see org.apache.s4.model.Model#estimate()
*/
public void estimate() {
if (numSamples > minNumSamples) {
/* Estimate the mean. */
CommonOps.scale(1.0 / numSamples, sumx, mean);
/*
* Estimate the variance. sigma_sq = 1/n (sumxsq - 1/n sumx^2) or
* 1/n sumxsq - mean^2.
*/
D1Matrix64F tmp = variance; // borrow as an intermediate array.
MatrixOps.elementSquare(mean, tmpArray);
CommonOps.scale(1.0 / numSamples, sumxsq, tmp);
CommonOps.sub(tmp, tmpArray, variance);
MatrixOps.elementFloor(SMALL_VARIANCE, variance, variance);
} else {
/* Not enough training sample. */
CommonOps.set(variance, SMALL_VARIANCE);
CommonOps.set(mean, 0.0);
}
/* Update log Gaussian constant. */
MatrixOps.elementLog(this.variance, tmpArray);
const2 = const1 - CommonOps.elementSum(tmpArray) / 2.0;
}
/*
* (non-Javadoc)
*
* @see org.apache.s4.model.Model#clearStatistics()
*/
public void clearStatistics() {
if (isTrain() == true) {
CommonOps.set(sumx, 0.0);
CommonOps.set(sumxsq, 0.0);
numSamples = 0;
}
}
/** @return the mean (mu) of the Gaussian density. */
public double[] getMean() {
DenseMatrix64F tmp = new DenseMatrix64F(mean);
return tmp.getData();
}
/** @return the variance (sigma squared) of the Gaussian density. */
public double[] getVariance() {
DenseMatrix64F tmp = new DenseMatrix64F(variance);
return tmp.getData();
}
public void setMean(D1Matrix64F mean) {
this.mean = mean;
}
public void setVariance(D1Matrix64F variance) {
this.variance = variance;
/* Update log Gaussian constant. */
MatrixOps.elementLog(this.variance, tmpArray);
const2 = const1 - CommonOps.elementSum(tmpArray) / 2.0;
}
/** @return the standard deviation (sigma) of the Gaussian density. */
public double[] getStd() {
DenseMatrix64F std = new DenseMatrix64F(numElements, 1);
MatrixOps.elementSquareRoot(variance, std);
return std.getData();
}
/** @return the sum of the observed vectors. */
public double[] getSumX() {
DenseMatrix64F tmp = new DenseMatrix64F(sumx);
return tmp.getData();
}
/** @return the sum of the observed vectors squared. */
public double[] getSumXSq() {
DenseMatrix64F tmp = new DenseMatrix64F(sumxsq);
return tmp.getData();
}
/** @return the number of observations. */
public double getNumSamples() {
return numSamples;
}
/** @return the dimensionality. */
public int getNumElements() {
return numElements;
}
/** @return true if the covariance matrix is diagonal. */
public boolean isDiagonal() {
return isDiagonal;
}
/**
* @return the value of the parameters and sufficient statistics of this
* model in a printable format.
*/
public String toString() {
StringBuilder sb = new StringBuilder("");
sb.append("Gaussian Model\n");
sb.append("const: " + const2 + "\n");
sb.append("num samp: " + numSamples + "\n");
sb.append("mean: " + mean.toString() + "\n");
sb.append("var: " + variance.toString() + "\n");
sb.append("sumx: " + sumx.toString() + "\n");
sb.append("sunxsq: " + sumxsq.toString() + "\n");
return sb.toString();
}
}