[SYSTEMDS-2735] Builtin function gmmPredict for clustering instances
Closes #1108.
diff --git a/scripts/builtin/gmmPredict.dml b/scripts/builtin/gmmPredict.dml
new file mode 100644
index 0000000..e054902
--- /dev/null
+++ b/scripts/builtin/gmmPredict.dml
@@ -0,0 +1,108 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+# ------------------------------------------
+# Gaussian Mixture Model Predict
+# ------------------------------------------
+
+# INPUT PARAMETERS:
+# ---------------------------------------------------------------------------------------------
+# NAME TYPE DEFAULT MEANING
+# ---------------------------------------------------------------------------------------------
+# X Double --- Matrix X (instances to be clustered)
+# weight Double --- Weight of learned model
+# mu Double --- fitted clusters mean
+# precisions_cholesky Double --- fitted precision matrix for each mixture
+# model String --- fitted model
+# ---------------------------------------------------------------------------------------------
+
+# OUTPUT:
+# ---------------------------------------------------------------------------------------------
+# NAME TYPE DEFAULT MEANING
+# ---------------------------------------------------------------------------------------------
+# predict Double --- predicted cluster labels
+# posterior_prob Double --- probabilities of belongingness
+# ---------------------------------------------------------------------------------------------
+
+# compute posterior probabilities for new instances given the variance and mean of fitted data
+
+m_gmmPredict = function(Matrix[Double] X, Matrix[Double] weight,
+ Matrix[Double] mu, Matrix[Double] precisions_cholesky, String model)
+ return(Matrix[Double] predict, Matrix[Double] posterior_prob)
+{
+ # compute the posterior probabilities for new instances
+ weighted_log_prob = compute_log_gaussian_prob(X, mu, precisions_cholesky, model) + log(weight)
+ log_prob_norm = logSumExp(weighted_log_prob, "rows")
+ log_resp = weighted_log_prob - log_prob_norm
+ posterior_prob = exp(log_resp)
+ predict = rowIndexMax(weighted_log_prob)
+}
+
+compute_log_gaussian_prob = function(Matrix[Double] X, Matrix[Double] mu,
+ Matrix[Double] prec_chol, String model)
+ return(Matrix[Double] es_log_prob ) # nrow(X) * n_components
+{
+ n_components = nrow(mu)
+ d = ncol(X)
+
+ if(model == "VVV") {
+ log_prob = matrix(0, nrow(X), n_components) # log probabilities
+ log_det_chol = matrix(0, 1, n_components) # log determinant
+ i = 1
+ for(k in 1:n_components) {
+ prec = prec_chol[i:(k*ncol(X)),]
+ y = X %*% prec - mu[k,] %*% prec
+ log_prob[, k] = rowSums(y*y)
+ # compute log_det_cholesky
+ log_det = sum(log(diag(t(prec))))
+ log_det_chol[1,k] = log_det
+ i = i + ncol(X)
+ }
+ }
+ else if(model == "EEE") {
+ log_prob = matrix(0, nrow(X), n_components)
+ log_det_chol = as.matrix(sum(log(diag(prec_chol))))
+ prec = prec_chol
+ for(k in 1:n_components) {
+ y = X %*% prec - mu[k,] %*% prec
+ log_prob[, k] = rowSums(y*y)
+ }
+ }
+ else if(model == "VVI") {
+ log_det_chol = t(rowSums(log(prec_chol)))
+ prec = prec_chol
+ precisions = prec^2
+ bc_matrix = matrix(1,nrow(X), nrow(mu))
+ log_prob = (bc_matrix*t(rowSums(mu^2 * precisions))
+ - 2 * (X %*% t(mu * precisions)) + X^2 %*% t(precisions))
+ }
+ else if (model == "VII") {
+ log_det_chol = t(d * log(prec_chol))
+ prec = prec_chol
+ precisions = prec^ 2
+ bc_matrix = matrix(1,nrow(X), nrow(mu))
+ log_prob = (bc_matrix * t(rowSums(mu^2) * precisions)
+ - 2 * X %*% t(mu * precisions) + rowSums(X*X) %*% t(precisions) )
+ }
+ if(ncol(log_det_chol) == 1)
+ log_det_chol = matrix(1, 1, ncol(log_prob)) * log_det_chol
+
+ es_log_prob = -.5 * (ncol(X) * log(2 * pi) + log_prob) + log_det_chol
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index 5010c21..a1f372c 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -116,6 +116,7 @@
GET_PERMUTATIONS("getPermutations", true),
GLM("glm", true),
GMM("gmm", true),
+ GMM_PREDICT("gmmPredict", true),
GNMF("gnmf", true),
GRID_SEARCH("gridSearch", true),
HYPERBAND("hyperband", true),
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGMMPredictTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGMMPredictTest.java
new file mode 100644
index 0000000..f0d2cc7
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGMMPredictTest.java
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class BuiltinGMMPredictTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "GMM_Predict";
+ private final static String TEST_DIR = "functions/builtin/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinGMMPredictTest.class.getSimpleName() + "/";
+
+ private final static double eps = 2;
+ private final static double tol = 1e-3;
+ private final static double tol2 = 1e-5;
+
+ private final static String DATASET = SCRIPT_DIR + "functions/transform/input/iris/iris.csv";
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B"}));
+ }
+
+ @Test
+ public void testGMMMPredictCP1() {
+ runGMMPredictTest(3, "VVI", "random", 10,
+ 0.000000001, tol,42,true, LopProperties.ExecType.CP);
+ }
+
+ @Test
+ public void testGMMMPredictCP2() {
+ runGMMPredictTest(3, "VII", "random", 50,
+ 0.000001, tol2,42,true, LopProperties.ExecType.CP);
+ }
+
+ @Test
+ public void testGMMMPredictCPKmean1() {
+ runGMMPredictTest(3, "VVV", "kmeans", 10,
+ 0.0000001, tol,42,true, LopProperties.ExecType.CP);
+ }
+
+ @Test
+ public void testGMMMPredictCPKmean2() {
+ runGMMPredictTest(3, "EEE", "kmeans", 150,
+ 0.000001, tol,42,true, LopProperties.ExecType.CP);
+ }
+
+ @Test
+ public void testGMMMPredictCPKmean3() {
+ runGMMPredictTest(3, "VII", "kmeans", 50,
+ 0.000001, tol2,42,true, LopProperties.ExecType.CP);
+ }
+
+// @Test
+// public void testGMMM1Spark() {
+// runGMMPredictTest(3, "VVV", "random", 10,
+// 0.0000001, tol,42,true, LopProperties.ExecType.SPARK); }
+//
+// @Test
+// public void testGMMM2Spark() {
+// runGMMPredictTest(3, "EEE", "random", 50,
+// 0.0000001, tol,42,true, LopProperties.ExecType.CP);
+// }
+//
+// @Test
+// public void testGMMMS3Spark() {
+// runGMMPredictTest(3, "VVI", "random", 100,
+// 0.000001, tol,42,true, LopProperties.ExecType.CP);
+// }
+//
+// @Test
+// public void testGMMM4Spark() {
+// runGMMPredictTest(3, "VII", "random", 100,
+// 0.000001, tol1,42,true, LopProperties.ExecType.CP);
+// }
+//
+// @Test
+// public void testGMMM1KmeanSpark() {
+// runGMMPredictTest(3, "VVV", "kmeans", 100,
+// 0.000001, tol2,42,false, LopProperties.ExecType.SPARK);
+// }
+//
+// @Test
+// public void testGMMM2KmeanSpark() {
+// runGMMPredictTest(3, "EEE", "kmeans", 50,
+// 0.00000001, tol1,42,false, LopProperties.ExecType.SPARK);
+// }
+//
+// @Test
+// public void testGMMM3KmeanSpark() {
+// runGMMPredictTest(3, "VVI", "kmeans", 100,
+// 0.000001, tol,42,false, LopProperties.ExecType.SPARK);
+// }
+//
+// @Test
+// public void testGMMM4KmeanSpark() {
+// runGMMPredictTest(3, "VII", "kmeans", 100,
+// 0.000001, tol,42,false, LopProperties.ExecType.SPARK);
+// }
+
+ private void runGMMPredictTest(int G_mixtures, String model, String init_param, int iter,
+ double reg, double tol, int seed, boolean rewrite, LopProperties.ExecType instType) {
+
+ Types.ExecMode platformOld = setExecMode(instType);
+ boolean rewriteOld = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite;
+ try {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ String outFile = output("O");
+ System.out.println(outFile);
+ programArgs = new String[] {"-args", DATASET,
+ String.valueOf(G_mixtures), model, init_param, String.valueOf(iter), String.valueOf(reg),
+ String.valueOf(tol), String.valueOf(seed), outFile};
+
+ runTest(true, false, null, -1);
+ // compare results
+ double accuracy = TestUtils.readDMLScalar(outFile);
+ Assert.assertEquals(1, accuracy, eps);
+ }
+ finally {
+ resetExecMode(platformOld);
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewriteOld;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/builtin/GMM_Predict.dml b/src/test/scripts/functions/builtin/GMM_Predict.dml
new file mode 100644
index 0000000..283db7c
--- /dev/null
+++ b/src/test/scripts/functions/builtin/GMM_Predict.dml
@@ -0,0 +1,54 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = read($1, data_type = "frame", format = "csv", header=TRUE)
+X = X[ , 2:ncol(X) - 1]
+X = as.matrix(X)
+
+# divide in train and test set
+train = X[1:45,]
+train = rbind(train, X[52:95,])
+train = rbind(train, X[102:145,])
+
+test = X[46:51,]
+test = rbind(test, X[96:101,])
+test = rbind(test, X[146:150,])
+
+# train GMM
+[labels, prob, df, bic, mu, prec_chol, w] = gmm(X=train, n_components = $2,
+ model = $3, init_params = $4, iter = $5, reg_covar = $6, tol = $7, seed=$8, verbose=TRUE)
+
+# predict labels
+[pred, pp] = gmmPredict(test, w, mu, prec_chol, $3)
+
+# expected clusters/predictions
+expected = matrix("6 6 5", 3, 1)
+
+resp = matrix(1, 17, 3) * t(seq(1,3))
+resp = resp == pred
+cluster = t(colSums(resp))
+
+cluster = order(target = cluster, by = 1, decreasing = FALSE, index.return=FALSE)
+correct_Predictions = order(target = expected, by = 1, decreasing = FALSE, index.return=FALSE)
+
+error = mean(abs(correct_Predictions - cluster))
+write(error, $9, format = "text")