[SYSTEMDS-2735] Builtin function gmmPredict for clustering instances

Closes #1108.
diff --git a/scripts/builtin/gmmPredict.dml b/scripts/builtin/gmmPredict.dml
new file mode 100644
index 0000000..e054902
--- /dev/null
+++ b/scripts/builtin/gmmPredict.dml
@@ -0,0 +1,108 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+# ------------------------------------------
+# Gaussian Mixture Model Predict
+# ------------------------------------------
+
+# INPUT PARAMETERS:
+# ---------------------------------------------------------------------------------------------
+# NAME                   TYPE     DEFAULT   MEANING
+# ---------------------------------------------------------------------------------------------
+# X                      Double   ---       Matrix X (instances to be clustered)
+# weight                 Double   ---       Weight of learned model
+# mu                     Double   ---       fitted clusters mean
+# precisions_cholesky    Double   ---       fitted precision matrix for each mixture
+# model                  String   ---       fitted model
+# ---------------------------------------------------------------------------------------------
+
+# OUTPUT:
+# ---------------------------------------------------------------------------------------------
+# NAME                   TYPE     DEFAULT   MEANING
+# ---------------------------------------------------------------------------------------------
+# predict                Double   ---       predicted cluster labels
+# posterior_prob         Double   ---       probabilities of belongingness
+# ---------------------------------------------------------------------------------------------
+
+# compute posterior probabilities for new instances given the variance and mean of fitted data
+
+m_gmmPredict = function(Matrix[Double] X, Matrix[Double] weight,
+  Matrix[Double] mu, Matrix[Double] precisions_cholesky, String model)
+  return(Matrix[Double] predict, Matrix[Double] posterior_prob)
+{
+  # compute the posterior probabilities for new instances
+  weighted_log_prob =  compute_log_gaussian_prob(X, mu, precisions_cholesky, model) + log(weight)
+  log_prob_norm = logSumExp(weighted_log_prob, "rows")
+  log_resp = weighted_log_prob - log_prob_norm
+  posterior_prob = exp(log_resp)
+  predict =  rowIndexMax(weighted_log_prob)
+}
+
+compute_log_gaussian_prob = function(Matrix[Double] X, Matrix[Double] mu,
+  Matrix[Double] prec_chol, String model)
+  return(Matrix[Double] es_log_prob ) # nrow(X) * n_components
+{
+  n_components = nrow(mu)
+  d = ncol(X)
+
+  if(model == "VVV") { 
+    log_prob = matrix(0, nrow(X), n_components) # log probabilities
+    log_det_chol = matrix(0, 1, n_components)   # log determinant 
+    i = 1
+    for(k in 1:n_components) {
+      prec = prec_chol[i:(k*ncol(X)),]
+      y = X %*% prec - mu[k,] %*% prec
+      log_prob[, k] = rowSums(y*y)
+      # compute log_det_cholesky
+      log_det = sum(log(diag(t(prec))))
+      log_det_chol[1,k] = log_det
+      i = i + ncol(X)
+    }
+  }
+  else if(model == "EEE") { 
+    log_prob = matrix(0, nrow(X), n_components) 
+    log_det_chol = as.matrix(sum(log(diag(prec_chol))))
+    prec = prec_chol
+    for(k in 1:n_components) {
+      y = X %*% prec - mu[k,] %*% prec
+      log_prob[, k] = rowSums(y*y) 
+    }
+  }
+  else if(model ==  "VVI") {
+    log_det_chol = t(rowSums(log(prec_chol)))
+    prec = prec_chol
+    precisions = prec^2
+    bc_matrix = matrix(1,nrow(X), nrow(mu))
+    log_prob = (bc_matrix*t(rowSums(mu^2 * precisions)) 
+      - 2 * (X %*% t(mu * precisions)) + X^2 %*% t(precisions))
+  }
+  else if (model == "VII") {
+    log_det_chol = t(d * log(prec_chol))
+    prec = prec_chol
+    precisions = prec^ 2
+    bc_matrix = matrix(1,nrow(X), nrow(mu))
+    log_prob = (bc_matrix * t(rowSums(mu^2) * precisions) 
+      - 2 * X %*% t(mu * precisions) + rowSums(X*X) %*% t(precisions) )
+  }
+  if(ncol(log_det_chol) == 1)
+    log_det_chol = matrix(1, 1, ncol(log_prob)) * log_det_chol
+
+  es_log_prob = -.5 * (ncol(X) * log(2 * pi) + log_prob) + log_det_chol
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index 5010c21..a1f372c 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -116,6 +116,7 @@
 	GET_PERMUTATIONS("getPermutations", true),
 	GLM("glm", true),
 	GMM("gmm", true),
+	GMM_PREDICT("gmmPredict", true),
 	GNMF("gnmf", true),
 	GRID_SEARCH("gridSearch", true),
 	HYPERBAND("hyperband", true),
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGMMPredictTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGMMPredictTest.java
new file mode 100644
index 0000000..f0d2cc7
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGMMPredictTest.java
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class BuiltinGMMPredictTest extends AutomatedTestBase {
+	private final static String TEST_NAME = "GMM_Predict";
+	private final static String TEST_DIR = "functions/builtin/";
+	private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinGMMPredictTest.class.getSimpleName() + "/";
+
+	private final static double eps = 2;
+	private final static double tol = 1e-3;
+	private final static double tol2 = 1e-5;
+
+	private final static String DATASET = SCRIPT_DIR + "functions/transform/input/iris/iris.csv";
+
+	@Override
+	public void setUp() {
+		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B"}));
+	}
+
+	@Test
+	public void testGMMMPredictCP1() {
+		runGMMPredictTest(3, "VVI", "random", 10,
+			0.000000001, tol,42,true, LopProperties.ExecType.CP);
+	}
+
+	@Test
+	public void testGMMMPredictCP2() {
+		runGMMPredictTest(3, "VII", "random", 50,
+			0.000001, tol2,42,true, LopProperties.ExecType.CP);
+	}
+
+	@Test
+	public void testGMMMPredictCPKmean1() {
+		runGMMPredictTest(3, "VVV", "kmeans", 10,
+			0.0000001, tol,42,true, LopProperties.ExecType.CP);
+	}
+
+	@Test
+	public void testGMMMPredictCPKmean2() {
+		runGMMPredictTest(3, "EEE", "kmeans", 150,
+			0.000001, tol,42,true, LopProperties.ExecType.CP);
+	}
+
+	@Test
+	public void testGMMMPredictCPKmean3() {
+		runGMMPredictTest(3, "VII", "kmeans", 50,
+			0.000001, tol2,42,true, LopProperties.ExecType.CP);
+	}
+
+//	@Test
+//	public void testGMMM1Spark() {
+//		runGMMPredictTest(3, "VVV", "random", 10,
+//		0.0000001, tol,42,true, LopProperties.ExecType.SPARK); }
+//
+//	@Test
+//	public void testGMMM2Spark() {
+//		runGMMPredictTest(3, "EEE", "random", 50,
+//			0.0000001, tol,42,true, LopProperties.ExecType.CP);
+//	}
+//
+//	@Test
+//	public void testGMMMS3Spark() {
+//		runGMMPredictTest(3, "VVI", "random", 100,
+//			0.000001, tol,42,true, LopProperties.ExecType.CP);
+//	}
+//
+//	@Test
+//	public void testGMMM4Spark() {
+//		runGMMPredictTest(3, "VII", "random", 100,
+//			0.000001, tol1,42,true, LopProperties.ExecType.CP);
+//	}
+//
+//	@Test
+//	public void testGMMM1KmeanSpark() {
+//		runGMMPredictTest(3, "VVV", "kmeans", 100,
+//			0.000001, tol2,42,false, LopProperties.ExecType.SPARK);
+//	}
+//
+//	@Test
+//	public void testGMMM2KmeanSpark() {
+//		runGMMPredictTest(3, "EEE", "kmeans", 50,
+//			0.00000001, tol1,42,false, LopProperties.ExecType.SPARK);
+//	}
+//
+//	@Test
+//	public void testGMMM3KmeanSpark() {
+//		runGMMPredictTest(3, "VVI", "kmeans", 100,
+//			0.000001, tol,42,false, LopProperties.ExecType.SPARK);
+//	}
+//
+//	@Test
+//	public void testGMMM4KmeanSpark() {
+//		runGMMPredictTest(3, "VII", "kmeans", 100,
+//			0.000001, tol,42,false, LopProperties.ExecType.SPARK);
+//	}
+
+	private void runGMMPredictTest(int G_mixtures, String model, String init_param, int iter,
+		double reg, double tol, int seed, boolean rewrite, LopProperties.ExecType instType) {
+
+		Types.ExecMode platformOld = setExecMode(instType);
+		boolean rewriteOld = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+		OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite;
+		try {
+			loadTestConfiguration(getTestConfiguration(TEST_NAME));
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + TEST_NAME + ".dml";
+			String outFile = output("O");
+			System.out.println(outFile);
+			programArgs = new String[] {"-args", DATASET,
+				String.valueOf(G_mixtures), model, init_param, String.valueOf(iter), String.valueOf(reg),
+				String.valueOf(tol), String.valueOf(seed), outFile};
+
+			runTest(true, false, null, -1);
+			// compare results
+			double accuracy = TestUtils.readDMLScalar(outFile);
+			Assert.assertEquals(1, accuracy, eps);
+		}
+		finally {
+			resetExecMode(platformOld);
+			OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewriteOld;
+		}
+	}
+}
diff --git a/src/test/scripts/functions/builtin/GMM_Predict.dml b/src/test/scripts/functions/builtin/GMM_Predict.dml
new file mode 100644
index 0000000..283db7c
--- /dev/null
+++ b/src/test/scripts/functions/builtin/GMM_Predict.dml
@@ -0,0 +1,54 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = read($1, data_type = "frame", format = "csv", header=TRUE)
+X = X[ , 2:ncol(X) - 1]
+X = as.matrix(X)
+
+# divide in train and test set
+train = X[1:45,]
+train = rbind(train, X[52:95,])
+train = rbind(train, X[102:145,])
+
+test = X[46:51,]
+test = rbind(test, X[96:101,])
+test = rbind(test, X[146:150,])
+
+# train GMM
+[labels, prob, df, bic, mu, prec_chol, w] = gmm(X=train, n_components = $2,
+  model = $3, init_params = $4, iter = $5, reg_covar = $6, tol = $7, seed=$8, verbose=TRUE)
+ 
+# predict labels
+[pred, pp] = gmmPredict(test, w, mu, prec_chol, $3)  
+
+# expected clusters/predictions
+expected = matrix("6 6 5", 3, 1)
+
+resp = matrix(1, 17, 3) * t(seq(1,3))
+resp = resp == pred
+cluster = t(colSums(resp))
+
+cluster = order(target = cluster, by = 1, decreasing = FALSE, index.return=FALSE)
+correct_Predictions = order(target = expected, by = 1, decreasing = FALSE, index.return=FALSE)
+
+error = mean(abs(correct_Predictions - cluster))
+write(error, $9, format = "text")