| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| # Prediction function for a Gaussian Mixture Model (gmm). |
| # Compute posterior probabilities for new instances given the variance and mean of fitted dat. |
| # |
| # INPUT: |
| # ------------------------------------------------------------------------------------------ |
| # X Dataset input to predict the labels from |
| # weight Weight of learned model: |
| # A matrix whose [i,k]th entry is the probability |
| # that observation i in the test data belongs to the kth class |
| # mu Fitted clusters mean |
| # precisions_cholesky Fitted precision matrix for each mixture |
| # model "VVV": unequal variance (full),each component has its own general covariance matrix |
| # "EEE": equal variance (tied), all components share the same general covariance matrix |
| # "VVI": spherical, unequal volume (diag), each component has its own diagonal |
| # covariance matrix |
| # "VII": spherical, equal volume (spherical), each component has its own single variance |
| # ------------------------------------------------------------------------------------------ |
| # |
| # OUTPUT: |
| # --------------------------------------------------------------------------------------------------- |
| # labels The predictions made by the gaussian model on the X input dataset |
| # predict_prob Probability of the predictions given the X input dataset |
| # --------------------------------------------------------------------------------------------------- |
| |
| m_gmmPredict = function(Matrix[Double] X, Matrix[Double] weight, |
| Matrix[Double] mu, Matrix[Double] precisions_cholesky, String model = "VVV") |
| return(Matrix[Double] labels, Matrix[Double] predict_prob) |
| { |
| # compute the posterior probabilities for new instances |
| weighted_log_prob = compute_log_gaussian_prob(X, mu, precisions_cholesky, model) + log(weight) |
| log_prob_norm = logSumExp(weighted_log_prob, "rows") |
| log_resp = weighted_log_prob - log_prob_norm |
| predict_prob = exp(log_resp) |
| labels = rowIndexMax(weighted_log_prob) |
| } |
| |
| compute_log_gaussian_prob = function(Matrix[Double] X, Matrix[Double] mu, |
| Matrix[Double] prec_chol, String model) |
| return(Matrix[Double] es_log_prob ) # nrow(X) * n_components |
| { |
| n_components = nrow(mu) |
| d = ncol(X) |
| |
| if(model == "VVV") { |
| log_prob = matrix(0, nrow(X), n_components) # log probabilities |
| log_det_chol = matrix(0, 1, n_components) # log determinant |
| i = 1 |
| for(k in 1:n_components) { |
| prec = prec_chol[i:(k*ncol(X)),] |
| y = X %*% prec - mu[k,] %*% prec |
| log_prob[, k] = rowSums(y*y) |
| # compute log_det_cholesky |
| log_det = sum(log(diag(t(prec)))) |
| log_det_chol[1,k] = log_det |
| i = i + ncol(X) |
| } |
| } |
| else if(model == "EEE") { |
| log_prob = matrix(0, nrow(X), n_components) |
| log_det_chol = as.matrix(sum(log(diag(prec_chol)))) |
| prec = prec_chol |
| for(k in 1:n_components) { |
| y = X %*% prec - mu[k,] %*% prec |
| log_prob[, k] = rowSums(y*y) |
| } |
| } |
| else if(model == "VVI") { |
| log_det_chol = t(rowSums(log(prec_chol))) |
| prec = prec_chol |
| precisions = prec^2 |
| bc_matrix = matrix(1,nrow(X), nrow(mu)) |
| log_prob = (bc_matrix*t(rowSums(mu^2 * precisions)) |
| - 2 * (X %*% t(mu * precisions)) + X^2 %*% t(precisions)) |
| } |
| else if (model == "VII") { |
| log_det_chol = t(d * log(prec_chol)) |
| prec = prec_chol |
| precisions = prec^ 2 |
| bc_matrix = matrix(1,nrow(X), nrow(mu)) |
| log_prob = (bc_matrix * t(rowSums(mu^2) * precisions) |
| - 2 * X %*% t(mu * precisions) + rowSums(X*X) %*% t(precisions) ) |
| } |
| if(ncol(log_det_chol) == 1) |
| log_det_chol = matrix(1, 1, ncol(log_prob)) * log_det_chol |
| |
| es_log_prob = -.5 * (ncol(X) * log(2 * pi) + log_prob) + log_det_chol |
| } |