[SYSTEMDS-2656] Fix robustness spark transform encode (empty partitions)
This patch fixes an edge cases of spark transform encode (specifically
recode and dummy code) when spark partitions are completely empty.
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
index df5f16b..5dd7b7f 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
@@ -254,9 +254,11 @@
throws Exception
{
//build meta data (e.g., recode maps)
- if( _raEncoder != null )
+ if( _raEncoder != null ) {
+ _raEncoder.prepareBuildPartial();
while( iter.hasNext() )
_raEncoder.buildPartial(iter.next()._2());
+ }
//output recode maps as columnID - token pairs
ArrayList<Tuple2<Integer,Object>> ret = new ArrayList<>();
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
index e195835..d8d524a 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
@@ -121,14 +121,16 @@
protected void putCode(HashMap<String,Long> map, String key) {
map.put(key, Long.valueOf(map.size()+1));
}
+
+ public void prepareBuildPartial() {
+ //ensure allocated partial recode map
+ if( _rcdMapsPart == null )
+ _rcdMapsPart = new HashMap<>();
+ }
public void buildPartial(FrameBlock in) {
if( !isApplicable() )
return;
-
- //ensure allocated partial recode map
- if( _rcdMapsPart == null )
- _rcdMapsPart = new HashMap<>();
//construct partial recode map (tokens w/o codes)
//iterate over columns for sequential access
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinMiceTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinMiceTest.java
index a647918..ab76824 100644
--- a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinMiceTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinMiceTest.java
@@ -51,42 +51,34 @@
runMiceNominalTest(mask, 1, false, LopProperties.ExecType.CP);
}
-// @Test
-// public void testMiceMixSpark() {
-// double[][] mask = {{ 0.0, 0.0, 1.0, 1.0, 0.0}};
-// runMiceNominalTest(mask, 1, LopProperties.ExecType.SPARK);
-// }
-
@Test
public void testMiceNumberCP() {
double[][] mask = {{ 0.0, 0.0, 0.0, 0.0, 0.0}};
runMiceNominalTest(mask, 2, false, LopProperties.ExecType.CP);
}
-// @Test
-// public void testMiceNumberSpark() {
-// double[][] mask = {{ 0.0, 0.0, 0.0, 0.0, 0.0}};
-// runMiceNominalTest(mask, 2, LopProperties.ExecType.SPARK);
-// }
-
@Test
public void testMiceCategoricalCP() {
double[][] mask = {{ 1.0, 1.0, 1.0, 1.0, 1.0}};
runMiceNominalTest(mask, 3, false, LopProperties.ExecType.CP);
}
-// @Test
-// public void testMiceCategoricalSpark() {
-// double[][] mask = {{ 1.0, 1.0, 1.0, 1.0, 1.0}};
-// runMiceNominalTest(mask, 3, LopProperties.ExecType.SPARK);
-// }
-
@Test
public void testMiceMixLineageReuseCP() {
double[][] mask = {{ 0.0, 0.0, 1.0, 1.0, 0.0}};
runMiceNominalTest(mask, 1, true, LopProperties.ExecType.CP);
}
+ //added a single, relatively-fast spark test, others seem infeasible
+ //as forcing every operation to spark takes too long for complex,
+ //composite builtins like mice.
+
+ @Test
+ public void testMiceNumberSpark() {
+ double[][] mask = {{ 0.0, 0.0, 0.0, 0.0, 0.0}};
+ runMiceNominalTest(mask, 2, false, LopProperties.ExecType.SPARK);
+ }
+
private void runMiceNominalTest(double[][] mask, int testType, boolean lineage, LopProperties.ExecType instType) {
Types.ExecMode platformOld = setExecMode(instType);
try {
@@ -94,10 +86,10 @@
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[]{"-nvargs", "X=" + DATASET, "Mask="+input("M"),
- "iteration=" + iter, "dataN=" + output("N"), "dataC=" + output("C")};
+ "iteration=" + iter, "dataN=" + output("N"), "dataC=" + output("C")};
if (lineage) {
- String[] lin = new String[] {"-stats","-lineage", ReuseCacheType.REUSE_HYBRID.name().toLowerCase()};
- programArgs = (String[]) ArrayUtils.addAll(programArgs, lin);
+ programArgs = (String[]) ArrayUtils.addAll(programArgs, new String[] {
+ "-stats","-lineage", ReuseCacheType.REUSE_HYBRID.name().toLowerCase()});
}
writeInputMatrixWithMTD("M", mask, true);
@@ -125,18 +117,16 @@
}
}
- private void testNumericOutput()
- {
+ private void testNumericOutput() {
//compare matrices
HashMap<MatrixValue.CellIndex, Double> dmlfileN = readDMLMatrixFromHDFS("N");
HashMap<MatrixValue.CellIndex, Double> rfileN = readRMatrixFromFS("N");
// compare numerical imputations
TestUtils.compareMatrices(dmlfileN, rfileN, eps, "Stat-DML", "Stat-R");
-
}
- private void testCategoricalOutput()
- {
+
+ private void testCategoricalOutput() {
HashMap<MatrixValue.CellIndex, Double> dmlfileC = readDMLMatrixFromHDFS("C");
HashMap<MatrixValue.CellIndex, Double> rfileC = readRMatrixFromFS("C");
@@ -154,4 +144,4 @@
else
Assert.fail("categorical test fails, the true value count is less than 98%");
}
-}
\ No newline at end of file
+}