[SYSTEMDS-414] New rewrite for PCA -> lmDS pipeline

This patch contains a rewrite to reuse tsmm result in lmDS if
called after PCA incrementally for increasing number of columns.
diff --git a/dev/Tasks.txt b/dev/Tasks.txt
index 8c0da13..689949c 100644
--- a/dev/Tasks.txt
+++ b/dev/Tasks.txt
@@ -339,6 +339,7 @@
  * 411 Improved handling of multi-level cache duplicates              OK 
  * 412 Robust lineage tracing (non-recursive, parfor)                 OK
  * 413 Cache and reuse MultiReturnBuiltin instructions                OK
+ * 414 New rewrite for PCA --> lmDS pipeline                          OK
 
 SYSTEMDS-500 Documentation Webpage Reintroduction
  * 501 Make Documentation webpage framework                           OK
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
index c25f132..829ab43 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
@@ -56,6 +56,7 @@
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.InstructionParser;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.DataGenCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
@@ -114,6 +115,8 @@
 		newInst = (newInst == null) ? rewriteAggregateCbind(curr, ec, lrwec) : newInst;
 		//A %*% B[,1:k] = (A %*% B)[,1:k];
 		newInst = (newInst == null) ? rewriteIndexingMatMul(curr, ec, lrwec) : newInst;
+		//PCA --> lmDS pipeline
+		newInst = (newInst == null) ? rewritePcaTsmm(curr, ec, lrwec) : newInst;
 		
 		if (newInst == null)
 			return false;
@@ -664,6 +667,65 @@
 			LineageCacheStatistics.incrementPRewrites();
 		return inst;
 	}
+
+	private static ArrayList<Instruction> rewritePcaTsmm(Instruction curr, ExecutionContext ec, ExecutionContext lrwec)
+	{
+		Map<String, MatrixBlock> inCache = new HashMap<>();
+		if (!isPcaTsmm(curr, ec, inCache))
+			return null;
+
+		// Create a transient read op over the last tsmm result
+		MatrixBlock cachedEntry = inCache.get("lastMatrix");
+		MatrixObject newmo = convMBtoMO(cachedEntry);
+		lrwec.setVariable("cachedEntry", newmo);
+		DataOp lastRes = HopRewriteUtils.createTransientRead("cachedEntry", cachedEntry);
+
+		// Create a transient read op over this tsmm's input 
+		MatrixObject mo = ec.getMatrixObject(((ComputationCPInstruction)curr).input1);
+		lrwec.setVariable("oldMatrix", mo);
+		DataOp newMatrix = HopRewriteUtils.createTransientRead("oldMatrix", mo);
+		
+		// Index out the added column from the projected matrix
+		MatrixBlock projected = inCache.get("projected");
+		MatrixObject projmo = convMBtoMO(projected);
+		lrwec.setVariable("projected", projmo);
+		DataOp projRes = HopRewriteUtils.createTransientRead("projected", projmo);
+		IndexingOp lastCol = HopRewriteUtils.createIndexingOp(projRes, new LiteralOp(1), new LiteralOp(projmo.getNumRows()), 
+				new LiteralOp(projmo.getNumColumns()), new LiteralOp(projmo.getNumColumns()));
+		
+		// Apply t(lastCol) on i/p matrix to get the result vectors.
+		ReorgOp tlastCol = HopRewriteUtils.createTranspose(lastCol);
+		AggBinaryOp newCol = HopRewriteUtils.createMatrixMultiply(tlastCol, newMatrix);
+		ReorgOp tnewCol = HopRewriteUtils.createTranspose(newCol);
+		
+		// Push the result row & column inside the cashed block as 2nd last row and col respectively.
+		IndexingOp topLeft = HopRewriteUtils.createIndexingOp(lastRes, new LiteralOp(1), new LiteralOp(newmo.getNumRows()-1), 
+			new LiteralOp(1), new LiteralOp(newmo.getNumColumns()-1));
+		IndexingOp topRight = HopRewriteUtils.createIndexingOp(lastRes, new LiteralOp(1), new LiteralOp(newmo.getNumRows()-1), 
+			new LiteralOp(newmo.getNumColumns()), new LiteralOp(newmo.getNumColumns()));
+		IndexingOp bottomLeft = HopRewriteUtils.createIndexingOp(lastRes, new LiteralOp(newmo.getNumRows()), 
+			new LiteralOp(newmo.getNumRows()), new LiteralOp(1), new LiteralOp(newmo.getNumColumns()-1));
+		IndexingOp bottomRight = HopRewriteUtils.createIndexingOp(lastRes, new LiteralOp(newmo.getNumRows()), 
+			new LiteralOp(newmo.getNumRows()), new LiteralOp(newmo.getNumColumns()), new LiteralOp(newmo.getNumColumns()));
+		IndexingOp topCol = HopRewriteUtils.createIndexingOp(tnewCol, new LiteralOp(1), new LiteralOp(mo.getNumColumns()-2), 
+			new LiteralOp(1), new LiteralOp(1));
+		IndexingOp bottomCol = HopRewriteUtils.createIndexingOp(tnewCol, new LiteralOp(mo.getNumColumns()), 
+			new LiteralOp(mo.getNumColumns()), new LiteralOp(1), new LiteralOp(1));
+		NaryOp rowOne = HopRewriteUtils.createNary(OpOpN.CBIND, topLeft, topCol, topRight);
+		NaryOp rowTwo = HopRewriteUtils.createNary(OpOpN.CBIND, bottomLeft, bottomCol, bottomRight);
+		NaryOp lrwHop = HopRewriteUtils.createNary(OpOpN.RBIND, rowOne, newCol, rowTwo);
+		DataOp lrwWrite = HopRewriteUtils.createTransientWrite(LR_VAR, lrwHop);
+		
+		// Generate runtime instructions
+		if (LOG.isDebugEnabled())
+			LOG.debug("LINEAGE REWRITE rewritePcaTsmm APPLIED");
+		ArrayList<Instruction> inst = genInst(lrwWrite, lrwec);
+		_disableReuse = true;
+
+		if (DMLScript.STATISTICS) 
+			LineageCacheStatistics.incrementPRewrites();
+		return inst;
+	}
 	
 	/*------------------------REWRITE APPLICABILITY CHECKS-------------------------*/
 
