[SYSTEMDS-419] New rewrite for compiler-assisted partial lineage reuse

This patch adds a new rewrite for compiler-assisted partial rewrites of
the lineage cache. Such rewrites can be applied during recompilation
with the scope of entire operation DAGs and thus avoid executing
operations that will later become part of a partial rewrite pattern.

Furthermore, this makes also a minor change of fixing the young
generation memory budget in order tests to avoid OOMs in the github
action tests.
diff --git a/pom.xml b/pom.xml
index 0f4e208..76a9965 100644
--- a/pom.xml
+++ b/pom.xml
@@ -52,7 +52,7 @@
 		<jcuda.version>10.2.0</jcuda.version>
 		<!-->Testing settings<!-->
 		<skipTests>true</skipTests>
-		<argLine>-Xms4g -Xmx4g</argLine>
+		<argLine>-Xms4g -Xmx4g -Xmn400m</argLine>
 	</properties>
 
 	<repositories>
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index 1412ffc..bd260b9 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -671,7 +671,7 @@
 		return auop;
 	}
 	
-	public static AggBinaryOp createTSMM(Hop input, boolean left) {
+	public static AggBinaryOp createTsmm(Hop input, boolean left) {
 		Hop trans = createTranspose(input);
 		return createMatrixMultiply(
 			left ? trans : input, left ? input : trans);
@@ -1431,6 +1431,23 @@
 		return ret;
 	}
 	
+	public static Hop createPartialTsmmCbind(Hop X, Hop deltaX, Hop tsmmIn1) {
+		//partial rewrite to rewrite tsmm(cbind(in1, in2)) into form that can reuse tsmm(in1)
+		// cell bottomLeft = t(lastCol) %*% oldMatrix
+		ReorgOp tLastCol = HopRewriteUtils.createTranspose(deltaX);
+		AggBinaryOp bottomLeft = HopRewriteUtils.createMatrixMultiply(tLastCol, X);
+		// cell topRight = t(oldMatrix) %*% lastCol = t(bottomLeft)
+		ReorgOp topRight = HopRewriteUtils.createTranspose(bottomLeft);
+		// bottomRight = t(lastCol) %*% lastCol
+		AggBinaryOp bottomRight = HopRewriteUtils.createMatrixMultiply(tLastCol, deltaX);
+		// rowOne = cbind(lastRes, topRight)
+		BinaryOp rowOne = HopRewriteUtils.createBinary(tsmmIn1, topRight, OpOp2.CBIND);
+		// rowTwo = cbind(bottomLeft, bottomRight)
+		BinaryOp rowTwo = HopRewriteUtils.createBinary(bottomLeft, bottomRight, OpOp2.CBIND);
+		// rbind(rowOne, rowTwo)
+		return HopRewriteUtils.createBinary(rowOne, rowTwo, OpOp2.RBIND);
+	}
+	
 	//////////////////////////////////////
 	// utils for lookup tables
 	
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 1929315..191c35d 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -498,7 +498,7 @@
 			List<Hop> inputs = hi.getInput().get(1).getInput();
 			if( HopRewriteUtils.checkAvgRowsGteCols(inputs) ) {
 				Hop[] tsmms = inputs.stream()
-					.map(h -> HopRewriteUtils.createTSMM(h, true)).toArray(Hop[]::new);
+					.map(h -> HopRewriteUtils.createTsmm(h, true)).toArray(Hop[]::new);
 				hnew = HopRewriteUtils.createNary(OpOpN.PLUS, tsmms);
 				//cleanup parent references from rbind
 				//HopRewriteUtils.removeAllChildReferences(hi.getInput().get(1));
@@ -528,6 +528,19 @@
 				branch = 2;
 			}
 		}
