| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| # ------------------------------------------ |
| # Gaussian Mixture Model Predict |
| # ------------------------------------------ |
| |
| # INPUT PARAMETERS: |
| # --------------------------------------------------------------------------------------------- |
| # NAME TYPE DEFAULT MEANING |
| # --------------------------------------------------------------------------------------------- |
| # X Double --- Matrix X (instances to be clustered) |
| # weight Double --- Weight of learned model |
| # mu Double --- fitted clusters mean |
| # precisions_cholesky Double --- fitted precision matrix for each mixture |
| # model String --- fitted model |
| # --------------------------------------------------------------------------------------------- |
| |
| # OUTPUT: |
| # --------------------------------------------------------------------------------------------- |
| # NAME TYPE DEFAULT MEANING |
| # --------------------------------------------------------------------------------------------- |
| # predict Double --- predicted cluster labels |
| # posterior_prob Double --- probabilities of belongingness |
| # --------------------------------------------------------------------------------------------- |
| |
| # compute posterior probabilities for new instances given the variance and mean of fitted data |
| |
| m_gmmPredict = function(Matrix[Double] X, Matrix[Double] weight, |
| Matrix[Double] mu, Matrix[Double] precisions_cholesky, String model) |
| return(Matrix[Double] predict, Matrix[Double] posterior_prob) |
| { |
| # compute the posterior probabilities for new instances |
| weighted_log_prob = compute_log_gaussian_prob(X, mu, precisions_cholesky, model) + log(weight) |
| log_prob_norm = logSumExp(weighted_log_prob, "rows") |
| log_resp = weighted_log_prob - log_prob_norm |
| posterior_prob = exp(log_resp) |
| predict = rowIndexMax(weighted_log_prob) |
| } |
| |
| compute_log_gaussian_prob = function(Matrix[Double] X, Matrix[Double] mu, |
| Matrix[Double] prec_chol, String model) |
| return(Matrix[Double] es_log_prob ) # nrow(X) * n_components |
| { |
| n_components = nrow(mu) |
| d = ncol(X) |
| |
| if(model == "VVV") { |
| log_prob = matrix(0, nrow(X), n_components) # log probabilities |
| log_det_chol = matrix(0, 1, n_components) # log determinant |
| i = 1 |
| for(k in 1:n_components) { |
| prec = prec_chol[i:(k*ncol(X)),] |
| y = X %*% prec - mu[k,] %*% prec |
| log_prob[, k] = rowSums(y*y) |
| # compute log_det_cholesky |
| log_det = sum(log(diag(t(prec)))) |
| log_det_chol[1,k] = log_det |
| i = i + ncol(X) |
| } |
| } |
| else if(model == "EEE") { |
| log_prob = matrix(0, nrow(X), n_components) |
| log_det_chol = as.matrix(sum(log(diag(prec_chol)))) |
| prec = prec_chol |
| for(k in 1:n_components) { |
| y = X %*% prec - mu[k,] %*% prec |
| log_prob[, k] = rowSums(y*y) |
| } |
| } |
| else if(model == "VVI") { |
| log_det_chol = t(rowSums(log(prec_chol))) |
| prec = prec_chol |
| precisions = prec^2 |
| bc_matrix = matrix(1,nrow(X), nrow(mu)) |
| log_prob = (bc_matrix*t(rowSums(mu^2 * precisions)) |
| - 2 * (X %*% t(mu * precisions)) + X^2 %*% t(precisions)) |
| } |
| else if (model == "VII") { |
| log_det_chol = t(d * log(prec_chol)) |
| prec = prec_chol |
| precisions = prec^ 2 |
| bc_matrix = matrix(1,nrow(X), nrow(mu)) |
| log_prob = (bc_matrix * t(rowSums(mu^2) * precisions) |
| - 2 * X %*% t(mu * precisions) + rowSums(X*X) %*% t(precisions) ) |
| } |
| if(ncol(log_det_chol) == 1) |
| log_det_chol = matrix(1, 1, ncol(log_prob)) * log_det_chol |
| |
| es_log_prob = -.5 * (ncol(X) * log(2 * pi) + log_prob) + log_det_chol |
| } |