[SYSTEMDS-3525] Disable and add tests for binary inplace operations
This patch adds a flag for the update inplace for binary operations.
The flag is disabled by default. This patch also adds a test which
exposes a bug of binary inplace. A binary inplace operation consuming
another inplaced intermediate (e.g. right index) leads to corruption.
We also disable binary inplace if lineage-based reuse is enabled to
avoid corrupting the cached intermediates.
Closes #1814
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 04585d7..74740f1 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -51,6 +51,7 @@
import org.apache.sysds.lops.SortKeys;
import org.apache.sysds.lops.Unary;
import org.apache.sysds.lops.UnaryCP;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -814,7 +815,8 @@
_etype = ExecType.SPARK;
}
- if( transitive && _etypeForced != ExecType.SPARK && _etypeForced != ExecType.FED && //
+ if( OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE &&
+ transitive && _etypeForced != ExecType.SPARK && _etypeForced != ExecType.FED &&
getDataType().isMatrix() // Output is a matrix
&& op == OpOp2.DIV // Operation is division
&& dt1.isMatrix() // Left hand side is a Matrix
@@ -823,6 +825,7 @@
&& memOfInputIsLessThanBudget() //
&& getInput().get(0).getExecType() != ExecType.SPARK // Is not already a spark operation
&& doesNotContainNanAndInf(getInput().get(1)) // Guaranteed not to densify the operation
+ && LineageCacheConfig.ReuseCacheType.isNone() // Inplace update corrupts the already cached input matrix block
) {
inplace = true;
_etype = ExecType.CP;
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index 580ddcc..4ab7c33 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -215,7 +215,14 @@
*/
//TODO enabling it by default requires modifications in lineage-based reuse
public static boolean ALLOW_UNARY_UPDATE_IN_PLACE = false;
-
+
+ /**
+ * Enables update-in-place for binary operators if the first input
+ * has no consumers. In this case we directly write the output
+ * values back to the first input block.
+ */
+ public static boolean ALLOW_BINARY_UPDATE_IN_PLACE = false;
+
/**
* Replace eval second-order function calls with normal function call
* if the function name is a known string (after constant propagation).
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
index 20119ce..cff0650 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
@@ -53,6 +53,10 @@
inplace = false;
}
+ public boolean isInPlace() {
+ return inplace;
+ }
+
@Override
public void processInstruction(ExecutionContext ec) {
// Read input matrices
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index a483b6c..0ce6cf3 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -27,6 +27,7 @@
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.cp.BinaryMatrixMatrixCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DataGenCPInstruction;
@@ -228,6 +229,8 @@
|| (inst instanceof DataGenCPInstruction) && ((DataGenCPInstruction) inst).isMatrixCall());
boolean updateInplace = (inst instanceof MatrixIndexingCPInstruction)
&& ec.getMatrixObject(((ComputationCPInstruction)inst).input1).getUpdateType().isInPlace();
+ updateInplace = updateInplace || ((inst instanceof BinaryMatrixMatrixCPInstruction)
+ && ((BinaryMatrixMatrixCPInstruction) inst).isInPlace());
boolean federatedOutput = false;
return insttype && rightop && !updateInplace && !federatedOutput;
}
diff --git a/src/test/java/org/apache/sysds/test/functions/updateinplace/BinaryUpdateInPlaceTest.java b/src/test/java/org/apache/sysds/test/functions/updateinplace/BinaryUpdateInPlaceTest.java
new file mode 100644
index 0000000..19bb2cd
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/updateinplace/BinaryUpdateInPlaceTest.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.updateinplace;
+
+ import java.util.ArrayList;
+ import java.util.HashMap;
+ import java.util.List;
+
+ import org.apache.sysds.common.Types;
+ import org.apache.sysds.hops.OptimizerUtils;
+ import org.apache.sysds.runtime.matrix.data.MatrixValue;
+ import org.apache.sysds.test.AutomatedTestBase;
+ import org.apache.sysds.test.TestConfiguration;
+ import org.apache.sysds.test.TestUtils;
+ import org.junit.Ignore;
+ import org.junit.Test;
+
+
+public class BinaryUpdateInPlaceTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "BinaryUpdateInplace";
+ private final static String TEST_DIR = "functions/updateinplace/";
+ private final static String TEST_CLASS_DIR = TEST_DIR + BinaryUpdateInPlaceTest.class.getSimpleName() + "/";
+ private final static double eps = 1e-3;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"B",}));
+ }
+
+ @Ignore
+ @Test
+ public void testInPlace() {
+ runInPlaceTest(Types.ExecType.CP);
+ }
+
+
+ private void runInPlaceTest(Types.ExecType instType) {
+ Types.ExecMode platformOld = setExecMode(instType);
+ boolean oldFlag = OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE;
+
+ try {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ List<String> proArgs = new ArrayList<>();
+ proArgs.add("-args");
+ proArgs.add(output("R"));
+ programArgs = proArgs.toArray(new String[proArgs.size()]);
+
+ OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE = true;
+ runTest(true, false, null, -1);
+ HashMap<MatrixValue.CellIndex, Double> R_inplace = readDMLMatrixFromOutputDir("R");
+ OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE = false;
+ runTest(true, false, null, -1);
+ HashMap<MatrixValue.CellIndex, Double> R = readDMLMatrixFromOutputDir("R");
+
+ //compare matrices
+ TestUtils.compareMatrices(R_inplace,R,eps,"with-Inplace","no_Inplace");
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ }
+ finally {
+ rtplatform = platformOld;
+ OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE = oldFlag;
+ }
+ }
+}
+
diff --git a/src/test/scripts/functions/updateinplace/BinaryUpdateInplace.dml b/src/test/scripts/functions/updateinplace/BinaryUpdateInplace.dml
new file mode 100644
index 0000000..8283d0e
--- /dev/null
+++ b/src/test/scripts/functions/updateinplace/BinaryUpdateInplace.dml
@@ -0,0 +1,66 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+D = rand(rows=32, cols=100, min=0, max=20, seed=42)
+bs = 32;
+ep = 3;
+iter_ep = ceil(nrow(D)/bs);
+maxiter = ep * iter_ep;
+beg = 1;
+iter = 0;
+i = 1;
+R = matrix(0, rows=1, cols=maxiter+1);
+
+while (iter < maxiter) {
+ end = beg + bs - 1;
+ if (end>nrow(D))
+ end = nrow(D);
+ X = D[beg:end,]
+
+ #inlace binary after inplace indexing corrupts the dataset
+ R[1,iter+1] = sum(D);
+
+ #reusable OP across epochs
+ X = scale(X, FALSE, TRUE);
+ #pollute cache with not reusable OPs
+ X = ((X + X) * i - X) / (i+1)
+ X = ((X + X) * i - X) / (i+1)
+ X = ((X + X) * i - X) / (i+1)
+ X = ((X + X) * i - X) / (i+1)
+ X = ((X + X) * i - X) / (i+1)
+ X = ((X + X) * i - X) / (i+1)
+ X = ((X + X) * i - X) / (i+1)
+ X = ((X + X) * i - X) / (i+1)
+ X = ((X + X) * i - X) / (i+1)
+ X = ((X + X) * i - X) / (i+1)
+
+ iter = iter + 1;
+ if (end == nrow(D))
+ beg = 1;
+ else
+ beg = end + 1;
+ i = i + 1;
+
+}
+#R = X;
+R[1,maxiter+1] = sum(X);
+write(R, $1, format="text");
+