[SYSTEMDS-2722] Split Testing
Added tests (builtin and federated) for new split builtin.
diff --git a/scripts/builtin/split.dml b/scripts/builtin/split.dml
index b2d78d0..5e6f1c5 100644
--- a/scripts/builtin/split.dml
+++ b/scripts/builtin/split.dml
@@ -19,14 +19,15 @@
#
#-------------------------------------------------------------
-# Split input data X and y into contiguous or samples train/test sets
+# Split input data X and Y into contiguous or samples train/test sets
# ------------------------------------------------------------------------------
# NAME TYPE DEFAULT MEANING
# ------------------------------------------------------------------------------
# X Matrix --- Input feature matrix
-# y Matrix --- Input
+# Y Matrix --- Input Labels
# f Double 0.7 Train set fraction [0,1]
# cont Boolean TRUE contiuous splits, otherwise sampled
+# seed Integer -1 The seed to reandomly select rows in sampled mode
# ------------------------------------------------------------------------------
# Xtrain Matrix --- Train split of feature matrix
# Xtest Matrix --- Test split of feature matrix
@@ -34,30 +35,30 @@
# ytest Matrix --- Test split of label matrix
# ------------------------------------------------------------------------------
-m_split = function(Matrix[Double] X, Matrix[Double] y, Double f=0.7, Boolean cont=TRUE)
- return (Matrix[Double] Xtrain, Matrix[Double] Xtest, Matrix[Double] ytrain, Matrix[Double] ytest)
+m_split = function(Matrix[Double] X, Matrix[Double] Y, Double f=0.7, Boolean cont=TRUE, Integer seed=-1)
+ return (Matrix[Double] Xtrain, Matrix[Double] Xtest, Matrix[Double] Ytrain, Matrix[Double] Ytest)
{
# basic sanity checks
if( f <= 0 | f >= 1 )
- print("Invalid train/test split configuration: f="+f);
- if( nrow(X) != nrow(y) )
- print("Mismatching number of rows X and y: "+nrow(X)+" "+nrow(y) )
+ stop("Invalid train/test split configuration: f="+f);
+ if( nrow(X) != nrow(Y) )
+ stop("Mismatching number of rows X and Y: "+nrow(X)+" "+nrow(Y) )
# contiguous train/test splits
if( cont ) {
Xtrain = X[1:f*nrow(X),];
- ytrain = y[1:f*nrow(X),];
+ Ytrain = Y[1:f*nrow(X),];
Xtest = X[(nrow(Xtrain)+1):nrow(X),];
- ytest = y[(nrow(Xtrain)+1):nrow(X),];
+ Ytest = Y[(nrow(Xtrain)+1):nrow(X),];
}
# sampled train/test splits
else {
- I = rand(rows=nrow(X), cols=1) <= f;
+ I = rand(rows=nrow(X), cols=1, seed=seed) <= f;
P1 = removeEmpty(target=diag(I), margin="rows", select=I);
P2 = removeEmpty(target=diag(I==0), margin="rows", select=I==0);
Xtrain = P1 %*% X;
- ytrain = P1 %*% y;
+ Ytrain = P1 %*% Y;
Xtest = P2 %*% X;
- ytest = P2 %*% y;
+ Ytest = P2 %*% Y;
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSplitTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSplitTest.java
new file mode 100644
index 0000000..ecf6e80
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSplitTest.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class BuiltinSplitTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "split";
+ private final static String TEST_DIR = "functions/builtin/";
+ private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinSplitTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B",}));
+ }
+
+ public double eps = 0.00001;
+ public int cols = 100;
+ public int rows = 10;
+
+ @Test
+ public void test_CP() {
+
+ runSplitTest(LopProperties.ExecType.CP);
+
+ }
+
+ @Test
+ public void test_Spark() {
+ runSplitTest(LopProperties.ExecType.SPARK);
+ }
+
+ private void runSplitTest(ExecType instType) {
+ ExecMode platformOld = setExecMode(instType);
+
+ try {
+ setOutputBuffering(true);
+
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-nvargs", "cols=" + cols, "rows=" + rows};
+
+ String out = runTest(null).toString();
+ Assert.assertTrue(out.contains("TRUE"));
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
index 6039354..3220e1a 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
@@ -19,6 +19,7 @@
package org.apache.sysds.test.functions.federated.primitives;
+import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -63,11 +64,13 @@
federatedMultiply(Types.ExecMode.SINGLE_NODE);
}
- /*
- * FIXME spark execution mode support
- *
- * @Test public void federatedMultiplySP() { federatedMultiply(Types.ExecMode.SPARK); }
- */
+
+ @Test
+ @Ignore
+ public void federatedMultiplySP() {
+ // TODO Fix me Spark execution error
+ federatedMultiply(Types.ExecMode.SPARK);
+ }
public void federatedMultiply(Types.ExecMode execMode) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
new file mode 100644
index 0000000..a13c93a
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedSplitTest extends AutomatedTestBase {
+
+ private static final Log LOG = LogFactory.getLog(FederatedSplitTest.class.getName());
+ private final static String TEST_DIR = "functions/federated/";
+ private final static String TEST_NAME = "FederatedSplitTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedSplitTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+ @Parameterized.Parameter(2)
+ public String cont;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {{152, 12, "TRUE"},{132, 11, "FALSE"}});
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ }
+
+ @Test
+ public void federatedSplitCP() {
+ federatedSplit(Types.ExecMode.SINGLE_NODE);
+ }
+
+ public void federatedSplit(Types.ExecMode execMode) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+ rtplatform = execMode;
+ if(rtplatform == Types.ExecMode.SPARK) {
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int halfRows = rows / 2;
+ // We have two matrices handled by a single federated worker
+ double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
+ double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
+ // And another two matrices handled by a single federated worker
+ double[][] Y1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 44);
+ double[][] Y2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 21);
+
+ writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("Y1", Y1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("Y2", Y2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+
+ TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1);
+ Thread t2 = startLocalFedWorkerThread(port2);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
+ "Y2=" + input("Y2"), "Z=" + expected("Z"), "Cont=" + cont};
+ String out = runTest(null).toString();
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "X2=" + TestUtils.federatedAddress(port2, input("X2")),
+ "Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
+ "Y2=" + TestUtils.federatedAddress(port2, input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z"),
+ "Cont=" + cont};
+ String fedOut = runTest(null).toString();
+
+ LOG.error(out);
+ LOG.error(fedOut);
+ // compare via files
+ compareResults(1e-9);
+
+ TestUtils.shutdownThreads(t1, t2);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+}
diff --git a/src/test/scripts/functions/builtin/split.dml b/src/test/scripts/functions/builtin/split.dml
new file mode 100644
index 0000000..9ced733
--- /dev/null
+++ b/src/test/scripts/functions/builtin/split.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = rand(rows = $rows, cols=$cols, seed=1)
+Y = rand(rows = $rows, cols=1, seed=13)
+
+[Xtrain, Xtest, Ytrain, Ytest] = split(X=X,Y=Y, seed= 132)
+
+sumX = sum(X)
+sumY = sum(Y)
+
+sumXt = sum(Xtrain) + sum(Xtest)
+sumYt = sum(Ytrain) + sum(Ytest)
+
+sameXAndY = abs( sumX + sumY - sumXt - sumYt) < 0.001
+
+print(sameXAndY)
\ No newline at end of file
diff --git a/src/test/scripts/functions/federated/FederatedSplitTest.dml b/src/test/scripts/functions/federated/FederatedSplitTest.dml
new file mode 100644
index 0000000..44c59a9
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedSplitTest.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+ ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
+Y = federated(addresses=list($Y1, $Y2),
+ ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
+
+
+[Xtr, Xte, Ytr, Yte] = split(X=X,Y=Y,f=0.95, cont=$Cont, seed = 13)
+write(Xte, $Z)
+print(toString(Xte))
diff --git a/src/test/scripts/functions/federated/FederatedSplitTestReference.dml b/src/test/scripts/functions/federated/FederatedSplitTestReference.dml
new file mode 100644
index 0000000..4db8e1f
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedSplitTestReference.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2))
+Y = rbind(read($Y1), read($Y2))
+[Xtr, Xte, Ytr, Yte] = split(X=X,Y=Y, f=0.95 ,cont=$Cont, seed = 13)
+write(Xte, $Z)
+print(toString(Xte))