| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.hops.rewrite; |
| |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.HashSet; |
| import java.util.List; |
| |
| import org.apache.commons.logging.Log; |
| import org.apache.commons.logging.LogFactory; |
| import org.apache.sysds.common.Types.AggOp; |
| import org.apache.sysds.common.Types.OpOp1; |
| import org.apache.sysds.common.Types.OpOp2; |
| import org.apache.sysds.common.Types.OpOp3; |
| import org.apache.sysds.common.Types.OpOpData; |
| import org.apache.sysds.conf.ConfigurationManager; |
| import org.apache.sysds.hops.AggBinaryOp; |
| import org.apache.sysds.hops.FunctionOp; |
| import org.apache.sysds.hops.Hop; |
| import org.apache.sysds.hops.OptimizerUtils; |
| import org.apache.sysds.lops.Compression.CompressConfig; |
| import org.apache.sysds.parser.DMLProgram; |
| import org.apache.sysds.parser.ForStatement; |
| import org.apache.sysds.parser.ForStatementBlock; |
| import org.apache.sysds.parser.FunctionStatement; |
| import org.apache.sysds.parser.FunctionStatementBlock; |
| import org.apache.sysds.parser.IfStatement; |
| import org.apache.sysds.parser.IfStatementBlock; |
| import org.apache.sysds.parser.StatementBlock; |
| import org.apache.sysds.parser.WhileStatement; |
| import org.apache.sysds.parser.WhileStatementBlock; |
| import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; |
| import org.apache.sysds.runtime.matrix.data.MatrixBlock; |
| |
| /** |
| * Rule: Compressed Re block if config compressed.linalg is enabled, we inject compression directions after read of |
| * matrices if number of rows is above 1000 and cols at least 1. |
| * |
| * In case of 'auto' compression, we apply compression if the data size is known to exceed aggregate cluster memory, the |
| * matrix is used in loops, and all operations are supported over compressed matrices. |
| */ |
| public class RewriteCompressedReblock extends StatementBlockRewriteRule { |
| private static final Log LOG = LogFactory.getLog(RewriteCompressedReblock.class.getName()); |
| |
| private static final String TMP_PREFIX = "__cmtx"; |
| |
| @Override |
| public boolean createsSplitDag() { |
| return false; |
| } |
| |
| @Override |
| public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus sate) { |
| // check for inapplicable statement blocks |
| if(!HopRewriteUtils.isLastLevelStatementBlock(sb) || sb.getHops() == null) |
| return Arrays.asList(sb); |
| |
| // parse compression config |
| final CompressConfig compress = ConfigurationManager.getCompressConfig(); |
| |
| // perform compressed reblock rewrite |
| if(compress.isEnabled()) { |
| Hop.resetVisitStatus(sb.getHops()); |
| for(Hop h : sb.getHops()) |
| injectCompressionDirective(h, compress, sb.getDMLProg()); |
| Hop.resetVisitStatus(sb.getHops()); |
| } |
| return Arrays.asList(sb); |
| } |
| |
| @Override |
| public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate) { |
| return sbs; |
| } |
| |
| private static void injectCompressionDirective(Hop hop, CompressConfig compress, DMLProgram prog) { |
| if(hop.isVisited() || hop.requiresCompression() || hop.hasCompressedInput()) |
| return; |
| |
| // recursively process children |
| for(Hop hi : hop.getInput()) |
| injectCompressionDirective(hi, compress, prog); |
| |
| // check for compression conditions |
| switch(compress) { |
| case TRUE: |
| if(satisfiesCompressionCondition(hop)) |
| hop.setRequiresCompression(); |
| break; |
| case AUTO: |
| if(OptimizerUtils.isSparkExecutionMode() && satisfiesAutoCompressionCondition(hop, prog)) |
| hop.setRequiresCompression(); |
| break; |
| case COST: |
| if(satisfiesCostCompressionCondition(hop, prog)) |
| hop.setRequiresCompression(); |
| break; |
| default: |
| break; |
| } |
| |
| if(satisfiesDeCompressionCondition(hop)) { |
| hop.setRequiresDeCompression(); |
| } |
| |
| hop.setVisited(); |
| } |
| |
| public static boolean satisfiesSizeConstraintsForCompression(Hop hop) { |
| if(hop.getDim2() >= 1) { |
| final long x = hop.getDim1(); |
| final long y = hop.getDim2(); |
| return |
| // If the Cube of the number of rows is greater than multiplying the number of columns by 1024. |
| y << 10 <= x * x |
| // is very sparse and at least 100 rows. |
| || (hop.getSparsity() < 0.0001 && y > 100); |
| } |
| return false; |
| } |
| |
| public static boolean satisfiesCompressionCondition(Hop hop) { |
| boolean satisfies = false; |
| if(satisfiesSizeConstraintsForCompression(hop)) |
| satisfies |= HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD); |
| |
| return satisfies; |
| } |
| |
| public static boolean satisfiesAggressiveCompressionCondition(Hop hop) { |
| //size-independent conditions (robust against unknowns) |
| boolean satisfies = HopRewriteUtils.isTernary(hop, OpOp3.CTABLE) //matrix (no vector) ctable |
| && hop.getInput(0).getDataType().isMatrix() && hop.getInput(1).getDataType().isMatrix(); |
| //size-dependent conditions |
| if(satisfiesSizeConstraintsForCompression(hop)) { |
| satisfies |= HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD); |
| satisfies |= HopRewriteUtils.isUnary(hop, OpOp1.ROUND, OpOp1.FLOOR, OpOp1.NOT, OpOp1.CEIL); |
| satisfies |= HopRewriteUtils.isBinary(hop, OpOp2.EQUAL, OpOp2.NOTEQUAL, OpOp2.LESS, |
| OpOp2.LESSEQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL, OpOp2.AND, OpOp2.OR, OpOp2.MODULUS); |
| satisfies |= HopRewriteUtils.isTernary(hop, OpOp3.CTABLE); |
| } |
| if(LOG.isDebugEnabled() && satisfies) |
| LOG.debug("Operation Satisfies: " + hop); |
| return satisfies; |
| } |
| |
| private static boolean satisfiesDeCompressionCondition(Hop hop) { |
| // TODO decompression Condition |
| return false; |
| } |
| |
| private static boolean outOfCore(Hop hop) { |
| double matrixPSize = OptimizerUtils.estimatePartitionedSizeExactSparsity(hop); |
| double cacheSize = SparkExecutionContext.getDataMemoryBudget(true, true); |
| return matrixPSize > cacheSize; |
| } |
| |
| private static boolean ultraSparse(Hop hop) { |
| double sparsity = OptimizerUtils.getSparsity(hop); |
| return sparsity < MatrixBlock.ULTRA_SPARSITY_TURN_POINT; |
| } |
| |
| private static boolean satisfiesAutoCompressionCondition(Hop hop, DMLProgram prog) { |
| // check for basic compression condition |
| if(!(satisfiesCompressionCondition(hop) && hop.getMemEstimate() >= OptimizerUtils.getLocalMemBudget())) |
| return false; |
| |
| // determine if all operations are supported over compressed matrices, |
| // but conditionally only if all other conditions are met |
| if(hop.dimsKnown(true) && outOfCore(hop) && !ultraSparse(hop)) { |
| return analyseProgram(hop, prog).isValidAutoCompression(); |
| } |
| |
| return false; |
| } |
| |
| private static boolean satisfiesCostCompressionCondition(Hop hop, DMLProgram prog) { |
| boolean satisfies = true; |
| satisfies &= satisfiesAggressiveCompressionCondition(hop); |
| satisfies &= hop.dimsKnown(false); |
| satisfies &= analyseProgram(hop, prog).isValidAggressiveCompression(); |
| return satisfies; |
| |
| } |
| |
| private static ProbeStatus analyseProgram(Hop hop, DMLProgram prog) { |
| ProbeStatus status = new ProbeStatus(hop.getHopID(), prog); |
| for(StatementBlock sb : prog.getStatementBlocks()) |
| status.rAnalyzeProgram(sb); |
| return status; |
| } |
| |
| private static class ProbeStatus { |
| private final long startHopID; |
| private final DMLProgram prog; |
| |
| private int numberCompressedOpsExecuted = 0; |
| private int numberDecompressedOpsExecuted = 0; |
| private int inefficientSupportedOpsExecuted = 0; |
| // private int superEfficientSupportedOpsExecuted = 0; |
| |
| private boolean foundStart = false; |
| private boolean usedInLoop = false; |
| private boolean condUpdate = false; |
| private boolean nonApplicable = false; |
| |
| private HashSet<String> procFn = new HashSet<>(); |
| private HashSet<String> compMtx = new HashSet<>(); |
| |
| private ProbeStatus(long hopID, DMLProgram p) { |
| startHopID = hopID; |
| prog = p; |
| } |
| |
| private ProbeStatus(ProbeStatus status) { |
| startHopID = status.startHopID; |
| prog = status.prog; |
| foundStart = status.foundStart; |
| usedInLoop = status.usedInLoop; |
| condUpdate = status.condUpdate; |
| nonApplicable = status.nonApplicable; |
| procFn.addAll(status.procFn); |
| } |
| |
| private void rAnalyzeProgram(StatementBlock sb) { |
| if(sb instanceof FunctionStatementBlock) { |
| FunctionStatementBlock fsb = (FunctionStatementBlock) sb; |
| FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0); |
| for(StatementBlock csb : fstmt.getBody()) |
| rAnalyzeProgram(csb); |
| } |
| else if(sb instanceof WhileStatementBlock) { |
| WhileStatementBlock wsb = (WhileStatementBlock) sb; |
| WhileStatement wstmt = (WhileStatement) wsb.getStatement(0); |
| for(StatementBlock csb : wstmt.getBody()) |
| rAnalyzeProgram(csb); |
| if(wsb.variablesRead().containsAnyName(compMtx)) |
| usedInLoop = true; |
| } |
| else if(sb instanceof IfStatementBlock) { |
| IfStatementBlock isb = (IfStatementBlock) sb; |
| IfStatement istmt = (IfStatement) isb.getStatement(0); |
| for(StatementBlock csb : istmt.getIfBody()) |
| rAnalyzeProgram(csb); |
| for(StatementBlock csb : istmt.getElseBody()) |
| rAnalyzeProgram(csb); |
| if(isb.variablesUpdated().containsAnyName(compMtx)) |
| condUpdate = true; |
| } |
| else if(sb instanceof ForStatementBlock) { // incl parfor |
| ForStatementBlock fsb = (ForStatementBlock) sb; |
| ForStatement fstmt = (ForStatement) fsb.getStatement(0); |
| for(StatementBlock csb : fstmt.getBody()) |
| rAnalyzeProgram(csb); |
| if(fsb.variablesRead().containsAnyName(compMtx)) |
| usedInLoop = true; |
| } |
| else if(sb.getHops() != null) { // generic (last-level) |
| ArrayList<Hop> roots = sb.getHops(); |
| Hop.resetVisitStatus(roots); |
| // process entire HOP DAG starting from the roots |
| for(Hop root : roots) |
| rAnalyzeHopDag(root); |
| // remove temporary variables |
| compMtx.removeIf(n -> n.startsWith(TMP_PREFIX)); |
| Hop.resetVisitStatus(roots); |
| } |
| } |
| |
| private void rAnalyzeHopDag(Hop current) { |
| if(current.isVisited()) |
| return; |
| |
| // process children recursively |
| for(Hop input : current.getInput()) |
| rAnalyzeHopDag(input); |
| |
| // handle source persistent read |
| if(current.getHopID() == startHopID) { |
| compMtx.add(getTmpName(current)); |
| foundStart = true; |
| } |
| |
| // 1) handle transient reads and writes (name mapping) |
| if(HopRewriteUtils.isData(current, OpOpData.TRANSIENTWRITE) && |
| compMtx.contains(getTmpName(current.getInput().get(0)))) |
| compMtx.add(current.getName()); |
| else if(HopRewriteUtils.isData(current, OpOpData.TRANSIENTREAD) && compMtx.contains(current.getName())) |
| compMtx.add(getTmpName(current)); |
| // handle individual hops |
| else if(hasCompressedInput(current)) { |
| if(current instanceof FunctionOp) |
| handleFunctionOps(current); |
| else |
| handleApplicableOps(current); |
| } |
| current.setVisited(); |
| } |
| |
| private boolean hasCompressedInput(Hop hop) { |
| if(compMtx.isEmpty()) |
| return false; |
| for(Hop input : hop.getInput()) |
| if(compMtx.contains(getTmpName(input))) |
| return true; |
| return false; |
| } |
| |
| private static String getTmpName(Hop hop) { |
| return TMP_PREFIX + hop.getHopID(); |
| } |
| |
| private boolean isCompressed(Hop hop) { |
| return compMtx.contains(getTmpName(hop)); |
| } |
| |
| private void handleFunctionOps(Hop current) { |
| // TODO handle of functions in a more fine-grained manner |
| // to cover special cases multiple calls where compressed |
| // inputs might occur for different input parameters |
| |
| FunctionOp fop = (FunctionOp) current; |
| String fkey = fop.getFunctionKey(); |
| if(!procFn.contains(fkey)) { |
| // memoization to avoid redundant analysis and recursive calls |
| procFn.add(fkey); |
| // map inputs to function inputs |
| FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fkey); |
| FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0); |
| ProbeStatus status2 = new ProbeStatus(this); |
| for(int i = 0; i < fop.getInput().size(); i++) |
| if(compMtx.contains(getTmpName(fop.getInput().get(i)))) |
| status2.compMtx.add(fstmt.getInputParams().get(i).getName()); |
| // analyze function and merge meta info |
| status2.rAnalyzeProgram(fsb); |
| foundStart |= status2.foundStart; |
| usedInLoop |= status2.usedInLoop; |
| condUpdate |= status2.condUpdate; |
| nonApplicable |= status2.nonApplicable; |
| numberCompressedOpsExecuted += status2.numberCompressedOpsExecuted; |
| numberDecompressedOpsExecuted += status2.numberDecompressedOpsExecuted; |
| // map function outputs to outputs |
| String[] outputs = fop.getOutputVariableNames(); |
| for(int i = 0; i < outputs.length; i++) |
| if(status2.compMtx.contains(fstmt.getOutputParams().get(i).getName())) |
| compMtx.add(outputs[i]); |
| } |
| } |
| |
| private void handleApplicableOps(Hop current) { |
| // Valid with uncompressed outputs |
| boolean compUCOut = false; |
| // // tsmm |
| // compUCOut |= (current instanceof AggBinaryOp && current.getDim2() <= current.getBlocksize() && |
| // ((AggBinaryOp) current).checkTransposeSelf() == MMTSJType.LEFT); |
| |
| // // mvmm |
| // compUCOut |= (current instanceof AggBinaryOp && (current.getDim1() == 1 || current.getDim2() == 1)); |
| // compUCOut |= (HopRewriteUtils.isTransposeOperation(current) && current.getParent().size() == 1 && |
| // current.getParent().get(0) instanceof AggBinaryOp && |
| // (current.getParent().get(0).getDim1() == 1 || current.getParent().get(0).getDim2() == 1)); |
| |
| compUCOut |= (current instanceof AggBinaryOp); |
| compUCOut |= HopRewriteUtils.isBinaryMatrixColVectorOperation(current); |
| |
| boolean isAggregate = HopRewriteUtils |
| .isAggUnaryOp(current, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.MEAN); |
| |
| // If the aggregation function is done row wise. |
| if(isAggregate && current.getDim2() < 2 && current.getDim1() >= 1000) |
| inefficientSupportedOpsExecuted++; |
| |
| compUCOut |= isAggregate; |
| |
| // Valid compressed |
| boolean compCOut = false; |
| |
| // Compressed Output if the operation is Binary scalar |
| compCOut |= HopRewriteUtils.isBinaryMatrixScalarOperation(current); |
| compCOut |= HopRewriteUtils.isBinaryMatrixRowVectorOperation(current); |
| |
| // Compressed Output possible through overlapping matrix.if the operation is right Matrix Multiply |
| compCOut |= (current instanceof AggBinaryOp) && isCompressed(current.getInput().get(0)); |
| compUCOut = compCOut ? false : compUCOut; |
| |
| // Compressed Output if the operation is column bind. |
| compCOut |= HopRewriteUtils.isBinary(current, OpOp2.CBIND); |
| |
| boolean metaOp = HopRewriteUtils.isUnary(current, OpOp1.NROW, OpOp1.NCOL); |
| boolean ctableOp = HopRewriteUtils.isTernary(current, OpOp3.CTABLE); |
| |
| if(ctableOp) { |
| numberCompressedOpsExecuted += 4; |
| compCOut = true; |
| } |
| |
| boolean applicable = compUCOut || compCOut || metaOp; |
| |
| if(applicable) |
| numberCompressedOpsExecuted++; |
| else { |
| LOG.warn("Decompession op: " + current); |
| numberDecompressedOpsExecuted++; |
| } |
| |
| nonApplicable |= !(applicable); |
| |
| if(compCOut) |
| compMtx.add(getTmpName(current)); |
| } |
| |
| private boolean isValidAutoCompression() { |
| return foundStart && usedInLoop && !condUpdate && !nonApplicable; |
| } |
| |
| private boolean isValidAggressiveCompression() { |
| if(LOG.isDebugEnabled()) |
| LOG.debug(this.toString()); |
| return (inefficientSupportedOpsExecuted < numberCompressedOpsExecuted) && |
| (usedInLoop || numberCompressedOpsExecuted > 3) && numberDecompressedOpsExecuted < 1; |
| } |
| |
| @Override |
| public String toString() { |
| StringBuilder sb = new StringBuilder(); |
| sb.append("Compressed ProbeStatus : hopID =" + startHopID); |
| sb.append("\n CLA Ops : " + numberCompressedOpsExecuted); |
| sb.append("\n Decompress Ops : " + numberDecompressedOpsExecuted); |
| sb.append("\n Inefficient Ops : " + inefficientSupportedOpsExecuted); |
| sb.append("\n foundStart " + foundStart + " , inLoop :" + usedInLoop + " , condUpdate : " + condUpdate |
| + " , nonApplicable : " + nonApplicable); |
| sb.append("\n compressed Matrix: " + compMtx); |
| sb.append("\n Prog Fn " + procFn); |
| return sb.toString(); |
| } |
| |
| } |
| } |