blob: e054902e7ad07d5b7a150b7ac293753e81bc78cf [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# ------------------------------------------
# Gaussian Mixture Model Predict
# ------------------------------------------
# INPUT PARAMETERS:
# ---------------------------------------------------------------------------------------------
# NAME TYPE DEFAULT MEANING
# ---------------------------------------------------------------------------------------------
# X Double --- Matrix X (instances to be clustered)
# weight Double --- Weight of learned model
# mu Double --- fitted clusters mean
# precisions_cholesky Double --- fitted precision matrix for each mixture
# model String --- fitted model
# ---------------------------------------------------------------------------------------------
# OUTPUT:
# ---------------------------------------------------------------------------------------------
# NAME TYPE DEFAULT MEANING
# ---------------------------------------------------------------------------------------------
# predict Double --- predicted cluster labels
# posterior_prob Double --- probabilities of belongingness
# ---------------------------------------------------------------------------------------------
# compute posterior probabilities for new instances given the variance and mean of fitted data
m_gmmPredict = function(Matrix[Double] X, Matrix[Double] weight,
Matrix[Double] mu, Matrix[Double] precisions_cholesky, String model)
return(Matrix[Double] predict, Matrix[Double] posterior_prob)
{
# compute the posterior probabilities for new instances
weighted_log_prob = compute_log_gaussian_prob(X, mu, precisions_cholesky, model) + log(weight)
log_prob_norm = logSumExp(weighted_log_prob, "rows")
log_resp = weighted_log_prob - log_prob_norm
posterior_prob = exp(log_resp)
predict = rowIndexMax(weighted_log_prob)
}
compute_log_gaussian_prob = function(Matrix[Double] X, Matrix[Double] mu,
Matrix[Double] prec_chol, String model)
return(Matrix[Double] es_log_prob ) # nrow(X) * n_components
{
n_components = nrow(mu)
d = ncol(X)
if(model == "VVV") {
log_prob = matrix(0, nrow(X), n_components) # log probabilities
log_det_chol = matrix(0, 1, n_components) # log determinant
i = 1
for(k in 1:n_components) {
prec = prec_chol[i:(k*ncol(X)),]
y = X %*% prec - mu[k,] %*% prec
log_prob[, k] = rowSums(y*y)
# compute log_det_cholesky
log_det = sum(log(diag(t(prec))))
log_det_chol[1,k] = log_det
i = i + ncol(X)
}
}
else if(model == "EEE") {
log_prob = matrix(0, nrow(X), n_components)
log_det_chol = as.matrix(sum(log(diag(prec_chol))))
prec = prec_chol
for(k in 1:n_components) {
y = X %*% prec - mu[k,] %*% prec
log_prob[, k] = rowSums(y*y)
}
}
else if(model == "VVI") {
log_det_chol = t(rowSums(log(prec_chol)))
prec = prec_chol
precisions = prec^2
bc_matrix = matrix(1,nrow(X), nrow(mu))
log_prob = (bc_matrix*t(rowSums(mu^2 * precisions))
- 2 * (X %*% t(mu * precisions)) + X^2 %*% t(precisions))
}
else if (model == "VII") {
log_det_chol = t(d * log(prec_chol))
prec = prec_chol
precisions = prec^ 2
bc_matrix = matrix(1,nrow(X), nrow(mu))
log_prob = (bc_matrix * t(rowSums(mu^2) * precisions)
- 2 * X %*% t(mu * precisions) + rowSums(X*X) %*% t(precisions) )
}
if(ncol(log_det_chol) == 1)
log_det_chol = matrix(1, 1, ncol(log_prob)) * log_det_chol
es_log_prob = -.5 * (ncol(X) * log(2 * pi) + log_prob) + log_det_chol
}