blob: d39d4524eaa2d8f4c9cc089496d788cf622c2fb4 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# Prediction function for a Gaussian Mixture Model (gmm).
# Compute posterior probabilities for new instances given the variance and mean of fitted dat.
#
# INPUT:
# ------------------------------------------------------------------------------------------
# X Dataset input to predict the labels from
# weight Weight of learned model:
# A matrix whose [i,k]th entry is the probability
# that observation i in the test data belongs to the kth class
# mu Fitted clusters mean
# precisions_cholesky Fitted precision matrix for each mixture
# model "VVV": unequal variance (full),each component has its own general covariance matrix
# "EEE": equal variance (tied), all components share the same general covariance matrix
# "VVI": spherical, unequal volume (diag), each component has its own diagonal
# covariance matrix
# "VII": spherical, equal volume (spherical), each component has its own single variance
# ------------------------------------------------------------------------------------------
#
# OUTPUT:
# ---------------------------------------------------------------------------------------------------
# labels The predictions made by the gaussian model on the X input dataset
# predict_prob Probability of the predictions given the X input dataset
# ---------------------------------------------------------------------------------------------------
m_gmmPredict = function(Matrix[Double] X, Matrix[Double] weight,
Matrix[Double] mu, Matrix[Double] precisions_cholesky, String model = "VVV")
return(Matrix[Double] labels, Matrix[Double] predict_prob)
{
# compute the posterior probabilities for new instances
weighted_log_prob = compute_log_gaussian_prob(X, mu, precisions_cholesky, model) + log(weight)
log_prob_norm = logSumExp(weighted_log_prob, "rows")
log_resp = weighted_log_prob - log_prob_norm
predict_prob = exp(log_resp)
labels = rowIndexMax(weighted_log_prob)
}
compute_log_gaussian_prob = function(Matrix[Double] X, Matrix[Double] mu,
Matrix[Double] prec_chol, String model)
return(Matrix[Double] es_log_prob ) # nrow(X) * n_components
{
n_components = nrow(mu)
d = ncol(X)
if(model == "VVV") {
log_prob = matrix(0, nrow(X), n_components) # log probabilities
log_det_chol = matrix(0, 1, n_components) # log determinant
i = 1
for(k in 1:n_components) {
prec = prec_chol[i:(k*ncol(X)),]
y = X %*% prec - mu[k,] %*% prec
log_prob[, k] = rowSums(y*y)
# compute log_det_cholesky
log_det = sum(log(diag(t(prec))))
log_det_chol[1,k] = log_det
i = i + ncol(X)
}
}
else if(model == "EEE") {
log_prob = matrix(0, nrow(X), n_components)
log_det_chol = as.matrix(sum(log(diag(prec_chol))))
prec = prec_chol
for(k in 1:n_components) {
y = X %*% prec - mu[k,] %*% prec
log_prob[, k] = rowSums(y*y)
}
}
else if(model == "VVI") {
log_det_chol = t(rowSums(log(prec_chol)))
prec = prec_chol
precisions = prec^2
bc_matrix = matrix(1,nrow(X), nrow(mu))
log_prob = (bc_matrix*t(rowSums(mu^2 * precisions))
- 2 * (X %*% t(mu * precisions)) + X^2 %*% t(precisions))
}
else if (model == "VII") {
log_det_chol = t(d * log(prec_chol))
prec = prec_chol
precisions = prec^ 2
bc_matrix = matrix(1,nrow(X), nrow(mu))
log_prob = (bc_matrix * t(rowSums(mu^2) * precisions)
- 2 * X %*% t(mu * precisions) + rowSums(X*X) %*% t(precisions) )
}
if(ncol(log_det_chol) == 1)
log_det_chol = matrix(1, 1, ncol(log_prob)) * log_det_chol
es_log_prob = -.5 * (ncol(X) * log(2 * pi) + log_prob) + log_det_chol
}