[SYSTEMDS-414] New rewrite for PCA -> lmDS pipeline
This patch contains a rewrite to reuse tsmm result in lmDS if
called after PCA incrementally for increasing number of columns.
diff --git a/dev/Tasks.txt b/dev/Tasks.txt
index 8c0da13..689949c 100644
--- a/dev/Tasks.txt
+++ b/dev/Tasks.txt
@@ -339,6 +339,7 @@
* 411 Improved handling of multi-level cache duplicates OK
* 412 Robust lineage tracing (non-recursive, parfor) OK
* 413 Cache and reuse MultiReturnBuiltin instructions OK
+ * 414 New rewrite for PCA --> lmDS pipeline OK
SYSTEMDS-500 Documentation Webpage Reintroduction
* 501 Make Documentation webpage framework OK
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
index c25f132..829ab43 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
@@ -56,6 +56,7 @@
import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionParser;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
import org.apache.sysds.runtime.instructions.cp.DataGenCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
@@ -114,6 +115,8 @@
newInst = (newInst == null) ? rewriteAggregateCbind(curr, ec, lrwec) : newInst;
//A %*% B[,1:k] = (A %*% B)[,1:k];
newInst = (newInst == null) ? rewriteIndexingMatMul(curr, ec, lrwec) : newInst;
+ //PCA --> lmDS pipeline
+ newInst = (newInst == null) ? rewritePcaTsmm(curr, ec, lrwec) : newInst;
if (newInst == null)
return false;
@@ -664,6 +667,65 @@
LineageCacheStatistics.incrementPRewrites();
return inst;
}
+
+ private static ArrayList<Instruction> rewritePcaTsmm(Instruction curr, ExecutionContext ec, ExecutionContext lrwec)
+ {
+ Map<String, MatrixBlock> inCache = new HashMap<>();
+ if (!isPcaTsmm(curr, ec, inCache))
+ return null;
+
+ // Create a transient read op over the last tsmm result
+ MatrixBlock cachedEntry = inCache.get("lastMatrix");
+ MatrixObject newmo = convMBtoMO(cachedEntry);
+ lrwec.setVariable("cachedEntry", newmo);
+ DataOp lastRes = HopRewriteUtils.createTransientRead("cachedEntry", cachedEntry);
+
+ // Create a transient read op over this tsmm's input
+ MatrixObject mo = ec.getMatrixObject(((ComputationCPInstruction)curr).input1);
+ lrwec.setVariable("oldMatrix", mo);
+ DataOp newMatrix = HopRewriteUtils.createTransientRead("oldMatrix", mo);
+
+ // Index out the added column from the projected matrix
+ MatrixBlock projected = inCache.get("projected");
+ MatrixObject projmo = convMBtoMO(projected);
+ lrwec.setVariable("projected", projmo);
+ DataOp projRes = HopRewriteUtils.createTransientRead("projected", projmo);
+ IndexingOp lastCol = HopRewriteUtils.createIndexingOp(projRes, new LiteralOp(1), new LiteralOp(projmo.getNumRows()),
+ new LiteralOp(projmo.getNumColumns()), new LiteralOp(projmo.getNumColumns()));
+
+ // Apply t(lastCol) on i/p matrix to get the result vectors.
+ ReorgOp tlastCol = HopRewriteUtils.createTranspose(lastCol);
+ AggBinaryOp newCol = HopRewriteUtils.createMatrixMultiply(tlastCol, newMatrix);
+ ReorgOp tnewCol = HopRewriteUtils.createTranspose(newCol);
+
+ // Push the result row & column inside the cashed block as 2nd last row and col respectively.
+ IndexingOp topLeft = HopRewriteUtils.createIndexingOp(lastRes, new LiteralOp(1), new LiteralOp(newmo.getNumRows()-1),
+ new LiteralOp(1), new LiteralOp(newmo.getNumColumns()-1));
+ IndexingOp topRight = HopRewriteUtils.createIndexingOp(lastRes, new LiteralOp(1), new LiteralOp(newmo.getNumRows()-1),
+ new LiteralOp(newmo.getNumColumns()), new LiteralOp(newmo.getNumColumns()));
+ IndexingOp bottomLeft = HopRewriteUtils.createIndexingOp(lastRes, new LiteralOp(newmo.getNumRows()),
+ new LiteralOp(newmo.getNumRows()), new LiteralOp(1), new LiteralOp(newmo.getNumColumns()-1));
+ IndexingOp bottomRight = HopRewriteUtils.createIndexingOp(lastRes, new LiteralOp(newmo.getNumRows()),
+ new LiteralOp(newmo.getNumRows()), new LiteralOp(newmo.getNumColumns()), new LiteralOp(newmo.getNumColumns()));
+ IndexingOp topCol = HopRewriteUtils.createIndexingOp(tnewCol, new LiteralOp(1), new LiteralOp(mo.getNumColumns()-2),
+ new LiteralOp(1), new LiteralOp(1));
+ IndexingOp bottomCol = HopRewriteUtils.createIndexingOp(tnewCol, new LiteralOp(mo.getNumColumns()),
+ new LiteralOp(mo.getNumColumns()), new LiteralOp(1), new LiteralOp(1));
+ NaryOp rowOne = HopRewriteUtils.createNary(OpOpN.CBIND, topLeft, topCol, topRight);
+ NaryOp rowTwo = HopRewriteUtils.createNary(OpOpN.CBIND, bottomLeft, bottomCol, bottomRight);
+ NaryOp lrwHop = HopRewriteUtils.createNary(OpOpN.RBIND, rowOne, newCol, rowTwo);
+ DataOp lrwWrite = HopRewriteUtils.createTransientWrite(LR_VAR, lrwHop);
+
+ // Generate runtime instructions
+ if (LOG.isDebugEnabled())
+ LOG.debug("LINEAGE REWRITE rewritePcaTsmm APPLIED");
+ ArrayList<Instruction> inst = genInst(lrwWrite, lrwec);
+ _disableReuse = true;
+
+ if (DMLScript.STATISTICS)
+ LineageCacheStatistics.incrementPRewrites();
+ return inst;
+ }
/*------------------------REWRITE APPLICABILITY CHECKS-------------------------*/
@@ -963,6 +1025,41 @@
return inCache.containsKey("indexSource") ? true : false;
}
+ private static boolean isPcaTsmm(Instruction curr, ExecutionContext ec, Map<String, MatrixBlock> inCache) {
+ if (!LineageCacheConfig.isReusable(curr, ec)) {
+ return false;
+ }
+
+ LineageItem item = ((ComputationCPInstruction) curr).getLineageItem(ec).getValue();
+ if (curr.getOpcode().equalsIgnoreCase("tsmm")) {
+ LineageItem src1 = item.getInputs()[0];
+ if (src1.getOpcode().equalsIgnoreCase("cbind")) {
+ LineageItem src21 = src1.getInputs()[0];
+ LineageItem src22 = src1.getInputs()[1]; //ones
+ if (src21.getOpcode().equalsIgnoreCase("ba+*")) {
+ if (LineageCache.probe(src21))
+ inCache.put("projected", LineageCache.getMatrix(src21));
+
+ LineageItem src31 = src21.getInputs()[1];
+ LineageItem src32 = src21.getInputs()[0];
+ if (src31.getOpcode().equalsIgnoreCase("rightIndex")) {
+ LineageItem cu = src31.getInputs()[4];
+ //TODO: delta with more than one column
+ LineageItem old_cu = reduceColByOne(cu);
+ LineageItem old_RI = new LineageItem("rightIndex", new LineageItem[] {src31.getInputs()[0],
+ src31.getInputs()[1], src31.getInputs()[2], src31.getInputs()[3], old_cu});
+ LineageItem old_ba = new LineageItem("ba+*", new LineageItem[] {src32, old_RI});
+ LineageItem old_cbind = new LineageItem("cbind", new LineageItem[] {old_ba, src22});
+ LineageItem old_tsmm = new LineageItem("tsmm", new LineageItem[] {old_cbind});
+ if (LineageCache.probe(old_tsmm))
+ inCache.put("lastMatrix", LineageCache.getMatrix(old_tsmm));
+ }
+ }
+ }
+ }
+ return inCache.containsKey("projected") && inCache.containsKey("lastMatrix");
+ }
+
/*----------------------INSTRUCTIONS GENERATION & EXECUTION-----------------------*/
private static ArrayList<Instruction> genInst(Hop hops, ExecutionContext ec) {
@@ -1006,6 +1103,15 @@
mo.release();
return mo;
}
+
+ private static LineageItem reduceColByOne(LineageItem cu) {
+ String data = cu.getData(); //xx·SCALAR·INT64·true
+ String[] parts = data.split(Instruction.VALUETYPE_PREFIX);
+ int cuNum = Integer.valueOf(parts[0]);
+ parts[0] = String.valueOf(cuNum-1);
+ String old_data = InstructionUtils.concatOperandParts(parts);
+ return(new LineageItem(old_data));
+ }
private static ExecutionContext getExecutionContext() {
if( _lrEC == null )
diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java b/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
index 74b055e..c7a32b0 100644
--- a/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
@@ -39,7 +39,7 @@
protected static final String TEST_DIR = "functions/lineage/";
protected static final String TEST_NAME = "LineageReuseAlg";
- protected static final int TEST_VARIANTS = 5;
+ protected static final int TEST_VARIANTS = 6;
protected String TEST_CLASS_DIR = TEST_DIR + LineageReuseAlg.class.getSimpleName() + "/";
@Override
@@ -73,6 +73,11 @@
public void testGridSearchL2svmHybrid() {
testLineageTrace(TEST_NAME+"5", ReuseCacheType.REUSE_HYBRID);
}
+
+ @Test
+ public void testPCA_LM_pipeline() {
+ testLineageTrace(TEST_NAME+"6", ReuseCacheType.REUSE_HYBRID);
+ }
@Test
public void testStepLMFull() {
diff --git a/src/test/scripts/functions/lineage/LineageReuseAlg6.dml b/src/test/scripts/functions/lineage/LineageReuseAlg6.dml
new file mode 100644
index 0000000..6d0c14d
--- /dev/null
+++ b/src/test/scripts/functions/lineage/LineageReuseAlg6.dml
@@ -0,0 +1,97 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+#PCA -> LM pipeline
+
+checkR2 = function(Matrix[double] X, Matrix[double] y, Matrix[double] y_p,
+ Matrix[double] beta, Integer icpt) return (Double R2_ad)
+{
+ n = nrow(X);
+ m = ncol(X);
+ m_ext = m;
+ if (icpt == 1|icpt == 2)
+ m_ext = m+1; #due to extra column ones
+ avg_tot = sum(y)/n;
+ ss_tot = sum(y^2);
+ ss_avg_tot = ss_tot - n*avg_tot^2;
+ y_res = y - y_p;
+ avg_res = sum(y - y_p)/n;
+ ss_res = sum((y - y_p)^2);
+ R2 = 1 - ss_res/ss_avg_tot;
+ dispersion = ifelse(n>m_ext, ss_res/(n-m_ext), NaN);
+ R2_ad = ifelse(n>m_ext, 1-dispersion/(ss_avg_tot/(n-1)), NaN);
+}
+
+PCA = function(Matrix[Double] A, Integer K = ncol(A), Integer center = 1, Integer scale = 1,
+ Integer projectData = 1) return(Matrix[Double] newA)
+{
+ N = nrow(A);
+ D = ncol(A);
+
+ # perform z-scoring (centering and scaling)
+ A = scale(A, center==1, scale==1);
+
+ # co-variance matrix
+ mu = colSums(A)/N;
+ C = (t(A) %*% A)/(N-1) - (N/(N-1))*t(mu) %*% mu;
+
+ # compute eigen vectors and values
+ [evalues, evectors] = eigen(C);
+
+ decreasing_Idx = order(target=evalues,by=1,decreasing=TRUE,index.return=TRUE);
+ diagmat = table(seq(1,D),decreasing_Idx);
+ # sorts eigenvalues by decreasing order
+ evalues = diagmat %*% evalues;
+ # sorts eigenvectors column-wise in the order of decreasing eigenvalues
+ evectors = evectors %*% diagmat;
+
+
+ # select K dominant eigen vectors
+ nvec = ncol(evectors);
+
+ eval_dominant = evalues[1:K, 1];
+ evec_dominant = evectors[,1:K];
+
+ # the square root of eigenvalues
+ eval_stdev_dominant = sqrt(eval_dominant);
+
+ if (projectData == 1){
+ # Construct new data set by treating computed dominant eigenvectors as the basis vectors
+ newA = A %*% evec_dominant;
+ }
+}
+
+M = 1000;
+A = rand(rows=M, cols=100, seed=42);
+y = rand(rows=M, cols=1, seed=1);
+R = matrix(0, rows=1, cols=20);
+
+Kc = floor(ncol(A) * 0.8);
+
+for (i in 1:10) {
+ newA1 = PCA(A=A, K=Kc+i);
+ beta1 = lm(X=newA1, y=y, icpt=1, reg=0.0001, verbose=FALSE);
+ y_predict1 = lmpredict(X=newA1, w=beta1, icpt=1);
+ R2_ad1 = checkR2(newA1, y, y_predict1, beta1, 1);
+ R[,i] = R2_ad1;
+}
+
+write(R, $1, format="text");