[SYSTEMDS-3525] Disable and add tests for binary inplace operations

This patch adds a flag for the update inplace for binary operations.
The flag is disabled by default. This patch also adds a test which
exposes a bug of binary inplace. A binary inplace operation consuming
another inplaced intermediate (e.g. right index) leads to corruption.
We also disable binary inplace if lineage-based reuse is enabled to
avoid corrupting the cached intermediates.

Closes #1814
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 04585d7..74740f1 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -51,6 +51,7 @@
 import org.apache.sysds.lops.SortKeys;
 import org.apache.sysds.lops.Unary;
 import org.apache.sysds.lops.UnaryCP;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 
@@ -814,7 +815,8 @@
 			_etype = ExecType.SPARK;
 		}
 
-		if( transitive && _etypeForced != ExecType.SPARK && _etypeForced != ExecType.FED && //
+		if( OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE &&
+			transitive && _etypeForced != ExecType.SPARK && _etypeForced != ExecType.FED &&
 			getDataType().isMatrix() // Output is a matrix
 			&& op == OpOp2.DIV // Operation is division
 			&& dt1.isMatrix() // Left hand side is a Matrix
@@ -823,6 +825,7 @@
 			&& memOfInputIsLessThanBudget() //
 			&& getInput().get(0).getExecType() != ExecType.SPARK // Is not already a spark operation
 			&& doesNotContainNanAndInf(getInput().get(1)) // Guaranteed not to densify the operation
+			&& LineageCacheConfig.ReuseCacheType.isNone() // Inplace update corrupts the already cached input matrix block
 		) {
 			inplace = true;
 			_etype = ExecType.CP;
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index 580ddcc..4ab7c33 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -215,7 +215,14 @@
 	 */
 	//TODO enabling it by default requires modifications in lineage-based reuse
 	public static boolean ALLOW_UNARY_UPDATE_IN_PLACE = false;
-	
+
+	/**
+	 * Enables update-in-place for binary operators if the first input
+	 * has no consumers. In this case we directly write the output
+	 * values back to the first input block.
+	 */
+	public static boolean ALLOW_BINARY_UPDATE_IN_PLACE = false;
+
 	/**
 	 * Replace eval second-order function calls with normal function call
 	 * if the function name is a known string (after constant propagation).
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
index 20119ce..cff0650 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
@@ -53,6 +53,10 @@
 			inplace = false;
 	}
 
+	public boolean isInPlace() {
+		return inplace;
+	}
+
 	@Override
 	public void processInstruction(ExecutionContext ec) {
 		// Read input matrices
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index a483b6c..0ce6cf3 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -27,6 +27,7 @@
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.cp.BinaryMatrixMatrixCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.DataGenCPInstruction;
@@ -228,6 +229,8 @@
 			|| (inst instanceof DataGenCPInstruction) && ((DataGenCPInstruction) inst).isMatrixCall());
 		boolean updateInplace = (inst instanceof MatrixIndexingCPInstruction)
 			&& ec.getMatrixObject(((ComputationCPInstruction)inst).input1).getUpdateType().isInPlace();
+		updateInplace = updateInplace || ((inst instanceof BinaryMatrixMatrixCPInstruction)
+			&& ((BinaryMatrixMatrixCPInstruction) inst).isInPlace());
 		boolean federatedOutput = false;
 		return insttype && rightop && !updateInplace && !federatedOutput;
 	}
diff --git a/src/test/java/org/apache/sysds/test/functions/updateinplace/BinaryUpdateInPlaceTest.java b/src/test/java/org/apache/sysds/test/functions/updateinplace/BinaryUpdateInPlaceTest.java
new file mode 100644
index 0000000..19bb2cd
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/updateinplace/BinaryUpdateInPlaceTest.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.updateinplace;
+
+	import java.util.ArrayList;
+	import java.util.HashMap;
+	import java.util.List;
+
+	import org.apache.sysds.common.Types;
+	import org.apache.sysds.hops.OptimizerUtils;
+	import org.apache.sysds.runtime.matrix.data.MatrixValue;
+	import org.apache.sysds.test.AutomatedTestBase;
+	import org.apache.sysds.test.TestConfiguration;
+	import org.apache.sysds.test.TestUtils;
+	import org.junit.Ignore;
+	import org.junit.Test;
+
+
+public class BinaryUpdateInPlaceTest extends AutomatedTestBase {
+	private final static String TEST_NAME = "BinaryUpdateInplace";
+	private final static String TEST_DIR = "functions/updateinplace/";
+	private final static String TEST_CLASS_DIR = TEST_DIR + BinaryUpdateInPlaceTest.class.getSimpleName() + "/";
+	private final static double eps = 1e-3;
+
+	@Override
+	public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"B",}));
+	}
+
+	@Ignore
+	@Test
+	public void testInPlace() {
+		runInPlaceTest(Types.ExecType.CP);
+	}
+
+
+	private void runInPlaceTest(Types.ExecType instType) {
+		Types.ExecMode platformOld = setExecMode(instType);
+		boolean oldFlag = OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE;
+
+		try {
+			loadTestConfiguration(getTestConfiguration(TEST_NAME));
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + TEST_NAME + ".dml";
+			List<String> proArgs = new ArrayList<>();
+			proArgs.add("-args");
+			proArgs.add(output("R"));
+			programArgs = proArgs.toArray(new String[proArgs.size()]);
+
+			OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE = true;
+			runTest(true, false, null, -1);
+			HashMap<MatrixValue.CellIndex, Double> R_inplace = readDMLMatrixFromOutputDir("R");
+			OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE = false;
+			runTest(true, false, null, -1);
+			HashMap<MatrixValue.CellIndex, Double> R = readDMLMatrixFromOutputDir("R");
+
+			//compare matrices
+			TestUtils.compareMatrices(R_inplace,R,eps,"with-Inplace","no_Inplace");
+		}
+		catch(Exception e) {
+			e.printStackTrace();
+		}
+		finally {
+			rtplatform = platformOld;
+			OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE = oldFlag;
+		}
+	}
+}
+
diff --git a/src/test/scripts/functions/updateinplace/BinaryUpdateInplace.dml b/src/test/scripts/functions/updateinplace/BinaryUpdateInplace.dml
new file mode 100644
index 0000000..8283d0e
--- /dev/null
+++ b/src/test/scripts/functions/updateinplace/BinaryUpdateInplace.dml
@@ -0,0 +1,66 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+D = rand(rows=32, cols=100, min=0, max=20, seed=42)
+bs = 32;
+ep = 3;
+iter_ep = ceil(nrow(D)/bs);
+maxiter = ep * iter_ep;
+beg = 1;
+iter = 0;
+i = 1;
+R = matrix(0, rows=1, cols=maxiter+1);
+
+while (iter < maxiter) {
+  end = beg + bs - 1;
+  if (end>nrow(D))
+    end = nrow(D);
+  X = D[beg:end,]
+
+  #inlace binary after inplace indexing corrupts the dataset
+  R[1,iter+1] = sum(D);
+
+  #reusable OP across epochs
+  X = scale(X, FALSE, TRUE);
+  #pollute cache with not reusable OPs
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+  X = ((X + X) * i - X) / (i+1)
+
+  iter = iter + 1;
+  if (end == nrow(D))
+    beg = 1;
+  else
+    beg = end + 1;
+  i = i + 1;
+
+}
+#R = X;
+R[1,maxiter+1] = sum(X);
+write(R, $1, format="text");
+