@@ -963,6 +1025,41 @@
 		return inCache.containsKey("indexSource") ? true : false;
 	}
 
+	private static boolean isPcaTsmm(Instruction curr, ExecutionContext ec, Map<String, MatrixBlock> inCache) {
+		if (!LineageCacheConfig.isReusable(curr, ec)) {
+			return false;
+		}
+		
+		LineageItem item = ((ComputationCPInstruction) curr).getLineageItem(ec).getValue();
+		if (curr.getOpcode().equalsIgnoreCase("tsmm")) {
+			LineageItem src1 = item.getInputs()[0];
+			if (src1.getOpcode().equalsIgnoreCase("cbind")) {
+				LineageItem src21 = src1.getInputs()[0];
+				LineageItem src22 = src1.getInputs()[1]; //ones
+				if (src21.getOpcode().equalsIgnoreCase("ba+*")) {
+					if (LineageCache.probe(src21))
+						inCache.put("projected", LineageCache.getMatrix(src21));
+				
+					LineageItem src31 = src21.getInputs()[1];
+					LineageItem src32 = src21.getInputs()[0];
+					if (src31.getOpcode().equalsIgnoreCase("rightIndex")) {
+						LineageItem cu = src31.getInputs()[4];
+						//TODO: delta with more than one column
+						LineageItem old_cu = reduceColByOne(cu);
+						LineageItem old_RI = new LineageItem("rightIndex", new LineageItem[] {src31.getInputs()[0], 
+								src31.getInputs()[1], src31.getInputs()[2], src31.getInputs()[3], old_cu});
+						LineageItem old_ba = new LineageItem("ba+*", new LineageItem[] {src32, old_RI});
+						LineageItem old_cbind = new LineageItem("cbind", new LineageItem[] {old_ba, src22});
+						LineageItem old_tsmm = new LineageItem("tsmm", new LineageItem[] {old_cbind});
+						if (LineageCache.probe(old_tsmm))
+							inCache.put("lastMatrix", LineageCache.getMatrix(old_tsmm));
+					}
+				}
+			}
+		}
+		return inCache.containsKey("projected") && inCache.containsKey("lastMatrix");
+	}
+
 	/*----------------------INSTRUCTIONS GENERATION & EXECUTION-----------------------*/
 
 	private static ArrayList<Instruction> genInst(Hop hops, ExecutionContext ec) {
@@ -1006,6 +1103,15 @@
 		mo.release();
 		return mo;
 	}