+		//pattern 3: X = t(cbind(A, B)) %*% cbind(A, B), w/ one cbind consumer (twice in tsmm)
+		else if( HopRewriteUtils.isTsmm(hi) && hi.getInput().get(1).getParent().size()==2
+			&& HopRewriteUtils.isTransposeOperation(hi.getInput().get(0))
+			&& HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.CBIND) )
+		{
+			Hop input1 = hi.getInput().get(1).getInput().get(0);
+			Hop input2 = hi.getInput().get(1).getInput().get(1);
+			if( input1.getDim1() > input1.getDim2() && input2.getDim2() == 1 ) {
+				hnew = HopRewriteUtils.createPartialTsmmCbind(
+					input1, input2, HopRewriteUtils.createTsmm(input1, true));
+				branch = 3;
+			}
+		}
 		
 		//modify dag if one of the above rules applied
 		if( hnew != null ){ 
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index f51a8b3..5ab97bf 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -1134,7 +1134,7 @@
 	
 	private static Hop simplifyTransposedAppend( Hop parent, Hop hi, int pos )
 	{
-		//e.g., t(cbind(t(A),t(B))) --> rbind(A,B), t(rbind(t(A),t(B))) --> cbind(A,B)		
+		//e.g., t(cbind(t(A),t(B))) --> rbind(A,B), t(rbind(t(A),t(B))) --> cbind(A,B)
 		if(   HopRewriteUtils.isTransposeOperation(hi)  //t() rooted
 		   && hi.getInput().get(0) instanceof BinaryOp
 		   && (((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.CBIND    //append (cbind/rbind)
@@ -1156,7 +1156,7 @@
 				HopRewriteUtils.replaceChildReference(parent, hi, bopnew, pos);
 				
 				hi = bopnew;
-				LOG.debug("Applied simplifyTransposedAppend (line "+hi.getBeginLine()+").");				
+				LOG.debug("Applied simplifyTransposedAppend (line "+hi.getBeginLine()+").");
 			}
 		}
 		
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
index a884f9f..c1a7cd6 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
@@ -170,19 +170,7 @@
 			setupTReadCachedInput("deltaX", inCache, lrwec) :
 			HopRewriteUtils.createIndexingOp(newMatrix, 1L, mo.getNumRows(), mo.getNumColumns(), mo.getNumColumns());
 		
-		// cell bottomLeft = t(lastCol) %*% oldMatrix
-		ReorgOp tLastCol = HopRewriteUtils.createTranspose(lastCol);
-		AggBinaryOp bottomLeft = HopRewriteUtils.createMatrixMultiply(tLastCol, oldMatrix);
-		// cell topRight = t(oldMatrix) %*% lastCol = t(bottomLeft)
-		ReorgOp topRight = HopRewriteUtils.createTranspose(bottomLeft);
-		// bottomRight = t(lastCol) %*% lastCol
-		AggBinaryOp bottomRight = HopRewriteUtils.createMatrixMultiply(tLastCol, lastCol);
-		// rowOne = cbind(lastRes, topRight)
-		BinaryOp rowOne = HopRewriteUtils.createBinary(lastRes, topRight, OpOp2.CBIND);
-		// rowTwo = cbind(bottomLeft, bottomRight)
-		BinaryOp rowTwo = HopRewriteUtils.createBinary(bottomLeft, bottomRight, OpOp2.CBIND);
-		// rbind(rowOne, rowTwo)
-		BinaryOp lrwHop= HopRewriteUtils.createBinary(rowOne, rowTwo, OpOp2.RBIND);
+		Hop lrwHop = HopRewriteUtils.createPartialTsmmCbind(oldMatrix, lastCol, lastRes);
 		DataOp lrwWrite = HopRewriteUtils.createTransientWrite(LR_VAR, lrwHop);
 
 		// generate runtime instructions
diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/PartialReuseTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/PartialReuseTest.java
index 4533fab..5775b64 100644
--- a/src/test/java/org/apache/sysds/test/functions/lineage/PartialReuseTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/PartialReuseTest.java
@@ -31,6 +31,7 @@
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
 
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -54,11 +55,11 @@
 		testLineageTraceReuse(TEST_NAME1, ExecMode.SINGLE_NODE);
 	}
 	
-//	@Test
-//	public void testLineageTrace1Hybrid() {
-//		//test partial reuse in Hybrid (i.e., w/ reuse-aware recompilation)
-//		testLineageTraceReuse(TEST_NAME1, ExecMode.HYBRID);
-//	}
+	@Test
+	public void testLineageTrace1Hybrid() {
+		//test partial reuse in Hybrid (i.e., w/ reuse-aware recompilation)
+		testLineageTraceReuse(TEST_NAME1, ExecMode.HYBRID);
+	}
 
 	
 	public void testLineageTraceReuse(String testname, ExecMode et) {
@@ -98,6 +99,9 @@
 			
 			//check no evictions (previously buffer pool leak)
 			Assert.assertEquals(0, CacheStatistics.getFSWrites());
+			//if compiler assisted reuse check for the introduced appends (3x per iteration)
+			if( et == ExecMode.HYBRID )
+				Assert.assertEquals(900, Statistics.getCPHeavyHitterCount("append"));
 		}
 		finally {
 			resetExecMode(execModeOld);
diff --git a/src/test/scripts/functions/lineage/PartialReuse1.dml b/src/test/scripts/functions/lineage/PartialReuse1.dml
index a1d8d35..61e3fd2 100644
--- a/src/test/scripts/functions/lineage/PartialReuse1.dml
+++ b/src/test/scripts/functions/lineage/PartialReuse1.dml
@@ -22,12 +22,11 @@
 # Increase k for better performance gains
 
 X = rand(rows=20000, cols=300, seed=42);
-sum = 0;
-tmp = matrix(0, rows=nrow(X), cols=0);
+X2 = rand(rows=20000, cols=ncol(X)/2, seed=42);
 R = matrix(0, 1, ncol(X));
 
 for (i in 1:ncol(X)) {
-  tmp = cbind(tmp, X[,i]);
+  tmp = cbind(X2, X[,i]);
   Res1 = t(tmp) %*% tmp;
   while(FALSE) {}
   R[1,i] = sum(Res1);