| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysml.hops.ipa; |
| |
| import java.util.ArrayList; |
| import java.util.Collection; |
| import java.util.HashMap; |
| import java.util.HashSet; |
| import java.util.Iterator; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Map.Entry; |
| import java.util.Set; |
| |
| import org.apache.commons.collections.CollectionUtils; |
| import org.apache.commons.logging.Log; |
| import org.apache.commons.logging.LogFactory; |
| import org.apache.log4j.Level; |
| import org.apache.log4j.Logger; |
| import org.apache.sysml.conf.ConfigurationManager; |
| import org.apache.sysml.hops.AggUnaryOp; |
| import org.apache.sysml.hops.BinaryOp; |
| import org.apache.sysml.hops.DataGenOp; |
| import org.apache.sysml.hops.DataOp; |
| import org.apache.sysml.hops.FunctionOp; |
| import org.apache.sysml.hops.FunctionOp.FunctionType; |
| import org.apache.sysml.hops.Hop; |
| import org.apache.sysml.hops.Hop.DataGenMethod; |
| import org.apache.sysml.hops.Hop.DataOpTypes; |
| import org.apache.sysml.hops.Hop.OpOp1; |
| import org.apache.sysml.hops.Hop.OpOp2; |
| import org.apache.sysml.hops.HopsException; |
| import org.apache.sysml.hops.OptimizerUtils; |
| import org.apache.sysml.hops.Hop.VisitStatus; |
| import org.apache.sysml.hops.LiteralOp; |
| import org.apache.sysml.hops.UnaryOp; |
| import org.apache.sysml.hops.rewrite.HopRewriteUtils; |
| import org.apache.sysml.hops.recompile.Recompiler; |
| import org.apache.sysml.parser.DMLProgram; |
| import org.apache.sysml.parser.DMLTranslator; |
| import org.apache.sysml.parser.DataIdentifier; |
| import org.apache.sysml.parser.Expression.DataType; |
| import org.apache.sysml.parser.Expression.ValueType; |
| import org.apache.sysml.parser.ExternalFunctionStatement; |
| import org.apache.sysml.parser.ForStatement; |
| import org.apache.sysml.parser.ForStatementBlock; |
| import org.apache.sysml.parser.FunctionStatement; |
| import org.apache.sysml.parser.FunctionStatementBlock; |
| import org.apache.sysml.parser.IfStatement; |
| import org.apache.sysml.parser.IfStatementBlock; |
| import org.apache.sysml.parser.LanguageException; |
| import org.apache.sysml.parser.ParseException; |
| import org.apache.sysml.parser.StatementBlock; |
| import org.apache.sysml.parser.WhileStatement; |
| import org.apache.sysml.parser.WhileStatementBlock; |
| import org.apache.sysml.runtime.controlprogram.LocalVariableMap; |
| import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; |
| import org.apache.sysml.runtime.instructions.cp.BooleanObject; |
| import org.apache.sysml.runtime.instructions.cp.Data; |
| import org.apache.sysml.runtime.instructions.cp.DoubleObject; |
| import org.apache.sysml.runtime.instructions.cp.IntObject; |
| import org.apache.sysml.runtime.instructions.cp.ScalarObject; |
| import org.apache.sysml.runtime.instructions.cp.StringObject; |
| import org.apache.sysml.runtime.matrix.MatrixCharacteristics; |
| import org.apache.sysml.runtime.matrix.MatrixFormatMetaData; |
| import org.apache.sysml.udf.lib.DeNaNWrapper; |
| import org.apache.sysml.udf.lib.DeNegInfinityWrapper; |
| import org.apache.sysml.udf.lib.DynamicReadMatrixCP; |
| import org.apache.sysml.udf.lib.DynamicReadMatrixRcCP; |
| import org.apache.sysml.udf.lib.OrderWrapper; |
| |
| /** |
| * This Inter Procedural Analysis (IPA) serves two major purposes: |
| * 1) Inter-Procedure Analysis: propagate statistics from calling program into |
| * functions and back into main program. This is done recursively for nested |
| * function invocations. |
| * 2) Intra-Procedural Analysis: propagate statistics across hop dags of subsequent |
| * statement blocks in order to allow chained function calls and reasoning about |
| * changing sparsity etc (that requires the rewritten hops dag as input). This |
| * also includes control-flow aware propagation of size and sparsity. Furthermore, |
| * it also serves as a second constant propagation pass. |
| * |
| * In general, the basic concepts of IPA are as follows and all places that deal with |
| * statistic propagation should adhere to that: |
| * * Rule 1: Exact size propagation: Since the dimension information are sometimes used |
| * for specific lops construction (e.g., in append) and rewrites, we cannot propagate worst-case |
| * estimates but only exact information; otherwise size must be unknown. |
| * * Rule 2: Dimension information and sparsity are handled separately, i.e., if an updated |
| * variable has changing sparsity but constant dimensions, its dimensions are known but |
| * sparsity unknown. |
| * |
| * More specifically, those two rules are currently realized as follows: |
| * * Statistics propagation is applied for DML-bodied functions that are invoked exactly once. |
| * This ensures that we can savely propagate exact information into this function. |
| * If ALLOW_MULTIPLE_FUNCTION_CALLS is enabled we treat multiple calls with the same sizes |
| * as one call and hence, propagate those statistics into the function as well. |
| * * Output size inference happens for DML-bodied functions that are invoked exactly once |
| * and for external functions that are known in advance (see UDFs in org.apache.sysml.udf). |
| * * Size propagation across DAGs requires control flow awareness: |
| * - Generic statement blocks: updated variables -> old stats in; new stats out |
| * - While/for statement blocks: updated variables -> old stats in/out if loop insensitive; otherwise unknown |
| * - If statement blocks: updated variables -> old stats in; new stats out if branch-insensitive |
| * |
| * |
| */ |
| @SuppressWarnings("deprecation") |
| public class InterProceduralAnalysis |
| { |
| private static final boolean LDEBUG = false; //internal local debug level |
| private static final Log LOG = LogFactory.getLog(InterProceduralAnalysis.class.getName()); |
| |
| //internal configuration parameters |
| private static final boolean INTRA_PROCEDURAL_ANALYSIS = true; //propagate statistics across statement blocks (main/functions) |
| private static final boolean PROPAGATE_KNOWN_UDF_STATISTICS = true; //propagate statistics for known external functions |
| private static final boolean ALLOW_MULTIPLE_FUNCTION_CALLS = true; //propagate consistent statistics from multiple calls |
| private static final boolean REMOVE_UNUSED_FUNCTIONS = true; //remove unused functions (inlined or never called) |
| private static final boolean FLAG_FUNCTION_RECOMPILE_ONCE = true; //flag functions which require recompilation inside a loop for full function recompile |
| private static final boolean REMOVE_UNNECESSARY_CHECKPOINTS = true; //remove unnecessary checkpoints (unconditionally overwritten intermediates) |
| private static final boolean REMOVE_CONSTANT_BINARY_OPS = true; //remove constant binary operations (e.g., X*ones, where ones=matrix(1,...)) |
| private static final boolean PROPAGATE_SCALAR_VARS_INTO_FUN = true; //propagate scalar variables into functions that are called once |
| public static boolean UNARY_DIMS_PRESERVING_FUNS = true; //determine and exploit unary dimension preserving functions |
| |
| static { |
| // for internal debugging only |
| if( LDEBUG ) { |
| Logger.getLogger("org.apache.sysml.hops.ipa.InterProceduralAnalysis") |
| .setLevel((Level) Level.DEBUG); |
| } |
| } |
| |
| public InterProceduralAnalysis() { |
| //do nothing |
| } |
| |
| /** |
| * Public interface to perform IPA over a given DML program. |
| * |
| * @param dmlt |
| * @param dmlp |
| * @throws HopsException |
| * @throws ParseException |
| * @throws LanguageException |
| */ |
| @SuppressWarnings("unchecked") |
| public void analyzeProgram( DMLProgram dmlp ) |
| throws HopsException, ParseException, LanguageException |
| { |
| //step 1: get candidates for statistics propagation into functions (if required) |
| Map<String, Integer> fcandCounts = new HashMap<String, Integer>(); |
| Map<String, FunctionOp> fcandHops = new HashMap<String, FunctionOp>(); |
| Map<String, Set<Long>> fcandSafeNNZ = new HashMap<String, Set<Long>>(); |
| Set<String> allFCandKeys = new HashSet<String>(); |
| if( !dmlp.getFunctionStatementBlocks().isEmpty() ) { |
| for ( StatementBlock sb : dmlp.getStatementBlocks() ) //get candidates (over entire program) |
| getFunctionCandidatesForStatisticPropagation( sb, fcandCounts, fcandHops ); |
| allFCandKeys.addAll(fcandCounts.keySet()); //cp before pruning |
| pruneFunctionCandidatesForStatisticPropagation( fcandCounts, fcandHops ); |
| determineFunctionCandidatesNNZPropagation( fcandHops, fcandSafeNNZ ); |
| DMLTranslator.resetHopsDAGVisitStatus( dmlp ); |
| } |
| |
| //step 2: get unary dimension-preserving non-candidate functions |
| Collection<String> unaryFcandTmp = CollectionUtils.subtract(allFCandKeys, fcandCounts.keySet()); |
| HashSet<String> unaryFcands = new HashSet<String>(); |
| if( !unaryFcandTmp.isEmpty() && UNARY_DIMS_PRESERVING_FUNS ) { |
| for( String tmp : unaryFcandTmp ) |
| if( isUnarySizePreservingFunction(dmlp.getFunctionStatementBlock(tmp)) ) |
| unaryFcands.add(tmp); |
| } |
| |
| //step 3: propagate statistics and scalars into functions and across DAGs |
| if( !fcandCounts.isEmpty() || INTRA_PROCEDURAL_ANALYSIS ) { |
| //(callVars used to chain outputs/inputs of multiple functions calls) |
| LocalVariableMap callVars = new LocalVariableMap(); |
| for ( StatementBlock sb : dmlp.getStatementBlocks() ) //propagate stats into candidates |
| propagateStatisticsAcrossBlock( sb, fcandCounts, callVars, fcandSafeNNZ, unaryFcands, new HashSet<String>() ); |
| } |
| |
| //step 4: remove unused functions (e.g., inlined or never called) |
| if( REMOVE_UNUSED_FUNCTIONS ) { |
| removeUnusedFunctions( dmlp, allFCandKeys ); |
| } |
| |
| //step 5: flag functions with loops for 'recompile-on-entry' |
| if( FLAG_FUNCTION_RECOMPILE_ONCE ) { |
| flagFunctionsForRecompileOnce( dmlp ); |
| } |
| |
| //step 6: set global data flow properties |
| if( REMOVE_UNNECESSARY_CHECKPOINTS |
| && OptimizerUtils.isSparkExecutionMode() ) |
| { |
| //remove unnecessary checkpoint before update |
| removeCheckpointBeforeUpdate(dmlp); |
| |
| //move necessary checkpoint after update |
| moveCheckpointAfterUpdate(dmlp); |
| |
| //remove unnecessary checkpoint read-{write|uagg} |
| removeCheckpointReadWrite(dmlp); |
| } |
| |
| //step 7: remove constant binary ops |
| if( REMOVE_CONSTANT_BINARY_OPS ) { |
| removeConstantBinaryOps(dmlp); |
| } |
| |
| //TODO evaluate potential of SECOND_CHANCE |
| //(consistent call stats after first IPA pass and hence additional potential) |
| } |
| |
| /** |
| * |
| * @param sb |
| * @return |
| * @throws ParseException |
| * @throws HopsException |
| */ |
| public Set<String> analyzeSubProgram( StatementBlock sb ) |
| throws HopsException, ParseException |
| { |
| DMLTranslator.resetHopsDAGVisitStatus(sb); |
| |
| //step 1: get candidates for statistics propagation into functions (if required) |
| Map<String, Integer> fcandCounts = new HashMap<String, Integer>(); |
| Map<String, FunctionOp> fcandHops = new HashMap<String, FunctionOp>(); |
| Map<String, Set<Long>> fcandSafeNNZ = new HashMap<String, Set<Long>>(); |
| Set<String> allFCandKeys = new HashSet<String>(); |
| getFunctionCandidatesForStatisticPropagation( sb, fcandCounts, fcandHops ); |
| allFCandKeys.addAll(fcandCounts.keySet()); //cp before pruning |
| pruneFunctionCandidatesForStatisticPropagation( fcandCounts, fcandHops ); |
| determineFunctionCandidatesNNZPropagation( fcandHops, fcandSafeNNZ ); |
| DMLTranslator.resetHopsDAGVisitStatus( sb ); |
| |
| if( !fcandCounts.isEmpty() ) { |
| //step 2: propagate statistics into functions and across DAGs |
| //(callVars used to chain outputs/inputs of multiple functions calls) |
| LocalVariableMap callVars = new LocalVariableMap(); |
| propagateStatisticsAcrossBlock( sb, fcandCounts, callVars, fcandSafeNNZ, new HashSet<String>(), new HashSet<String>() ); |
| } |
| |
| return fcandCounts.keySet(); |
| } |
| |
| |
| ///////////////////////////// |
| // GET FUNCTION CANDIDATES |
| ////// |
| |
| /** |
| * |
| * @param sb |
| * @param fcand |
| * @throws HopsException |
| * @throws ParseException |
| */ |
| private void getFunctionCandidatesForStatisticPropagation( StatementBlock sb, Map<String, Integer> fcandCounts, Map<String, FunctionOp> fcandHops ) |
| throws HopsException, ParseException |
| { |
| if (sb instanceof FunctionStatementBlock) |
| { |
| FunctionStatementBlock fsb = (FunctionStatementBlock)sb; |
| FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); |
| for (StatementBlock sbi : fstmt.getBody()) |
| getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops); |
| } |
| else if (sb instanceof WhileStatementBlock) |
| { |
| WhileStatementBlock wsb = (WhileStatementBlock) sb; |
| WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); |
| for (StatementBlock sbi : wstmt.getBody()) |
| getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops); |
| } |
| else if (sb instanceof IfStatementBlock) |
| { |
| IfStatementBlock isb = (IfStatementBlock) sb; |
| IfStatement istmt = (IfStatement)isb.getStatement(0); |
| for (StatementBlock sbi : istmt.getIfBody()) |
| getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops); |
| for (StatementBlock sbi : istmt.getElseBody()) |
| getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops); |
| } |
| else if (sb instanceof ForStatementBlock) //incl parfor |
| { |
| ForStatementBlock fsb = (ForStatementBlock) sb; |
| ForStatement fstmt = (ForStatement)fsb.getStatement(0); |
| for (StatementBlock sbi : fstmt.getBody()) |
| getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops); |
| } |
| else //generic (last-level) |
| { |
| ArrayList<Hop> roots = sb.get_hops(); |
| if( roots != null ) //empty statement blocks |
| for( Hop root : roots ) |
| getFunctionCandidatesForStatisticPropagation(sb.getDMLProg(), root, fcandCounts, fcandHops); |
| } |
| } |
| |
| /** |
| * |
| * @param prog |
| * @param hop |
| * @param fcand |
| * @throws HopsException |
| * @throws ParseException |
| */ |
| private void getFunctionCandidatesForStatisticPropagation(DMLProgram prog, Hop hop, Map<String, Integer> fcandCounts, Map<String, FunctionOp> fcandHops ) |
| throws HopsException, ParseException |
| { |
| if( hop.getVisited() == VisitStatus.DONE ) |
| return; |
| |
| if( hop instanceof FunctionOp && !((FunctionOp)hop).getFunctionNamespace().equals(DMLProgram.INTERNAL_NAMESPACE) ) |
| { |
| //maintain counters and investigate functions if not seen so far |
| FunctionOp fop = (FunctionOp) hop; |
| String fkey = DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), fop.getFunctionName()); |
| |
| if( fcandCounts.containsKey(fkey) ) { |
| if( ALLOW_MULTIPLE_FUNCTION_CALLS ) |
| { |
| //compare input matrix characteristics for both function calls |
| //(if unknown or difference: maintain counter - this function is no candidate) |
| boolean consistent = true; |
| FunctionOp efop = fcandHops.get(fkey); |
| int numInputs = efop.getInput().size(); |
| for( int i=0; i<numInputs; i++ ) |
| { |
| Hop h1 = efop.getInput().get(i); |
| Hop h2 = fop.getInput().get(i); |
| //check matrix and scalar sizes (if known dims, nnz known/unknown, |
| // safeness of nnz propagation, determined later per input) |
| consistent &= (h1.dimsKnown() && h2.dimsKnown() |
| && h1.getDim1()==h2.getDim1() |
| && h1.getDim2()==h2.getDim2() |
| && h1.getNnz()==h2.getNnz() ); |
| //check literal values (equi value) |
| if( h1 instanceof LiteralOp ){ |
| consistent &= (h2 instanceof LiteralOp |
| && HopRewriteUtils.isEqualValue((LiteralOp)h1, (LiteralOp)h2)); |
| } |
| |
| |
| } |
| |
| if( !consistent ) //if differences, do not propagate |
| fcandCounts.put(fkey, fcandCounts.get(fkey)+1); |
| } |
| else |
| { |
| //maintain counter (this function is no candidate) |
| fcandCounts.put(fkey, fcandCounts.get(fkey)+1); |
| } |
| } |
| else { //first appearance |
| fcandCounts.put(fkey, 1); //create a new count entry |
| fcandHops.put(fkey, fop); //keep the function call hop |
| FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fop.getFunctionNamespace(), fop.getFunctionName()); |
| getFunctionCandidatesForStatisticPropagation(fsb, fcandCounts, fcandHops); |
| } |
| } |
| |
| for( Hop c : hop.getInput() ) |
| getFunctionCandidatesForStatisticPropagation(prog, c, fcandCounts, fcandHops); |
| |
| hop.setVisited(VisitStatus.DONE); |
| } |
| |
| /** |
| * |
| * @param fcand |
| */ |
| private void pruneFunctionCandidatesForStatisticPropagation(Map<String, Integer> fcandCounts, Map<String, FunctionOp> fcandHops) |
| { |
| //debug input |
| if( LOG.isDebugEnabled() ) |
| for( Entry<String,Integer> e : fcandCounts.entrySet() ) |
| { |
| String key = e.getKey(); |
| Integer count = e.getValue(); |
| LOG.debug("IPA: FUNC statistic propagation candidate: "+key+", callCount="+count); |
| } |
| |
| //materialize key set |
| Set<String> tmp = new HashSet<String>(fcandCounts.keySet()); |
| |
| //check and prune candidate list |
| for( String key : tmp ) |
| { |
| Integer cnt = fcandCounts.get(key); |
| if( cnt != null && cnt > 1 ) //if multiple refs |
| fcandCounts.remove(key); |
| } |
| |
| //debug output |
| if( LOG.isDebugEnabled() ) |
| for( String key : fcandCounts.keySet() ) |
| { |
| LOG.debug("IPA: FUNC statistic propagation candidate (after pruning): "+key); |
| } |
| } |
| |
| /** |
| * |
| * @param fsb |
| * @return |
| * @throws HopsException |
| * @throws ParseException |
| */ |
| private boolean isUnarySizePreservingFunction(FunctionStatementBlock fsb) |
| throws HopsException, ParseException |
| { |
| FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0); |
| |
| //check unary functions over matrices |
| boolean ret = (fstmt.getInputParams().size() == 1 |
| && fstmt.getInputParams().get(0).getDataType()==DataType.MATRIX |
| && fstmt.getOutputParams().size() == 1 |
| && fstmt.getOutputParams().get(0).getDataType()==DataType.MATRIX); |
| |
| //check size-preserving characteristic |
| if( ret ) { |
| HashMap<String, Integer> tmp1 = new HashMap<String,Integer>(); |
| HashMap<String, Set<Long>> tmp2 = new HashMap<String, Set<Long>>(); |
| HashSet<String> tmp3 = new HashSet<String>(); |
| HashSet<String> tmp4 = new HashSet<String>(); |
| LocalVariableMap callVars = new LocalVariableMap(); |
| |
| //populate input |
| MatrixObject mo = createOutputMatrix(7777, 3333, -1); |
| callVars.put(fstmt.getInputParams().get(0).getName(), mo); |
| |
| //propagate statistics |
| for (StatementBlock sbi : fstmt.getBody()) |
| propagateStatisticsAcrossBlock(sbi, tmp1, callVars, tmp2, tmp3, tmp4); |
| |
| //compare output |
| MatrixObject mo2 = (MatrixObject)callVars.get(fstmt.getOutputParams().get(0).getName()); |
| ret &= mo.getNumRows() == mo2.getNumRows() && mo.getNumColumns() == mo2.getNumColumns(); |
| |
| //reset function |
| mo.getMatrixCharacteristics().setDimension(-1, -1); |
| for (StatementBlock sbi : fstmt.getBody()) |
| propagateStatisticsAcrossBlock(sbi, tmp1, callVars, tmp2, tmp3, tmp4); |
| } |
| |
| return ret; |
| } |
| |
| ///////////////////////////// |
| // DETERMINE NNZ PROPAGATE SAFENESS |
| ////// |
| |
| /** |
| * Populates fcandSafeNNZ with all <functionKey,hopID> pairs where it is safe to |
| * propagate nnz into the function. |
| * |
| * @param fcandHops |
| * @param fcandSafeNNZ |
| */ |
| private void determineFunctionCandidatesNNZPropagation(Map<String, FunctionOp> fcandHops, Map<String, Set<Long>> fcandSafeNNZ) |
| { |
| //for all function candidates |
| for( Entry<String, FunctionOp> e : fcandHops.entrySet() ) |
| { |
| String fKey = e.getKey(); |
| FunctionOp fop = e.getValue(); |
| HashSet<Long> tmp = new HashSet<Long>(); |
| |
| //for all inputs of this function call |
| for( Hop input : fop.getInput() ) |
| { |
| //if nnz known it is safe to propagate those nnz because for multiple calls |
| //we checked of equivalence and hence all calls have the same nnz |
| if( input.getNnz()>=0 ) |
| tmp.add(input.getHopID()); |
| } |
| |
| fcandSafeNNZ.put(fKey, tmp); |
| } |
| } |
| |
| ///////////////////////////// |
| // INTRA-PROCEDURE ANALYSIS |
| ////// |
| |
| /** |
| * |
| * @param sb |
| * @param fcand |
| * @throws HopsException |
| * @throws ParseException |
| * @throws CloneNotSupportedException |
| */ |
| private void propagateStatisticsAcrossBlock( StatementBlock sb, Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> fcandSafeNNZ, Set<String> unaryFcands, Set<String> fnStack ) |
| throws HopsException, ParseException |
| { |
| if (sb instanceof FunctionStatementBlock) |
| { |
| FunctionStatementBlock fsb = (FunctionStatementBlock)sb; |
| FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); |
| for (StatementBlock sbi : fstmt.getBody()) |
| propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); |
| } |
| else if (sb instanceof WhileStatementBlock) |
| { |
| WhileStatementBlock wsb = (WhileStatementBlock) sb; |
| WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); |
| //old stats into predicate |
| propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars); |
| //remove updated constant scalars |
| Recompiler.removeUpdatedScalars(callVars, wsb); |
| //check and propagate stats into body |
| LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone(); |
| for (StatementBlock sbi : wstmt.getBody()) |
| propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); |
| if( Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, wsb) ){ //second pass if required |
| propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars); |
| for (StatementBlock sbi : wstmt.getBody()) |
| propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); |
| } |
| //remove updated constant scalars |
| Recompiler.removeUpdatedScalars(callVars, sb); |
| } |
| else if (sb instanceof IfStatementBlock) |
| { |
| IfStatementBlock isb = (IfStatementBlock) sb; |
| IfStatement istmt = (IfStatement)isb.getStatement(0); |
| //old stats into predicate |
| propagateStatisticsAcrossPredicateDAG(isb.getPredicateHops(), callVars); |
| //check and propagate stats into body |
| LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone(); |
| LocalVariableMap callVarsElse = (LocalVariableMap) callVars.clone(); |
| for (StatementBlock sbi : istmt.getIfBody()) |
| propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); |
| for (StatementBlock sbi : istmt.getElseBody()) |
| propagateStatisticsAcrossBlock(sbi, fcand, callVarsElse, fcandSafeNNZ, unaryFcands, fnStack); |
| callVars = Recompiler.reconcileUpdatedCallVarsIf(oldCallVars, callVars, callVarsElse, isb); |
| //remove updated constant scalars |
| Recompiler.removeUpdatedScalars(callVars, sb); |
| } |
| else if (sb instanceof ForStatementBlock) //incl parfor |
| { |
| ForStatementBlock fsb = (ForStatementBlock) sb; |
| ForStatement fstmt = (ForStatement)fsb.getStatement(0); |
| //old stats into predicate |
| propagateStatisticsAcrossPredicateDAG(fsb.getFromHops(), callVars); |
| propagateStatisticsAcrossPredicateDAG(fsb.getToHops(), callVars); |
| propagateStatisticsAcrossPredicateDAG(fsb.getIncrementHops(), callVars); |
| //remove updated constant scalars |
| Recompiler.removeUpdatedScalars(callVars, fsb); |
| //check and propagate stats into body |
| LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone(); |
| for (StatementBlock sbi : fstmt.getBody()) |
| propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); |
| if( Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, fsb) ) |
| for (StatementBlock sbi : fstmt.getBody()) |
| propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); |
| //remove updated constant scalars |
| Recompiler.removeUpdatedScalars(callVars, sb); |
| } |
| else //generic (last-level) |
| { |
| //remove updated constant scalars |
| Recompiler.removeUpdatedScalars(callVars, sb); |
| //old stats in, new stats out if updated |
| ArrayList<Hop> roots = sb.get_hops(); |
| DMLProgram prog = sb.getDMLProg(); |
| //refresh stats across dag |
| Hop.resetVisitStatus(roots); |
| propagateStatisticsAcrossDAG(roots, callVars); |
| //propagate stats into function calls |
| Hop.resetVisitStatus(roots); |
| propagateStatisticsIntoFunctions(prog, roots, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); |
| } |
| } |
| |
| |
| /** |
| * |
| * @param root |
| * @param vars |
| * @throws HopsException |
| */ |
| private void propagateStatisticsAcrossPredicateDAG( Hop root, LocalVariableMap vars ) |
| throws HopsException |
| { |
| if( root == null ) |
| return; |
| |
| //reset visit status because potentially called multiple times |
| root.resetVisitStatus(); |
| |
| try |
| { |
| Recompiler.rUpdateStatistics( root, vars ); |
| |
| //note: for predicates no output statistics |
| //Recompiler.extractDAGOutputStatistics(root, vars); |
| } |
| catch(Exception ex) |
| { |
| throw new HopsException("Failed to update Hop DAG statistics.", ex); |
| } |
| } |
| |
| |
| /** |
| * |
| * @param roots |
| * @param vars |
| * @throws HopsException |
| */ |
| private void propagateStatisticsAcrossDAG( ArrayList<Hop> roots, LocalVariableMap vars ) |
| throws HopsException |
| { |
| if( roots == null ) |
| return; |
| |
| try |
| { |
| //update DAG statistics from leafs to roots |
| for( Hop hop : roots ) |
| Recompiler.rUpdateStatistics( hop, vars ); |
| |
| //extract statistics from roots |
| Recompiler.extractDAGOutputStatistics(roots, vars, true); |
| } |
| catch( Exception ex ) |
| { |
| throw new HopsException("Failed to update Hop DAG statistics.", ex); |
| } |
| } |
| |
| |
| ///////////////////////////// |
| // INTER-PROCEDURE ANALYIS |
| ////// |
| |
| |
| /** |
| * |
| * @param prog |
| * @param hop |
| * @param fcand |
| * @param callVars |
| * @throws HopsException |
| * @throws ParseException |
| */ |
| private void propagateStatisticsIntoFunctions(DMLProgram prog, ArrayList<Hop> roots, Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> fcandSafeNNZ, Set<String> unaryFcands, Set<String> fnStack ) |
| throws HopsException, ParseException |
| { |
| for( Hop root : roots ) |
| propagateStatisticsIntoFunctions(prog, root, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); |
| } |
| |
| |
| /** |
| * |
| * @param prog |
| * @param hop |
| * @param fcand |
| * @throws HopsException |
| * @throws ParseException |
| */ |
| private void propagateStatisticsIntoFunctions(DMLProgram prog, Hop hop, Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> fcandSafeNNZ, Set<String> unaryFcands, Set<String> fnStack ) |
| throws HopsException, ParseException |
| { |
| if( hop.getVisited() == VisitStatus.DONE ) |
| return; |
| |
| for( Hop c : hop.getInput() ) |
| propagateStatisticsIntoFunctions(prog, c, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); |
| |
| if( hop instanceof FunctionOp ) |
| { |
| //maintain counters and investigate functions if not seen so far |
| FunctionOp fop = (FunctionOp) hop; |
| String fkey = DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), fop.getFunctionName()); |
| |
| if( fop.getFunctionType() == FunctionType.DML ) |
| { |
| FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fop.getFunctionNamespace(), fop.getFunctionName()); |
| FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); |
| |
| if( fcand.containsKey(fkey) && |
| !fnStack.contains(fkey) ) //prevent recursion |
| { |
| //maintain function call stack |
| fnStack.add(fkey); |
| |
| //create mapping and populate symbol table for refresh |
| LocalVariableMap tmpVars = new LocalVariableMap(); |
| populateLocalVariableMapForFunctionCall( fstmt, fop, |
| callVars, tmpVars, fcandSafeNNZ.get(fkey), fcand.get(fkey) ); |
| |
| //recursively propagate statistics |
| propagateStatisticsAcrossBlock(fsb, fcand, tmpVars, fcandSafeNNZ, unaryFcands, fnStack); |
| |
| //extract vars from symbol table, re-map and refresh main program |
| extractFunctionCallReturnStatistics(fstmt, fop, tmpVars, callVars, true); |
| |
| //maintain function call stack |
| fnStack.remove(fkey); |
| } |
| else if( unaryFcands.contains(fkey) ) { |
| extractFunctionCallEquivalentReturnStatistics(fstmt, fop, callVars); |
| } |
| else { |
| extractFunctionCallUnknownReturnStatistics(fstmt, fop, callVars); |
| } |
| } |
| else if ( fop.getFunctionType() == FunctionType.EXTERNAL_FILE |
| || fop.getFunctionType() == FunctionType.EXTERNAL_MEM ) |
| { |
| //infer output size for known external functions |
| FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fop.getFunctionNamespace(), fop.getFunctionName()); |
| ExternalFunctionStatement fstmt = (ExternalFunctionStatement) fsb.getStatement(0); |
| if( PROPAGATE_KNOWN_UDF_STATISTICS ) |
| extractExternalFunctionCallReturnStatistics(fstmt, fop, callVars); |
| else |
| extractFunctionCallUnknownReturnStatistics(fstmt, fop, callVars); |
| } |
| } |
| |
| hop.setVisited(VisitStatus.DONE); |
| } |
| |
| |
| /** |
| * |
| * @param fstmt |
| * @param fop |
| * @param callvars |
| * @param vars |
| * @param inputSafeNNZ |
| * @param singleton |
| * @throws HopsException |
| */ |
| private void populateLocalVariableMapForFunctionCall( FunctionStatement fstmt, FunctionOp fop, LocalVariableMap callvars, LocalVariableMap vars, Set<Long> inputSafeNNZ, Integer numCalls ) |
| throws HopsException |
| { |
| ArrayList<DataIdentifier> inputVars = fstmt.getInputParams(); |
| ArrayList<Hop> inputOps = fop.getInput(); |
| |
| for( int i=0; i<inputVars.size(); i++ ) |
| { |
| //create mapping between input hops and vars |
| DataIdentifier dat = inputVars.get(i); |
| Hop input = inputOps.get(i); |
| |
| if( input.getDataType()==DataType.MATRIX ) |
| { |
| //propagate matrix characteristics |
| MatrixObject mo = new MatrixObject(ValueType.DOUBLE, null); |
| MatrixCharacteristics mc = new MatrixCharacteristics( |
| input.getDim1(), input.getDim2(), |
| ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize(), |
| inputSafeNNZ.contains(input.getHopID())?input.getNnz():-1 ); |
| MatrixFormatMetaData meta = new MatrixFormatMetaData(mc,null,null); |
| mo.setMetaData(meta); |
| vars.put(dat.getName(), mo); |
| } |
| else if( input.getDataType()==DataType.SCALAR ) |
| { |
| //always propagate scalar literals into functions |
| //(for multiple calls, literal equivalence already checked) |
| if( input instanceof LiteralOp ) { |
| LiteralOp lit = (LiteralOp)input; |
| ScalarObject scalar = null; |
| switch(input.getValueType()) { |
| case DOUBLE: scalar = new DoubleObject(lit.getDoubleValue()); break; |
| case INT: scalar = new IntObject(lit.getLongValue()); break; |
| case BOOLEAN: scalar = new BooleanObject(lit.getBooleanValue()); break; |
| case STRING: scalar = new StringObject(lit.getStringValue()); break; |
| default: //do nothing |
| } |
| vars.put(dat.getName(), scalar); |
| } |
| //propagate scalar variables into functions if called once |
| //and input scalar is existing variable in symbol table |
| else if( PROPAGATE_SCALAR_VARS_INTO_FUN |
| && numCalls != null && numCalls == 1 |
| && input instanceof DataOp ) |
| { |
| Data scalar = callvars.get(input.getName()); |
| if( scalar != null && scalar instanceof ScalarObject ) { |
| vars.put(dat.getName(), scalar); |
| } |
| } |
| } |
| } |
| } |
| |
| /** |
| * |
| * @param fstmt |
| * @param fop |
| * @param tmpVars |
| * @param callVars |
| * @param overwrite |
| * @throws HopsException |
| */ |
| private void extractFunctionCallReturnStatistics( FunctionStatement fstmt, FunctionOp fop, LocalVariableMap tmpVars, LocalVariableMap callVars, boolean overwrite ) |
| throws HopsException |
| { |
| ArrayList<DataIdentifier> foutputOps = fstmt.getOutputParams(); |
| String[] outputVars = fop.getOutputVariableNames(); |
| String fkey = DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), fop.getFunctionName()); |
| |
| try |
| { |
| for( int i=0; i<foutputOps.size(); i++ ) |
| { |
| DataIdentifier di = foutputOps.get(i); |
| String fvarname = di.getName(); //name in function signature |
| String pvarname = outputVars[i]; //name in calling program |
| |
| if( di.getDataType()==DataType.MATRIX && tmpVars.keySet().contains(fvarname) ) |
| { |
| MatrixObject moIn = (MatrixObject) tmpVars.get(fvarname); |
| |
| if( !callVars.keySet().contains(pvarname) || overwrite ) //not existing so far |
| { |
| MatrixObject moOut = createOutputMatrix(moIn.getNumRows(), moIn.getNumColumns(), moIn.getNnz()); |
| callVars.put(pvarname, moOut); |
| } |
| else //already existing: take largest |
| { |
| Data dat = callVars.get(pvarname); |
| if( dat instanceof MatrixObject ) |
| { |
| MatrixObject moOut = (MatrixObject)dat; |
| MatrixCharacteristics mc = moOut.getMatrixCharacteristics(); |
| if( OptimizerUtils.estimateSizeExactSparsity(mc.getRows(), mc.getCols(), (mc.getNonZeros()>0)?((double)mc.getNonZeros())/mc.getRows()/mc.getCols():1.0) |
| < OptimizerUtils.estimateSize(moIn.getNumRows(), moIn.getNumColumns()) ) |
| { |
| //update statistics if necessary |
| mc.setDimension(moIn.getNumRows(), moIn.getNumColumns()); |
| mc.setNonZeros(moIn.getNnz()); |
| } |
| } |
| |
| } |
| } |
| } |
| } |
| catch( Exception ex ) |
| { |
| throw new HopsException( "Failed to extract output statistics of function "+fkey+".", ex); |
| } |
| } |
| |
| /** |
| * |
| * @param fstmt |
| * @param fop |
| * @param callVars |
| * @throws HopsException |
| */ |
| private void extractFunctionCallUnknownReturnStatistics( FunctionStatement fstmt, FunctionOp fop, LocalVariableMap callVars ) |
| throws HopsException |
| { |
| ArrayList<DataIdentifier> foutputOps = fstmt.getOutputParams(); |
| String[] outputVars = fop.getOutputVariableNames(); |
| String fkey = DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), fop.getFunctionName()); |
| |
| try |
| { |
| for( int i=0; i<foutputOps.size(); i++ ) |
| { |
| DataIdentifier di = foutputOps.get(i); |
| String pvarname = outputVars[i]; //name in calling program |
| |
| if( di.getDataType()==DataType.MATRIX ) |
| { |
| MatrixObject moOut = createOutputMatrix(-1, -1, -1); |
| callVars.put(pvarname, moOut); |
| } |
| } |
| } |
| catch( Exception ex ) |
| { |
| throw new HopsException( "Failed to extract output statistics of function "+fkey+".", ex); |
| } |
| } |
| |
| /** |
| * |
| * @param fstmt |
| * @param fop |
| * @param callVars |
| * @throws HopsException |
| */ |
| private void extractFunctionCallEquivalentReturnStatistics( FunctionStatement fstmt, FunctionOp fop, LocalVariableMap callVars ) |
| throws HopsException |
| { |
| String fkey = DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), fop.getFunctionName()); |
| try { |
| Hop input = fop.getInput().get(0); |
| MatrixObject moOut = createOutputMatrix(input.getDim1(), input.getDim2(), -1); |
| callVars.put(fop.getOutputVariableNames()[0], moOut); |
| } |
| catch( Exception ex ) { |
| throw new HopsException( "Failed to extract output statistics for unary function "+fkey+".", ex); |
| } |
| } |
| |
| /** |
| * |
| * @param fstmt |
| * @param fop |
| * @param callVars |
| * @throws HopsException |
| */ |
| private void extractExternalFunctionCallReturnStatistics( ExternalFunctionStatement fstmt, FunctionOp fop, LocalVariableMap callVars ) |
| throws HopsException |
| { |
| String className = fstmt.getOtherParams().get(ExternalFunctionStatement.CLASS_NAME); |
| |
| if( className.equals(OrderWrapper.class.getName()) |
| || className.equals(DeNaNWrapper.class.getCanonicalName()) |
| || className.equals(DeNegInfinityWrapper.class.getCanonicalName()) ) |
| { |
| Hop input = fop.getInput().get(0); |
| long lnnz = className.equals(OrderWrapper.class.getName()) ? input.getNnz() : -1; |
| MatrixObject moOut = createOutputMatrix(input.getDim1(), input.getDim2(),lnnz); |
| callVars.put(fop.getOutputVariableNames()[0], moOut); |
| } |
| else if( className.equals("org.apache.sysml.udf.lib.EigenWrapper") ) |
| //else if( className.equals(EigenWrapper.class.getName()) ) //string ref for build flexibility |
| { |
| Hop input = fop.getInput().get(0); |
| callVars.put(fop.getOutputVariableNames()[0], createOutputMatrix(input.getDim1(), 1, -1)); |
| callVars.put(fop.getOutputVariableNames()[1], createOutputMatrix(input.getDim1(), input.getDim1(),-1)); |
| } |
| else if( className.equals("org.apache.sysml.udf.lib.LinearSolverWrapperCP") ) |
| //else if( className.equals(LinearSolverWrapperCP.class.getName()) ) //string ref for build flexibility |
| { |
| Hop input = fop.getInput().get(1); |
| callVars.put(fop.getOutputVariableNames()[0], createOutputMatrix(input.getDim1(), 1, -1)); |
| } |
| else if( className.equals(DynamicReadMatrixCP.class.getName()) |
| || className.equals(DynamicReadMatrixRcCP.class.getName()) ) |
| { |
| Hop input1 = fop.getInput().get(1); //rows |
| Hop input2 = fop.getInput().get(2); //cols |
| if( input1 instanceof LiteralOp && input2 instanceof LiteralOp ) |
| callVars.put(fop.getOutputVariableNames()[0], createOutputMatrix(((LiteralOp)input1).getLongValue(), |
| ((LiteralOp)input2).getLongValue(),-1)); |
| } |
| else |
| { |
| extractFunctionCallUnknownReturnStatistics(fstmt, fop, callVars); |
| } |
| } |
| |
| /** |
| * |
| * @param dim1 |
| * @param dim2 |
| * @param nnz |
| * @return |
| */ |
| private MatrixObject createOutputMatrix( long dim1, long dim2, long nnz ) { |
| MatrixObject moOut = new MatrixObject(ValueType.DOUBLE, null); |
| MatrixCharacteristics mc = new MatrixCharacteristics( dim1, dim2, |
| ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize(), nnz); |
| MatrixFormatMetaData meta = new MatrixFormatMetaData(mc,null,null); |
| moOut.setMetaData(meta); |
| |
| return moOut; |
| } |
| |
| ///////////////////////////// |
| // REMOVE UNUSED FUNCTIONS |
| ////// |
| |
| /** |
| * |
| * @param dmlp |
| * @param fcandKeys |
| * @throws LanguageException |
| */ |
| public void removeUnusedFunctions( DMLProgram dmlp, Set<String> fcandKeys ) |
| throws LanguageException |
| { |
| Set<String> fnamespaces = dmlp.getNamespaces().keySet(); |
| for( String fnspace : fnamespaces ) |
| { |
| HashMap<String, FunctionStatementBlock> fsbs = dmlp.getFunctionStatementBlocks(fnspace); |
| Iterator<Entry<String, FunctionStatementBlock>> iter = fsbs.entrySet().iterator(); |
| while( iter.hasNext() ) |
| { |
| Entry<String, FunctionStatementBlock> e = iter.next(); |
| String fname = e.getKey(); |
| String fKey = DMLProgram.constructFunctionKey(fnspace, fname); |
| //probe function candidates, remove if no candidate |
| if( !fcandKeys.contains(fKey) ) |
| iter.remove(); |
| } |
| } |
| } |
| |
| |
| ///////////////////////////// |
| // FLAG FUNCTIONS FOR RECOMPILE_ONCE |
| ////// |
| |
| /** |
| * TODO call it after construct lops |
| * |
| * @param dmlp |
| * @throws LanguageException |
| */ |
| public void flagFunctionsForRecompileOnce( DMLProgram dmlp ) |
| throws LanguageException |
| { |
| for (String namespaceKey : dmlp.getNamespaces().keySet()) |
| for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) |
| { |
| FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey,fname); |
| if( rFlagFunctionForRecompileOnce( fsblock, false ) ) |
| { |
| fsblock.setRecompileOnce( true ); |
| LOG.debug("IPA: FUNC flagged for recompile-once: " + DMLProgram.constructFunctionKey(namespaceKey, fname)); |
| } |
| } |
| } |
| |
| /** |
| * Returns true if this statementblock requires recompilation inside a |
| * loop statement block. |
| * |
| * |
| * |
| * @param sb |
| */ |
| public boolean rFlagFunctionForRecompileOnce( StatementBlock sb, boolean inLoop ) |
| { |
| boolean ret = false; |
| |
| if (sb instanceof FunctionStatementBlock) |
| { |
| FunctionStatementBlock fsb = (FunctionStatementBlock)sb; |
| FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); |
| for( StatementBlock c : fstmt.getBody() ) |
| ret |= rFlagFunctionForRecompileOnce( c, inLoop ); |
| } |
| else if (sb instanceof WhileStatementBlock) |
| { |
| //recompilation information not available at this point |
| ret = true; |
| |
| /* |
| WhileStatementBlock wsb = (WhileStatementBlock) sb; |
| WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); |
| ret |= (inLoop && wsb.requiresPredicateRecompilation() ); |
| for( StatementBlock c : wstmt.getBody() ) |
| ret |= rFlagFunctionForRecompileOnce( c, true ); |
| */ |
| } |
| else if (sb instanceof IfStatementBlock) |
| { |
| IfStatementBlock isb = (IfStatementBlock) sb; |
| IfStatement istmt = (IfStatement)isb.getStatement(0); |
| ret |= (inLoop && isb.requiresPredicateRecompilation() ); |
| for( StatementBlock c : istmt.getIfBody() ) |
| ret |= rFlagFunctionForRecompileOnce( c, inLoop ); |
| for( StatementBlock c : istmt.getElseBody() ) |
| ret |= rFlagFunctionForRecompileOnce( c, inLoop ); |
| } |
| else if (sb instanceof ForStatementBlock) |
| { |
| //recompilation information not available at this point |
| ret = true; |
| |
| /* |
| ForStatementBlock fsb = (ForStatementBlock) sb; |
| ForStatement fstmt = (ForStatement)fsb.getStatement(0); |
| for( StatementBlock c : fstmt.getBody() ) |
| ret |= rFlagFunctionForRecompileOnce( c, true ); |
| */ |
| } |
| else |
| { |
| ret |= ( inLoop && sb.requiresRecompilation() ); |
| } |
| |
| return ret; |
| } |
| |
| ///////////////////////////// |
| // REMOVE UNNECESSARY CHECKPOINTS |
| ////// |
| |
| /** |
| * |
| * @param dmlp |
| * @throws HopsException |
| */ |
| private void removeCheckpointBeforeUpdate(DMLProgram dmlp) |
| throws HopsException |
| { |
| //approach: scan over top-level program (guaranteed to be unconditional), |
| //collect checkpoints; determine if used before update; remove first checkpoint |
| //on second checkpoint if update in between and not used before update |
| |
| HashMap<String, Hop> chkpointCand = new HashMap<String, Hop>(); |
| |
| for( StatementBlock sb : dmlp.getStatementBlocks() ) |
| { |
| //prune candidates (used before updated) |
| Set<String> cands = new HashSet<String>(chkpointCand.keySet()); |
| for( String cand : cands ) |
| if( sb.variablesRead().containsVariable(cand) |
| && !sb.variablesUpdated().containsVariable(cand) ) |
| { |
| //note: variableRead might include false positives due to meta |
| //data operations like nrow(X) or operations removed by rewrites |
| //double check hops on basic blocks; otherwise worst-case |
| boolean skipRemove = false; |
| if( sb.get_hops() !=null ) { |
| Hop.resetVisitStatus(sb.get_hops()); |
| skipRemove = true; |
| for( Hop root : sb.get_hops() ) |
| skipRemove &= !HopRewriteUtils.rContainsRead(root, cand, false); |
| } |
| if( !skipRemove ) |
| chkpointCand.remove(cand); |
| } |
| |
| //prune candidates (updated in conditional control flow) |
| Set<String> cands2 = new HashSet<String>(chkpointCand.keySet()); |
| if( sb instanceof IfStatementBlock || sb instanceof WhileStatementBlock |
| || sb instanceof ForStatementBlock ) |
| { |
| for( String cand : cands2 ) |
| if( sb.variablesUpdated().containsVariable(cand) ) { |
| chkpointCand.remove(cand); |
| } |
| } |
| //prune candidates (updated w/ multiple reads) |
| else |
| { |
| for( String cand : cands2 ) |
| if( sb.variablesUpdated().containsVariable(cand) && sb.get_hops() != null) |
| { |
| Hop.resetVisitStatus(sb.get_hops()); |
| for( Hop root : sb.get_hops() ) |
| if( root.getName().equals(cand) && |
| !HopRewriteUtils.rHasSimpleReadChain(root, cand) ) { |
| chkpointCand.remove(cand); |
| } |
| } |
| } |
| |
| //collect checkpoints and remove unnecessary checkpoints |
| ArrayList<Hop> tmp = collectCheckpoints(sb.get_hops()); |
| for( Hop chkpoint : tmp ) { |
| if( chkpointCand.containsKey(chkpoint.getName()) ) { |
| chkpointCand.get(chkpoint.getName()).setRequiresCheckpoint(false); |
| } |
| chkpointCand.put(chkpoint.getName(), chkpoint); |
| } |
| |
| } |
| } |
| |
| /** |
| * |
| * @param dmlp |
| * @throws HopsException |
| */ |
| private void moveCheckpointAfterUpdate(DMLProgram dmlp) |
| throws HopsException |
| { |
| //approach: scan over top-level program (guaranteed to be unconditional), |
| //collect checkpoints; determine if used before update; move first checkpoint |
| //after update if not used before update (best effort move which often avoids |
| //the second checkpoint on loops even though used in between) |
| |
| HashMap<String, Hop> chkpointCand = new HashMap<String, Hop>(); |
| |
| for( StatementBlock sb : dmlp.getStatementBlocks() ) |
| { |
| //prune candidates (used before updated) |
| Set<String> cands = new HashSet<String>(chkpointCand.keySet()); |
| for( String cand : cands ) |
| if( sb.variablesRead().containsVariable(cand) |
| && !sb.variablesUpdated().containsVariable(cand) ) |
| { |
| //note: variableRead might include false positives due to meta |
| //data operations like nrow(X) or operations removed by rewrites |
| //double check hops on basic blocks; otherwise worst-case |
| boolean skipRemove = false; |
| if( sb.get_hops() !=null ) { |
| Hop.resetVisitStatus(sb.get_hops()); |
| skipRemove = true; |
| for( Hop root : sb.get_hops() ) |
| skipRemove &= !HopRewriteUtils.rContainsRead(root, cand, false); |
| } |
| if( !skipRemove ) |
| chkpointCand.remove(cand); |
| } |
| |
| //prune candidates (updated in conditional control flow) |
| Set<String> cands2 = new HashSet<String>(chkpointCand.keySet()); |
| if( sb instanceof IfStatementBlock || sb instanceof WhileStatementBlock |
| || sb instanceof ForStatementBlock ) |
| { |
| for( String cand : cands2 ) |
| if( sb.variablesUpdated().containsVariable(cand) ) { |
| chkpointCand.remove(cand); |
| } |
| } |
| //move checkpoint after update with simple read chain |
| //(note: right now this only applies if the checkpoints comes from a previous |
| //statement block, within-dag checkpoints should be handled during injection) |
| else |
| { |
| for( String cand : cands2 ) |
| if( sb.variablesUpdated().containsVariable(cand) && sb.get_hops() != null) { |
| Hop.resetVisitStatus(sb.get_hops()); |
| for( Hop root : sb.get_hops() ) |
| if( root.getName().equals(cand) ) { |
| if( HopRewriteUtils.rHasSimpleReadChain(root, cand) ) { |
| chkpointCand.get(cand).setRequiresCheckpoint(false); |
| root.getInput().get(0).setRequiresCheckpoint(true); |
| chkpointCand.put(cand, root.getInput().get(0)); |
| } |
| else |
| chkpointCand.remove(cand); |
| } |
| } |
| } |
| |
| //collect checkpoints |
| ArrayList<Hop> tmp = collectCheckpoints(sb.get_hops()); |
| for( Hop chkpoint : tmp ) { |
| chkpointCand.put(chkpoint.getName(), chkpoint); |
| } |
| } |
| } |
| |
| /** |
| * |
| * @param dmlp |
| * @throws HopsException |
| */ |
| private void removeCheckpointReadWrite(DMLProgram dmlp) |
| throws HopsException |
| { |
| List<StatementBlock> sbs = dmlp.getStatementBlocks(); |
| |
| if( sbs.size()==1 & !(sbs.get(0) instanceof IfStatementBlock |
| || sbs.get(0) instanceof WhileStatementBlock |
| || sbs.get(0) instanceof ForStatementBlock) ) |
| { |
| //recursively process all dag roots |
| if( sbs.get(0).get_hops()!=null ) { |
| Hop.resetVisitStatus(sbs.get(0).get_hops()); |
| for( Hop root : sbs.get(0).get_hops() ) |
| rRemoveCheckpointReadWrite(root); |
| } |
| } |
| } |
| |
| /** |
| * |
| * @param roots |
| * @return |
| */ |
| private ArrayList<Hop> collectCheckpoints(ArrayList<Hop> roots) |
| { |
| ArrayList<Hop> ret = new ArrayList<Hop>(); |
| if( roots != null ) { |
| Hop.resetVisitStatus(roots); |
| for( Hop root : roots ) |
| rCollectCheckpoints(root, ret); |
| } |
| |
| return ret; |
| } |
| |
| /** |
| * |
| * @param hop |
| * @param checkpoints |
| */ |
| private void rCollectCheckpoints(Hop hop, ArrayList<Hop> checkpoints) |
| { |
| if( hop.getVisited()==VisitStatus.DONE ) |
| return; |
| |
| //handle leaf node for variable (checkpoint directly bound |
| //to logical variable name and not used) |
| if( hop.requiresCheckpoint() && hop.getParent().size()==1 |
| && hop.getParent().get(0) instanceof DataOp |
| && ((DataOp)hop.getParent().get(0)).getDataOpType()==DataOpTypes.TRANSIENTWRITE) |
| { |
| checkpoints.add(hop); |
| } |
| |
| //recursively process child nodes |
| for( Hop c : hop.getInput() ) |
| rCollectCheckpoints(c, checkpoints); |
| |
| hop.setVisited(Hop.VisitStatus.DONE); |
| } |
| |
| /** |
| * |
| * @param hop |
| */ |
| public static void rRemoveCheckpointReadWrite(Hop hop) |
| { |
| if( hop.getVisited()==VisitStatus.DONE ) |
| return; |
| |
| //remove checkpoint on pread if only consumed by pwrite or uagg |
| if( (hop instanceof DataOp && ((DataOp)hop).getDataOpType()==DataOpTypes.PERSISTENTWRITE) |
| || hop instanceof AggUnaryOp ) |
| { |
| //(pwrite|uagg) - pread |
| Hop c0 = hop.getInput().get(0); |
| if( c0.requiresCheckpoint() && c0.getParent().size() == 1 |
| && c0 instanceof DataOp && ((DataOp)c0).getDataOpType()==DataOpTypes.PERSISTENTREAD ) |
| { |
| c0.setRequiresCheckpoint(false); |
| } |
| |
| //(pwrite|uagg) - frame/matri cast - pread |
| if( c0 instanceof UnaryOp && c0.getParent().size() == 1 |
| && (((UnaryOp)c0).getOp()==OpOp1.CAST_AS_FRAME || ((UnaryOp)c0).getOp()==OpOp1.CAST_AS_MATRIX ) |
| && c0.getInput().get(0).requiresCheckpoint() && c0.getInput().get(0).getParent().size() == 1 |
| && c0.getInput().get(0) instanceof DataOp |
| && ((DataOp)c0.getInput().get(0)).getDataOpType()==DataOpTypes.PERSISTENTREAD ) |
| { |
| c0.getInput().get(0).setRequiresCheckpoint(false); |
| } |
| } |
| |
| //recursively process children |
| for( Hop c : hop.getInput() ) |
| rRemoveCheckpointReadWrite(c); |
| |
| hop.setVisited(Hop.VisitStatus.DONE); |
| } |
| |
| ///////////////////////////// |
| // REMOVE CONSTANT BINARY OPS |
| ////// |
| |
| /** |
| * |
| * @param dmlp |
| * @throws HopsException |
| */ |
| private void removeConstantBinaryOps(DMLProgram dmlp) |
| throws HopsException |
| { |
| //approach: scan over top-level program (guaranteed to be unconditional), |
| //collect ones=matrix(1,...); remove b(*)ones if not outer operation |
| HashMap<String, Hop> mOnes = new HashMap<String, Hop>(); |
| |
| for( StatementBlock sb : dmlp.getStatementBlocks() ) |
| { |
| //pruning updated variables |
| for( String var : sb.variablesUpdated().getVariableNames() ) |
| if( mOnes.containsKey( var ) ) |
| mOnes.remove( var ); |
| |
| //replace constant binary ops |
| if( !mOnes.isEmpty() ) |
| rRemoveConstantBinaryOp(sb, mOnes); |
| |
| //collect matrices of ones from last-level statement blocks |
| if( !(sb instanceof IfStatementBlock || sb instanceof WhileStatementBlock |
| || sb instanceof ForStatementBlock) ) |
| { |
| collectMatrixOfOnes(sb.get_hops(), mOnes); |
| } |
| } |
| } |
| |
| /** |
| * |
| * @param roots |
| * @param mOnes |
| */ |
| private void collectMatrixOfOnes(ArrayList<Hop> roots, HashMap<String,Hop> mOnes) |
| { |
| if( roots == null ) |
| return; |
| |
| for( Hop root : roots ) |
| if( root instanceof DataOp && ((DataOp)root).getDataOpType()==DataOpTypes.TRANSIENTWRITE |
| && root.getInput().get(0) instanceof DataGenOp |
| && ((DataGenOp)root.getInput().get(0)).getOp()==DataGenMethod.RAND |
| && ((DataGenOp)root.getInput().get(0)).hasConstantValue(1.0)) |
| { |
| mOnes.put(root.getName(),root.getInput().get(0)); |
| } |
| } |
| |
| /** |
| * |
| * @param sb |
| * @param mOnes |
| * @throws HopsException |
| */ |
| private void rRemoveConstantBinaryOp(StatementBlock sb, HashMap<String,Hop> mOnes) |
| throws HopsException |
| { |
| if( sb instanceof IfStatementBlock ) |
| { |
| IfStatementBlock isb = (IfStatementBlock) sb; |
| IfStatement istmt = (IfStatement)isb.getStatement(0); |
| for( StatementBlock c : istmt.getIfBody() ) |
| rRemoveConstantBinaryOp(c, mOnes); |
| if( istmt.getElseBody() != null ) |
| for( StatementBlock c : istmt.getElseBody() ) |
| rRemoveConstantBinaryOp(c, mOnes); |
| } |
| else if( sb instanceof WhileStatementBlock ) |
| { |
| WhileStatementBlock wsb = (WhileStatementBlock) sb; |
| WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); |
| for( StatementBlock c : wstmt.getBody() ) |
| rRemoveConstantBinaryOp(c, mOnes); |
| } |
| else if( sb instanceof ForStatementBlock ) |
| { |
| ForStatementBlock fsb = (ForStatementBlock) sb; |
| ForStatement fstmt = (ForStatement)fsb.getStatement(0); |
| for( StatementBlock c : fstmt.getBody() ) |
| rRemoveConstantBinaryOp(c, mOnes); |
| } |
| else |
| { |
| if( sb.get_hops() != null ){ |
| Hop.resetVisitStatus(sb.get_hops()); |
| for( Hop hop : sb.get_hops() ) |
| rRemoveConstantBinaryOp(hop, mOnes); |
| } |
| } |
| } |
| |
| /** |
| * |
| * @param hop |
| * @param mOnes |
| */ |
| private void rRemoveConstantBinaryOp(Hop hop, HashMap<String,Hop> mOnes) |
| { |
| if( hop.getVisited()==VisitStatus.DONE ) |
| return; |
| |
| if( hop instanceof BinaryOp && ((BinaryOp)hop).getOp()==OpOp2.MULT |
| && !((BinaryOp) hop).isOuterVectorOperator() |
| && hop.getInput().get(0).getDataType()==DataType.MATRIX |
| && hop.getInput().get(1) instanceof DataOp |
| && mOnes.containsKey(hop.getInput().get(1).getName()) ) |
| { |
| //replace matrix of ones with literal 1 (later on removed by |
| //algebraic simplification rewrites; otherwise more complex |
| //recursive processing of childs and rewiring required) |
| HopRewriteUtils.removeChildReferenceByPos(hop, hop.getInput().get(1), 1); |
| HopRewriteUtils.addChildReference(hop, new LiteralOp(1), 1); |
| } |
| |
| //recursively process child nodes |
| for( Hop c : hop.getInput() ) |
| rRemoveConstantBinaryOp(c, mOnes); |
| |
| hop.setVisited(Hop.VisitStatus.DONE); |
| } |
| } |