[SYSTEMDS-2722] Split Testing

Added tests (builtin and federated) for new split builtin.
diff --git a/scripts/builtin/split.dml b/scripts/builtin/split.dml
index b2d78d0..5e6f1c5 100644
--- a/scripts/builtin/split.dml
+++ b/scripts/builtin/split.dml
@@ -19,14 +19,15 @@
 #
 #-------------------------------------------------------------
 
-# Split input data X and y into contiguous or samples train/test sets
+# Split input data X and Y into contiguous or samples train/test sets
 # ------------------------------------------------------------------------------
 # NAME   TYPE    DEFAULT  MEANING
 # ------------------------------------------------------------------------------
 # X      Matrix  ---      Input feature matrix
-# y      Matrix  ---      Input 
+# Y      Matrix  ---      Input Labels
 # f      Double  0.7      Train set fraction [0,1]
 # cont   Boolean TRUE     contiuous splits, otherwise sampled
+# seed   Integer -1       The seed to reandomly select rows in sampled mode
 # ------------------------------------------------------------------------------
 # Xtrain Matrix  ---      Train split of feature matrix
 # Xtest  Matrix  ---      Test split of feature matrix
@@ -34,30 +35,30 @@
 # ytest  Matrix  ---      Test split of label matrix
 # ------------------------------------------------------------------------------
 
-m_split = function(Matrix[Double] X, Matrix[Double] y, Double f=0.7, Boolean cont=TRUE)
-  return (Matrix[Double] Xtrain, Matrix[Double] Xtest, Matrix[Double] ytrain, Matrix[Double] ytest) 
+m_split = function(Matrix[Double] X, Matrix[Double] Y, Double f=0.7, Boolean cont=TRUE, Integer seed=-1)
+  return (Matrix[Double] Xtrain, Matrix[Double] Xtest, Matrix[Double] Ytrain, Matrix[Double] Ytest) 
 {
   # basic sanity checks
   if( f <= 0 | f >= 1 )
-    print("Invalid train/test split configuration: f="+f);
-  if( nrow(X) != nrow(y) )
-    print("Mismatching number of rows X and y: "+nrow(X)+" "+nrow(y) )
+    stop("Invalid train/test split configuration: f="+f);
+  if( nrow(X) != nrow(Y) )
+    stop("Mismatching number of rows X and Y: "+nrow(X)+" "+nrow(Y) )
 
   # contiguous train/test splits
   if( cont ) {
     Xtrain = X[1:f*nrow(X),];
-    ytrain = y[1:f*nrow(X),];
+    Ytrain = Y[1:f*nrow(X),];
     Xtest = X[(nrow(Xtrain)+1):nrow(X),];
-    ytest = y[(nrow(Xtrain)+1):nrow(X),];
+    Ytest = Y[(nrow(Xtrain)+1):nrow(X),];
   }
   # sampled train/test splits
   else {
-    I = rand(rows=nrow(X), cols=1) <= f;
+    I = rand(rows=nrow(X), cols=1, seed=seed) <= f;
     P1 = removeEmpty(target=diag(I), margin="rows", select=I);
     P2 = removeEmpty(target=diag(I==0), margin="rows", select=I==0);
     Xtrain = P1 %*% X;
-    ytrain = P1 %*% y;
+    Ytrain = P1 %*% Y;
     Xtest = P2 %*% X;
-    ytest = P2 %*% y;
+    Ytest = P2 %*% Y;
   }
 }
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSplitTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSplitTest.java
new file mode 100644
index 0000000..ecf6e80
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSplitTest.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class BuiltinSplitTest extends AutomatedTestBase {
+	private final static String TEST_NAME = "split";
+	private final static String TEST_DIR = "functions/builtin/";
+	private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinSplitTest.class.getSimpleName() + "/";
+
+	@Override
+	public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B",}));
+	}
+
+	public double eps = 0.00001;
+	public int cols = 100;
+	public int rows = 10;
+
+	@Test
+	public void test_CP() {
+
+		runSplitTest(LopProperties.ExecType.CP);
+
+	}
+
+	@Test
+	public void test_Spark() {
+		runSplitTest(LopProperties.ExecType.SPARK);
+	}
+
+	private void runSplitTest(ExecType instType) {
+		ExecMode platformOld = setExecMode(instType);
+
+		try {
+			setOutputBuffering(true);
+		
+			loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+			String HOME = SCRIPT_DIR + TEST_DIR;
+
+			fullDMLScriptName = HOME + TEST_NAME + ".dml";
+			programArgs = new String[] {"-nvargs", "cols=" + cols, "rows=" + rows};
+
+			String out = runTest(null).toString();
+			Assert.assertTrue(out.contains("TRUE"));
+		}
+		finally {
+			rtplatform = platformOld;
+		}
+	}
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
index 6039354..3220e1a 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.test.functions.federated.primitives;
 
+import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -63,11 +64,13 @@
 		federatedMultiply(Types.ExecMode.SINGLE_NODE);
 	}
 
