[SYSTEMDS-2572] Additional mlcontext test for nn-library imports
The bug reported in SYSTEMDS-2572 was non-reproducible both in a local
environment as well as through spark-shell. However, as the mlcontext
tests did not include a test for sourcing (importing) dml scripts, we
add the related test script accordingly.
diff --git a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
index 3e07b15..697e9e9 100644
--- a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTest.java
@@ -1904,5 +1904,17 @@
Assert.assertEquals(true, c);
Assert.assertEquals("yes it's TRUE", d);
}
-
+
+ @Test
+ public void testNNImport() {
+ System.out.println("MLContextTest - NN import");
+ String s = "source(\"scripts/nn/layers/relu.dml\") as relu;\n"
+ + "X = rand(rows=100, cols=10, min=-1, max=1);\n"
+ + "R1 = relu::forward(X);\n"
+ + "R2 = max(X, 0);\n"
+ + "R = sum(R1==R2);\n";
+ double ret = ml.execute(dml(s).out("R"))
+ .getScalarObject("R").getDoubleValue();
+ Assert.assertEquals(1000, ret, 1e-20);
+ }
}