+	
+	private static LineageItem reduceColByOne(LineageItem cu) {
+		String data = cu.getData();  //xx·SCALAR·INT64·true
+		String[] parts = data.split(Instruction.VALUETYPE_PREFIX);
+		int cuNum = Integer.valueOf(parts[0]);
+		parts[0] = String.valueOf(cuNum-1);
+		String old_data = InstructionUtils.concatOperandParts(parts);
+		return(new LineageItem(old_data));
+	}
 
 	private static ExecutionContext getExecutionContext() {
 		if( _lrEC == null )
diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java b/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
index 74b055e..c7a32b0 100644
--- a/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
@@ -39,7 +39,7 @@
 	
 	protected static final String TEST_DIR = "functions/lineage/";
 	protected static final String TEST_NAME = "LineageReuseAlg";
-	protected static final int TEST_VARIANTS = 5;
+	protected static final int TEST_VARIANTS = 6;
 	protected String TEST_CLASS_DIR = TEST_DIR + LineageReuseAlg.class.getSimpleName() + "/";
 	
 	@Override
@@ -73,6 +73,11 @@
 	public void testGridSearchL2svmHybrid() {
 		testLineageTrace(TEST_NAME+"5", ReuseCacheType.REUSE_HYBRID);
 	}
+
+	@Test
+	public void testPCA_LM_pipeline() {
+		testLineageTrace(TEST_NAME+"6", ReuseCacheType.REUSE_HYBRID);
+	}
 	
 	@Test
 	public void testStepLMFull() {
diff --git a/src/test/scripts/functions/lineage/LineageReuseAlg6.dml b/src/test/scripts/functions/lineage/LineageReuseAlg6.dml
new file mode 100644
index 0000000..6d0c14d
--- /dev/null
+++ b/src/test/scripts/functions/lineage/LineageReuseAlg6.dml
@@ -0,0 +1,97 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+#PCA -> LM pipeline
+
+checkR2 = function(Matrix[double] X, Matrix[double] y, Matrix[double] y_p,
+          Matrix[double] beta, Integer icpt) return (Double R2_ad)
+{
+  n = nrow(X);
+  m = ncol(X);
+  m_ext = m;
+  if (icpt == 1|icpt == 2)
+      m_ext = m+1; #due to extra column ones
+  avg_tot = sum(y)/n;
+  ss_tot = sum(y^2);
+  ss_avg_tot = ss_tot - n*avg_tot^2;
+  y_res = y - y_p;
+  avg_res = sum(y - y_p)/n;
+  ss_res = sum((y - y_p)^2);
+  R2 = 1 - ss_res/ss_avg_tot;
+  dispersion = ifelse(n>m_ext, ss_res/(n-m_ext), NaN);
+  R2_ad = ifelse(n>m_ext, 1-dispersion/(ss_avg_tot/(n-1)), NaN);
+}
+
+PCA = function(Matrix[Double] A, Integer K = ncol(A), Integer center = 1, Integer scale = 1,
+    Integer projectData = 1) return(Matrix[Double] newA)
+{
+  N = nrow(A);
+  D = ncol(A);
+
+  # perform z-scoring (centering and scaling)
+  A = scale(A, center==1, scale==1);
+
+  # co-variance matrix
+  mu = colSums(A)/N;
+  C = (t(A) %*% A)/(N-1) - (N/(N-1))*t(mu) %*% mu;
+
+  # compute eigen vectors and values
+  [evalues, evectors] = eigen(C);
+
+  decreasing_Idx = order(target=evalues,by=1,decreasing=TRUE,index.return=TRUE);
+  diagmat = table(seq(1,D),decreasing_Idx);
+  # sorts eigenvalues by decreasing order
+  evalues = diagmat %*% evalues;
+  # sorts eigenvectors column-wise in the order of decreasing eigenvalues
+  evectors = evectors %*% diagmat;
+
+
+  # select K dominant eigen vectors
+  nvec = ncol(evectors);
+
+  eval_dominant = evalues[1:K, 1];
+  evec_dominant = evectors[,1:K];
+
+  # the square root of eigenvalues
+  eval_stdev_dominant = sqrt(eval_dominant);
+
+  if (projectData == 1){
+    # Construct new data set by treating computed dominant eigenvectors as the basis vectors
+    newA = A %*% evec_dominant;
+  }
+}
+
+M = 1000;
+A = rand(rows=M, cols=100, seed=42);
+y = rand(rows=M, cols=1, seed=1);
+R = matrix(0, rows=1, cols=20);
+
+Kc = floor(ncol(A) * 0.8);
+
+for (i in 1:10) {
+  newA1 = PCA(A=A, K=Kc+i);
+  beta1 = lm(X=newA1, y=y, icpt=1, reg=0.0001, verbose=FALSE);
+  y_predict1 = lmpredict(X=newA1, w=beta1, icpt=1);
+  R2_ad1 = checkR2(newA1, y, y_predict1, beta1, 1);
+  R[,i] = R2_ad1;
+}
+
+write(R, $1, format="text");