[SYSTEMDS-211] Fix incorrect size propagation list-matrix rbind/cbind
This patch fixes size propagation issues during parsing and
recompilation for rbind/cbind operations over lists into a single
matrix. Together with other rewrites, the incorrect size propagation led
to invalid runtime plans.
However, the additional tests with CV-lm still require an assertion to
allow function inlining as a precondition for the fold-rewrite to
eliminate redundancy. Solving this remaining issue requires a principled
size propagation approach for matrix objects in lists.
diff --git a/src/main/java/org/apache/sysds/hops/NaryOp.java b/src/main/java/org/apache/sysds/hops/NaryOp.java
index d7b66dd..6419dbc 100644
--- a/src/main/java/org/apache/sysds/hops/NaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/NaryOp.java
@@ -214,14 +214,18 @@
public void refreshSizeInformation() {
switch( _op ) {
case CBIND:
- setDim1(HopRewriteUtils.getMaxInputDim(this, true));
- setDim2(HopRewriteUtils.getSumValidInputDims(this, false));
- setNnz(HopRewriteUtils.getSumValidInputNnz(this));
+ if( !getInput().get(0).getDataType().isList() ) {
+ setDim1(HopRewriteUtils.getMaxInputDim(this, true));
+ setDim2(HopRewriteUtils.getSumValidInputDims(this, false));
+ setNnz(HopRewriteUtils.getSumValidInputNnz(this));
+ }
break;
case RBIND:
- setDim1(HopRewriteUtils.getSumValidInputDims(this, true));
- setDim2(HopRewriteUtils.getMaxInputDim(this, false));
- setNnz(HopRewriteUtils.getSumValidInputNnz(this));
+ if( !getInput().get(0).getDataType().isList() ) {
+ setDim1(HopRewriteUtils.getSumValidInputDims(this, true));
+ setDim2(HopRewriteUtils.getMaxInputDim(this, false));
+ setNnz(HopRewriteUtils.getSumValidInputNnz(this));
+ }
break;
case MIN:
case MAX:
diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index b358396..eeb44f9 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -826,8 +826,8 @@
appendDim2 = (m2clen>=0) ? m2clen : appendDim2;
}
}
- //TODO: calculate output dimensions of List
- if( output.getDataType() == DataType.LIST ) {
+
+ if( id.getDataType() == DataType.LIST ) {
appendDim1 = -1;
appendDim2 = -1;
}
diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameReaderProto.java b/src/main/java/org/apache/sysds/runtime/io/FrameReaderProto.java
index 34f0dbb..a62f712 100644
--- a/src/main/java/org/apache/sysds/runtime/io/FrameReaderProto.java
+++ b/src/main/java/org/apache/sysds/runtime/io/FrameReaderProto.java
@@ -56,7 +56,7 @@
return outputFrameBlock;
}
- private void readProtoFrameFromHDFS(Path path, FileSystem fileSystem, FrameBlock dest, long rlen, long clen)
+ private static void readProtoFrameFromHDFS(Path path, FileSystem fileSystem, FrameBlock dest, long rlen, long clen)
throws IOException {
SysdsProtos.Frame frame = readProtoFrameFromFile(path, fileSystem);
for(int row = 0; row < rlen; row++) {
@@ -69,7 +69,7 @@
IOUtilFunctions.deleteCrcFilesFromLocalFileSystem(fileSystem, path);
}
- private SysdsProtos.Frame readProtoFrameFromFile(Path path, FileSystem fileSystem) throws IOException {
+ private static SysdsProtos.Frame readProtoFrameFromFile(Path path, FileSystem fileSystem) throws IOException {
FSDataInputStream fsDataInputStream = fileSystem.open(path);
try {
return SysdsProtos.Frame.newBuilder().mergeFrom(fsDataInputStream).build();
diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
index 30cb46b..96a8407 100644
--- a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceDedupTest.java
@@ -160,7 +160,7 @@
String dedup_trace = readDMLLineageFromHDFS("R");
LineageItem dedup_li = LineageParser.parseLineageTrace(dedup_trace);
-
+
//check lineage DAG
assertEquals(dedup_li, li);
}
diff --git a/src/test/java/org/apache/sysds/test/functions/misc/RewriteListTsmmCVTest.java b/src/test/java/org/apache/sysds/test/functions/misc/RewriteListTsmmCVTest.java
index c5201aa..cc1636a 100644
--- a/src/test/java/org/apache/sysds/test/functions/misc/RewriteListTsmmCVTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/misc/RewriteListTsmmCVTest.java
@@ -39,7 +39,8 @@
*/
public class RewriteListTsmmCVTest extends AutomatedTestBase
{
- private static final String TEST_NAME1 = "RewriteListTsmmCV";
+ private static final String TEST_NAME1 = "RewriteListTsmmCV1";
+ private static final String TEST_NAME2 = "RewriteListTsmmCV2";
private static final String TEST_DIR = "functions/misc/";
private static final String TEST_CLASS_DIR = TEST_DIR + RewriteListTsmmCVTest.class.getSimpleName() + "/";
@@ -50,29 +51,50 @@
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+ addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}));
+ addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"R"}));
}
@Test
- public void testListTsmmRewriteCP() {
+ public void testListTsmm1RewriteCP() {
testListTsmmCV(TEST_NAME1, true, false, ExecType.CP);
}
@Test
- public void testListTsmmRewriteSP() {
+ public void testListTsmm1RewriteSP() {
testListTsmmCV(TEST_NAME1, true, false, ExecType.SPARK);
}
@Test
- public void testListTsmmRewriteLineageCP() {
+ public void testListTsmm1RewriteLineageCP() {
testListTsmmCV(TEST_NAME1, true, true, ExecType.CP);
}
@Test
- public void testListTsmmRewriteLineageSP() {
+ public void testListTsmm1RewriteLineageSP() {
testListTsmmCV(TEST_NAME1, true, true, ExecType.SPARK);
}
+ @Test
+ public void testListTsmm2RewriteCP() {
+ testListTsmmCV(TEST_NAME2, true, false, ExecType.CP);
+ }
+
+ @Test
+ public void testListTsmm2RewriteSP() {
+ testListTsmmCV(TEST_NAME2, true, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testListTsmm2RewriteLineageCP() {
+ testListTsmmCV(TEST_NAME2, true, true, ExecType.CP);
+ }
+
+ @Test
+ public void testListTsmm2RewriteLineageSP() {
+ testListTsmmCV(TEST_NAME2, true, true, ExecType.SPARK);
+ }
+
private void testListTsmmCV( String testname, boolean rewrites, boolean lineage, ExecType instType )
{
ExecMode platformOld = setExecMode(instType);
diff --git a/src/test/scripts/functions/misc/RewriteListTsmmCV.dml b/src/test/scripts/functions/misc/RewriteListTsmmCV1.dml
similarity index 100%
rename from src/test/scripts/functions/misc/RewriteListTsmmCV.dml
rename to src/test/scripts/functions/misc/RewriteListTsmmCV1.dml
diff --git a/src/test/scripts/functions/misc/RewriteListTsmmCV2.dml b/src/test/scripts/functions/misc/RewriteListTsmmCV2.dml
new file mode 100644
index 0000000..84abda8
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteListTsmmCV2.dml
@@ -0,0 +1,57 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+crossV = function(Matrix[double] X, Matrix[double] y, double lamda, Integer k) return (Matrix[double] R)
+{
+ #create empty lists
+ dataset_X = list(); #empty list
+ dataset_y = list();
+ fs = ceil(nrow(X)/k);
+ off = fs - 1;
+ #devide X, y into lists of k matrices
+ for (i in seq(1, k)) {
+ dataset_X = append(dataset_X, X[i*fs-off : min(i*fs, nrow(X)),]);
+ dataset_y = append(dataset_y, y[i*fs-off : min(i*fs, nrow(y)),]);
+ }
+
+ beta_list = list();
+ #keep one fold for testing in each iteration
+ for (i in seq(1, k)) {
+ [tmpX, testX] = remove(dataset_X, i);
+ [tmpy, testy] = remove(dataset_y, i);
+ trainX = rbind(tmpX);
+ trainy = rbind(tmpy);
+ trainX = trainX[,1:ncol(X)] # TODO improve list size propagation
+ beta = lm(X=trainX, y=trainy, reg=lamda, verbose=FALSE);
+ beta_list = append(beta_list, beta);
+ }
+
+ R = cbind(beta_list);
+}
+
+X = rand(rows=$1, cols=$2);
+y = X %*% rand(rows=$2, cols=1);
+
+R = crossV(X, y, 0.001, 7);
+
+r = as.matrix(sum(R!=0));
+write(r, $3);
+#expected: "Result: $2*7