[SYSTEMDS-2701] KMeans Predict builtin
This commit include tests in dml and the addition of
the predict in python.
diff --git a/scripts/builtin/kmeansPredict.dml b/scripts/builtin/kmeansPredict.dml
new file mode 100644
index 0000000..ab8722c
--- /dev/null
+++ b/scripts/builtin/kmeansPredict.dml
@@ -0,0 +1,49 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-----------------------------------------------------------------------------
+
+# Builtin function that does predictions based on a set of centroids provided.
+#
+# INPUT PARAMETERS:
+# ----------------------------------------------------------------------------
+# NAME TYPE DEFAULT MEANING
+# ----------------------------------------------------------------------------
+# X Double --- The input Matrix to do KMeans on.
+# C Double --- The input Centroids to map X onto.
+#
+# RETURN VALUES
+# ----------------------------------------------------------------------------
+# NAME TYPE DEFAULT MEANING
+# ----------------------------------------------------------------------------
+# Y String "Y.mtx" The mapping of records to centroids
+# ----------------------------------------------------------------------------
+
+
+m_kmeansPredict = function(Matrix[Double] X, Matrix[Double] C)
+ return (Matrix[Double] Y)
+{
+
+ D = -2 * (X %*% t(C)) + t(rowSums (C ^ 2));
+ P = (D <= rowMins (D));
+ aggr_P = t(cumsum (t(P)));
+ Y = rowSums (aggr_P == 0) + 1
+
+}
+
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java
index a9e023e..2c08ef4 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -119,6 +119,7 @@
ISNAN("is.nan", false),
ISINF("is.infinite", false),
KMEANS("kmeans", true),
+ KMEANSPREDICT("kmeansPredict", true),
L2SVM("l2svm", true),
LASSO("lasso", true),
LENGTH("length", false),
diff --git a/src/main/python/systemds/operator/algorithm.py b/src/main/python/systemds/operator/algorithm.py
index b30df47..7833030 100644
--- a/src/main/python/systemds/operator/algorithm.py
+++ b/src/main/python/systemds/operator/algorithm.py
@@ -92,6 +92,21 @@
params_dict.update(kwargs)
return OperationNode(x.sds_context, 'kmeans', named_input_nodes=params_dict, output_type=OutputType.LIST, number_of_outputs=2)
+def kmeansPredict(X: OperationNode, C: OperationNode) -> OperationNode:
+ """
+ Perform Kmeans Predict, note that the Ids returned are 1 indexed.
+
+ :param X: The matrix to classify.
+ :param Y: The Clusters to use for classification into.
+ :return: `OperationNode` containing a matrix of classifications of Id's of specific clusters in C.
+ """
+ X._check_matrix_op()
+ C._check_matrix_op()
+
+ params_dict = {'X' : X, 'C' : C}
+ return OperationNode(X.sds_context, 'kmeansPredict', named_input_nodes=params_dict, output_type=OutputType.MATRIX, shape=(1, X.shape[0]))
+
+
def pca(x: OperationNode, **kwargs: Dict[str, VALID_INPUT_TYPES]) -> OperationNode:
"""
diff --git a/src/main/python/tests/algorithms/test_kmeans.py b/src/main/python/tests/algorithms/test_kmeans.py
index ebf2264..426c40b 100644
--- a/src/main/python/tests/algorithms/test_kmeans.py
+++ b/src/main/python/tests/algorithms/test_kmeans.py
@@ -24,7 +24,7 @@
import numpy as np
from systemds.context import SystemDSContext
from systemds.matrix import Matrix
-from systemds.operator.algorithm import kmeans
+from systemds.operator.algorithm import kmeans, kmeansPredict
class TestKMeans(unittest.TestCase):
@@ -59,6 +59,29 @@
corners.add("nn")
self.assertTrue(len(corners) == 4)
+ def test_500x2(self):
+ """
+ This test is based on statistics, that if we run kmeans, on a normal distributed dataset, centered around 0
+ and use 4 clusters then they will be located in each one corner.
+ This test uses the prediction builtin.
+ """
+ features = self.generate_matrices_for_k_means((500, 2), seed=1304)
+ [c, _] = kmeans(features, k=4).compute()
+ C = Matrix(self.sds, c)
+ elm = Matrix(self.sds, np.array([[1, 1], [-1, 1], [-1, -1], [1, -1]]))
+ res = kmeansPredict(elm, C).compute()
+ corners = set()
+ for x in res:
+ if x == 1:
+ corners.add("pp")
+ elif x == 2:
+ corners.add("pn")
+ elif x == 3:
+ corners.add("np")
+ else:
+ corners.add("nn")
+ self.assertTrue(len(corners) == 4)
+
def test_invalid_input_1(self):
features = Matrix(self.sds, np.array([]))
with self.assertRaises(ValueError) as context:
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinKmeansPredictTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinKmeansPredictTest.java
new file mode 100644
index 0000000..bc1f074
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinKmeansPredictTest.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import java.util.HashMap;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class BuiltinKmeansPredictTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "kmeansPredict";
+ private final static String TEST_DIR = "functions/builtin/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinKmeansPredictTest.class.getSimpleName() + "/";
+ private final static double eps = 1e-10;
+ private final static int rows = 1320;
+ private final static int cols = 32;
+ private final static double spSparse = 0.3;
+ private final static double spDense = 0.7;
+ private final static double max_iter = 50;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"res"}));
+ }
+
+ @Test
+ public void testKMeansDenseBinSingleRewritesCP() {
+ runKMeansTest(false, 2, 1, true, LopProperties.ExecType.CP);
+ }
+
+ private void runKMeansTest(boolean sparse, int centroids, int runs, boolean rewrites,
+ LopProperties.ExecType instType) {
+ Types.ExecMode platformOld = setExecMode(instType);
+
+ boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+
+ try {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+ double sparsity = sparse ? spSparse : spDense;
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-nvargs", "X=" + input("X"), "res=" + output("res"), "k=" + centroids,
+ "runs=" + runs, "eps=" + eps, "max_iter=" + max_iter};
+
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
+
+ // generate actual datasets
+ double[][] X = getRandomMatrix(rows, cols, 0, 1, sparsity, 714);
+ writeInputMatrixWithMTD("X", X, true);
+
+ runTest(null);
+ HashMap<CellIndex, Double> res = readDMLScalarFromHDFS("res");
+ Assert.assertTrue(res.values().size() == 1);
+ Assert.assertEquals(res.values().toArray()[0] , 1.);
+ }
+ finally {
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+ OptimizerUtils.ALLOW_AUTO_VECTORIZATION = true;
+ OptimizerUtils.ALLOW_OPERATOR_FUSION = true;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/builtin/kmeansPredict.dml b/src/test/scripts/functions/builtin/kmeansPredict.dml
new file mode 100644
index 0000000..a96fc28
--- /dev/null
+++ b/src/test/scripts/functions/builtin/kmeansPredict.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($X)
+
+[C, Y] = kmeans(X, $k, $runs, $max_iter, $eps, TRUE, 50)
+Y_1 = kmeansPredict(X, C)
+
+res = mean(Y==Y_1)
+write(res, $res)