[SYSTEMDS-419] Performance lineage-based reuse (partial rewrites)
This patch makes a minor performance improvement to the important
partial rewrite tsmm(cbind(X,v)) to tsmm(X) + compensation plan, by
avoiding cbind(X, v)[,1:n-1] to extract X if X is still available in the
lineage cache. This avoids unnecessary allocation and copies.
diff --git a/dev/Tasks.txt b/dev/Tasks.txt
index 1b2e94c..3bf07c7 100644
--- a/dev/Tasks.txt
+++ b/dev/Tasks.txt
@@ -345,6 +345,7 @@
* 416 Lineage deduplication while, nested if, loop sequences OK
* 417 New rewrite for partial reuse in StepLM OK
* 418 Performance lineage tracing and reuse probing small data OK
+ * 419 Performance and robustness partial rewrites
SYSTEMDS-420 Compiler Improvements
* 421 Fix invalid IPA scalar propagation into functions OK
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index 30f66f4..1412ffc 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -704,6 +704,10 @@
return createIndexingOp(input, row, row, col, col);
}
+ public static IndexingOp createIndexingOp(Hop input, long rl, long ru, long cl, long cu) {
+ return createIndexingOp(input, new LiteralOp(rl), new LiteralOp(ru), new LiteralOp(cl), new LiteralOp(cu));
+ }
+
public static IndexingOp createIndexingOp(Hop input, Hop rl, Hop ru, Hop cl, Hop cu) {
IndexingOp ix = new IndexingOp("tmp", DataType.MATRIX, ValueType.FP64, input, rl, ru, cl, cu, rl==ru, cl==cu);
ix.setBlocksize(input.getBlocksize());
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
index da1b8fc..32dedcf 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
@@ -158,25 +158,23 @@
MatrixObject mo = ec.getMatrixObject(((ComputationCPInstruction)curr).input1);
lrwec.setVariable("oldMatrix", mo);
DataOp newMatrix = HopRewriteUtils.createTransientRead("oldMatrix", mo);
- IndexingOp oldMatrix = HopRewriteUtils.createIndexingOp(newMatrix, new LiteralOp(1),
- new LiteralOp(mo.getNumRows()), new LiteralOp(1), new LiteralOp(mo.getNumColumns()-1));
- Hop lastCol;
+
+ // Use X from cache, or create rightIndex
+ Hop oldMatrix = inCache.containsKey("X") ?
+ setupTReadCachedInput("X", inCache, lrwec) :
+ HopRewriteUtils.createIndexingOp(newMatrix, 1L, mo.getNumRows(), 1L, mo.getNumColumns()-1);
+
// Use deltaX from cache, or create rightIndex
- if (inCache.containsKey("deltaX")) {
- MatrixBlock cachedRI = inCache.get("deltaX");
- lrwec.setVariable("deltaX", convMBtoMO(cachedRI));
- lastCol = HopRewriteUtils.createTransientRead("deltaX", cachedRI);
- }
- else
- lastCol = HopRewriteUtils.createIndexingOp(newMatrix, new LiteralOp(1), new LiteralOp(mo.getNumRows()),
- new LiteralOp(mo.getNumColumns()), new LiteralOp(mo.getNumColumns()));
- // cell topRight = t(oldMatrix) %*% lastCol
- ReorgOp tOldM = HopRewriteUtils.createTranspose(oldMatrix);
- AggBinaryOp topRight = HopRewriteUtils.createMatrixMultiply(tOldM, lastCol);
- // cell bottomLeft = t(lastCol) %*% oldMatrix = t(topRight)
- ReorgOp bottomLeft = HopRewriteUtils.createTranspose(topRight);
- // bottomRight = t(lastCol) %*% lastCol
+ Hop lastCol = inCache.containsKey("deltaX") ?
+ setupTReadCachedInput("deltaX", inCache, lrwec) :
+ HopRewriteUtils.createIndexingOp(newMatrix, 1L, mo.getNumRows(), mo.getNumColumns(), mo.getNumColumns());
+
+ // cell bottomLeft = t(lastCol) %*% oldMatrix
ReorgOp tLastCol = HopRewriteUtils.createTranspose(lastCol);
+ AggBinaryOp bottomLeft = HopRewriteUtils.createMatrixMultiply(tLastCol, oldMatrix);
+ // cell topRight = t(oldMatrix) %*% lastCol = t(bottomLeft)
+ ReorgOp topRight = HopRewriteUtils.createTranspose(bottomLeft);
+ // bottomRight = t(lastCol) %*% lastCol
AggBinaryOp bottomRight = HopRewriteUtils.createMatrixMultiply(tLastCol, lastCol);
// rowOne = cbind(lastRes, topRight)
BinaryOp rowOne = HopRewriteUtils.createBinary(lastRes, topRight, OpOp2.CBIND);
@@ -810,12 +808,14 @@
if (curr.getOpcode().equalsIgnoreCase("tsmm")) {
LineageItem source = item.getInputs()[0];
if (source.getOpcode().equalsIgnoreCase("cbind")) {
- //for (LineageItem input : source.getInputs()) {
// create tsmm lineage on top of the input of last append
LineageItem input1 = source.getInputs()[0];
LineageItem tmp = new LineageItem(curr.getOpcode(), new LineageItem[] {input1});
if (LineageCache.probe(tmp))
inCache.put("lastMatrix", LineageCache.getMatrix(tmp));
+ // look for the old matrix in cache
+ if( LineageCache.probe(input1) )
+ inCache.put("X", LineageCache.getMatrix(input1));
// look for the appended column in cache
if (LineageCache.probe(source.getInputs()[1]))
inCache.put("deltaX", LineageCache.getMatrix(source.getInputs()[1]));
@@ -846,6 +846,8 @@
// create tsmm lineage on top of the input of last append
LineageItem input1 = source.getInputs()[0];
LineageItem tmp = new LineageItem(curr.getOpcode(), new LineageItem[] {input1});
+ if( LineageCache.probe(input1) )
+ inCache.put("X", LineageCache.getMatrix(input1));
if (LineageCache.probe(tmp))
inCache.put("lastMatrix", LineageCache.getMatrix(tmp));
}
@@ -1178,6 +1180,12 @@
return newInst;
}
+ private static DataOp setupTReadCachedInput(String name, Map<String, MatrixBlock> inCache, ExecutionContext ec) {
+ MatrixBlock cachedRI = inCache.get(name);
+ ec.setVariable(name, convMBtoMO(cachedRI));
+ return HopRewriteUtils.createTransientRead(name, cachedRI);
+ }
+
private static void executeInst (ArrayList<Instruction> newInst, ExecutionContext lrwec)
{
// Disable explain not to print unnecessary logs