-	/*
-	 * FIXME spark execution mode support
-	 * 
-	 * @Test public void federatedMultiplySP() { federatedMultiply(Types.ExecMode.SPARK); }
-	 */
+
+	@Test 
+	@Ignore
+	public void federatedMultiplySP() {
+		// TODO Fix me Spark execution error
+		federatedMultiply(Types.ExecMode.SPARK);
+	}
 
 	public void federatedMultiply(Types.ExecMode execMode) {
 		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
new file mode 100644
index 0000000..a13c93a
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedSplitTest extends AutomatedTestBase {
+
+    private static final Log LOG = LogFactory.getLog(FederatedSplitTest.class.getName());
+    private final static String TEST_DIR = "functions/federated/";
+    private final static String TEST_NAME = "FederatedSplitTest";
+    private final static String TEST_CLASS_DIR = TEST_DIR + FederatedSplitTest.class.getSimpleName() + "/";
+
+    private final static int blocksize = 1024;
+    @Parameterized.Parameter()
+    public int rows;
+    @Parameterized.Parameter(1)
+    public int cols;
+    @Parameterized.Parameter(2)
+    public String cont;
+
+    @Parameterized.Parameters
+    public static Collection<Object[]> data() {
+        return Arrays.asList(new Object[][] {{152, 12, "TRUE"},{132, 11, "FALSE"}});
+    }
+
+    @Override
+    public void setUp() {
+        TestUtils.clearAssertionInformation();
+        addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+    }
+
+    @Test
+    public void federatedSplitCP() {
+        federatedSplit(Types.ExecMode.SINGLE_NODE);
+    }
+
+    public void federatedSplit(Types.ExecMode execMode) {
+        boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+        Types.ExecMode platformOld = rtplatform;
+        rtplatform = execMode;
+        if(rtplatform == Types.ExecMode.SPARK) {
+            DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+        }
+
+        getAndLoadTestConfiguration(TEST_NAME);
+        String HOME = SCRIPT_DIR + TEST_DIR;
+
+        // write input matrices
+        int halfRows = rows / 2;
+        // We have two matrices handled by a single federated worker
+        double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
+        double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
+        // And another two matrices handled by a single federated worker
+        double[][] Y1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 44);
+        double[][] Y2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 21);
+
+        writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+        writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+        writeInputMatrixWithMTD("Y1", Y1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+        writeInputMatrixWithMTD("Y2", Y2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+
+        TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+        loadTestConfiguration(config);
+
+        int port1 = getRandomAvailablePort();
+        int port2 = getRandomAvailablePort();
+        Thread t1 = startLocalFedWorkerThread(port1);
+        Thread t2 = startLocalFedWorkerThread(port2);
+
+        // Run reference dml script with normal matrix
+        fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+        programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
+            "Y2=" + input("Y2"), "Z=" + expected("Z"), "Cont=" + cont};
+        String out = runTest(null).toString();
+
+        // Run actual dml script with federated matrix
+        fullDMLScriptName = HOME + TEST_NAME + ".dml";
+        programArgs = new String[] {"-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
+            "X2=" + TestUtils.federatedAddress(port2, input("X2")),
+            "Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
+            "Y2=" + TestUtils.federatedAddress(port2, input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z"),
+            "Cont=" + cont};
+        String fedOut = runTest(null).toString();
+
+        LOG.error(out);
+        LOG.error(fedOut);
+        // compare via files
+        compareResults(1e-9);
+
+        TestUtils.shutdownThreads(t1, t2);
+        rtplatform = platformOld;
+        DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+    }
+}
diff --git a/src/test/scripts/functions/builtin/split.dml b/src/test/scripts/functions/builtin/split.dml
new file mode 100644
index 0000000..9ced733
--- /dev/null
+++ b/src/test/scripts/functions/builtin/split.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = rand(rows = $rows, cols=$cols, seed=1)
+Y = rand(rows = $rows, cols=1, seed=13)
+
+[Xtrain, Xtest, Ytrain, Ytest] = split(X=X,Y=Y, seed= 132)
+
+sumX = sum(X)
+sumY = sum(Y)
+
+sumXt = sum(Xtrain) + sum(Xtest)
+sumYt = sum(Ytrain) + sum(Ytest)
+
+sameXAndY = abs( sumX + sumY - sumXt - sumYt) < 0.001
+
+print(sameXAndY)
\ No newline at end of file
diff --git a/src/test/scripts/functions/federated/FederatedSplitTest.dml b/src/test/scripts/functions/federated/FederatedSplitTest.dml
new file mode 100644
index 0000000..44c59a9
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedSplitTest.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+    ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
+Y = federated(addresses=list($Y1, $Y2),
+    ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
+
+
+[Xtr, Xte, Ytr, Yte] = split(X=X,Y=Y,f=0.95, cont=$Cont, seed = 13)
+write(Xte, $Z)
+print(toString(Xte))
diff --git a/src/test/scripts/functions/federated/FederatedSplitTestReference.dml b/src/test/scripts/functions/federated/FederatedSplitTestReference.dml
new file mode 100644
index 0000000..4db8e1f
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedSplitTestReference.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2))
+Y = rbind(read($Y1), read($Y2))
+[Xtr, Xte, Ytr, Yte] = split(X=X,Y=Y, f=0.95 ,cont=$Cont, seed = 13)
+write(Xte, $Z)
+print(toString(Xte))