[SYSTEMDS-1863] Full MLContext test for LinearReg
* Takes advantage of existing R algorithm scripts used for
codegen testing.
* This would improve the testing by allowing us to provide all
the necessary inputs into the script.
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index a66ee1e..0183e34 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -1649,6 +1649,8 @@
}
protected String getRScript() {
+ if(fullRScriptName != null)
+ return fullRScriptName;
return sourceDirectory + selectedTest + ".R";
}
diff --git a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextLinregTest.java b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextLinregTest.java
index a5cddb8..0e45cb4 100644
--- a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextLinregTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextLinregTest.java
@@ -22,8 +22,13 @@
import static org.apache.sysds.api.mlcontext.ScriptFactory.dmlFromFile;
import org.apache.log4j.Logger;
-import org.junit.Test;
import org.apache.sysds.api.mlcontext.Script;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.HashMap;
public class MLContextLinregTest extends MLContextTestBase {
protected static Logger log = Logger.getLogger(MLContextLinregTest.class);
@@ -37,6 +42,11 @@
CG, DS,
}
+ private final static double eps = 1e-3;
+
+ private final static int rows = 2468;
+ private final static int cols = 507;
+
@Test
public void testLinregCGSparse() {
runLinregTestMLC(LinregType.CG, true);
@@ -59,24 +69,42 @@
private void runLinregTestMLC(LinregType type, boolean sparse) {
- double[][] X = getRandomMatrix(10, 3, 0, 1, sparse ? sparsity2 : sparsity1, 7);
- double[][] Y = getRandomMatrix(10, 1, 0, 10, 1.0, 3);
+ double[][] X = getRandomMatrix(rows, cols, 0, 1, sparse ? sparsity2 : sparsity1, 7);
+ double[][] Y = getRandomMatrix(rows, 1, 0, 10, 1.0, 3);
+
+ // Hack Alert
+ // overwrite baseDirectory to the place where test data is stored.
+ baseDirectory = "target/testTemp/functions/mlcontext/";
+
+ fullRScriptName = "src/test/scripts/functions/codegenalg/Algorithm_LinregCG.R";
+
+ writeInputMatrixWithMTD("X", X, true);
+ writeInputMatrixWithMTD("y", Y, true);
+
+ rCmd = getRCmd(inputDir(), "0", "0.000001", "0", "0.001", expectedDir());
+ runRScript(true);
+
+ MatrixBlock outmat = new MatrixBlock();
switch (type) {
case CG:
Script lrcg = dmlFromFile(TEST_SCRIPT_CG);
lrcg.in("X", X).in("y", Y).in("$icpt", "0").in("$tol", "0.000001").in("$maxi", "0").in("$reg", "0.000001")
.out("beta_out");
- ml.execute(lrcg);
+ outmat = ml.execute(lrcg).getMatrix("beta_out").toMatrixBlock();
break;
case DS:
Script lrds = dmlFromFile(TEST_SCRIPT_DS);
lrds.in("X", X).in("y", Y).in("$icpt", "0").in("$reg", "0.000001").out("beta_out");
- ml.execute(lrds);
+ outmat = ml.execute(lrds).getMatrix("beta_out").toMatrixBlock();
break;
}
+
+ //compare matrices
+ HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromFS("w");
+ TestUtils.compareMatrices(rfile, outmat, eps);
}
}