[SYSTEMDS-2647] Python API MultiLogReg Algorithm
diff --git a/src/main/python/systemds/operator/algorithm.py b/src/main/python/systemds/operator/algorithm.py
index 77c59a5..a2a8e83 100644
--- a/src/main/python/systemds/operator/algorithm.py
+++ b/src/main/python/systemds/operator/algorithm.py
@@ -85,7 +85,8 @@
.format(s=x._np_array.shape))
if 'k' in kwargs.keys() and kwargs.get('k') < 1:
- raise ValueError("Invalid number of clusters in K means, number must be integer above 0")
+ raise ValueError(
+ "Invalid number of clusters in K-Means, number must be integer above 0")
params_dict = {'X': x}
params_dict.update(kwargs)
@@ -108,9 +109,10 @@
.format(s=x._np_array.shape))
if 'K' in kwargs.keys() and kwargs.get('K') < 1:
- raise ValueError("Invalid number of clusters in K means, number must be integer above 0")
+ raise ValueError(
+ "Invalid number of clusters in K means, number must be integer above 0")
- if 'scale'in kwargs.keys():
+ if 'scale' in kwargs.keys():
if kwargs.get('scale') == True:
kwargs.set('scale', "TRUE")
elif kwargs.get('scale' == False):
@@ -126,3 +128,34 @@
params_dict.update(kwargs)
return OperationNode(x.sds_context, 'pca', named_input_nodes=params_dict)
+
+def multiLogReg(x: DAGNode, y: DAGNode, **kwargs: Dict[str, VALID_INPUT_TYPES]) -> OperationNode:
+ """
+ Performs Multiclass Logistic Regression on the matrix input
+ using Trust Region method.
+
+ See: Trust Region Newton Method for Logistic Regression, Lin, Weng and Keerthi, JMLR 9 (2008) 627-650)
+
+ :param x: Input dataset to perform logstic regression on
+ :param y: Labels rowaligned with the input dataset
+ :param icpt: Intercept, default 2, Intercept presence, shifting and rescaling X columns:
+ 0 = no intercept, no shifting, no rescaling;
+ 1 = add intercept, but neither shift nor rescale X;
+ 2 = add intercept, shift & rescale X columns to mean = 0, variance = 1
+ :param tol: float tolerance for the algorithm.
+ :param reg: Regularization parameter (lambda = 1/C); intercept settings are not regularized.
+ :param maxi: Maximum outer iterations of the algorithm
+ :param maxii: Maximum inner iterations of the algorithm
+ """
+
+ x._check_matrix_op()
+ if x._np_array.size == 0:
+ raise ValueError("Found array with 0 feature(s) (shape={s}) while a minimum of 1 is required."
+ .format(s=x._np_array.shape))
+ if y._np_array.size == 0:
+ raise ValueError("Found array with 0 feature(s) (shape={s}) while a minimum of 1 is required."
+ .format(s=y._np_array.shape))
+
+ params_dict = {'X': x, 'Y': y}
+ params_dict.update(kwargs)
+ return OperationNode(x.sds_context, 'multiLogReg', named_input_nodes=params_dict)
diff --git a/src/main/python/tests/algorithms/test_multiLogReg.py b/src/main/python/tests/algorithms/test_multiLogReg.py
new file mode 100644
index 0000000..6c7e297
--- /dev/null
+++ b/src/main/python/tests/algorithms/test_multiLogReg.py
@@ -0,0 +1,70 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+import unittest
+
+import numpy as np
+from systemds.context import SystemDSContext
+from systemds.matrix import Matrix
+from systemds.operator.algorithm import multiLogReg
+
+
+class TestMultiLogReg(unittest.TestCase):
+
+ sds: SystemDSContext = None
+
+ @classmethod
+ def setUpClass(cls):
+ cls.sds = SystemDSContext()
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.sds.close()
+
+ def test_simple(self):
+ """
+ Test simple, if the log reg splits a dataset where everything over 1 is label 1 and under 1 is 0.
+ """
+ # Generate data
+ mu, sigma = 1, 0.1
+ X = np.reshape(np.random.normal(mu, sigma, 500), (2,250))
+ # All over 1 is true
+ f = lambda x: x[0] > 1
+ labels = f(X)
+ # Y labels as double
+ Y = np.array(labels, dtype=np.double)
+ # Transpose X to fit input format.
+ X = X.transpose()
+
+ # Call algorithm
+ bias = multiLogReg(Matrix(self.sds,X),Matrix(self.sds,Y)).compute()
+
+ # Calculate result.
+ res = np.reshape(np.dot(X, bias[:len(X[0])]) + bias[len(X[0])], (250))
+
+ f2 = lambda x: x > 0
+ accuracy = np.sum(labels == f2(res)) / 250 * 100
+
+ self.assertTrue(accuracy > 98)
+
+
+if __name__ == "__main__":
+ unittest.main(exit=False)