| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.parser; |
| |
| import java.util.ArrayList; |
| import java.util.Collection; |
| import java.util.HashMap; |
| import java.util.HashSet; |
| import java.util.LinkedList; |
| import java.util.List; |
| import java.util.stream.Collectors; |
| |
| import org.apache.sysds.conf.ConfigurationManager; |
| import org.apache.sysds.hops.Hop; |
| import org.apache.sysds.hops.IndexingOp; |
| import org.apache.sysds.hops.LiteralOp; |
| import org.apache.sysds.hops.OptimizerUtils; |
| import org.apache.sysds.hops.rewrite.HopRewriteUtils; |
| import org.apache.sysds.common.Builtins; |
| import org.apache.sysds.common.Types.DataType; |
| import org.apache.sysds.common.Types.OpOp1; |
| import org.apache.sysds.common.Types.OpOp2; |
| import org.apache.sysds.common.Types.OpOpData; |
| import org.apache.sysds.common.Types.ValueType; |
| import org.apache.sysds.parser.Expression.BinaryOp; |
| import org.apache.sysds.parser.PrintStatement.PRINTTYPE; |
| import org.apache.sysds.runtime.controlprogram.ParForProgramBlock.PDataPartitionFormat; |
| import org.apache.sysds.runtime.controlprogram.ParForProgramBlock.PDataPartitioner; |
| import org.apache.sysds.runtime.controlprogram.ParForProgramBlock.PExecMode; |
| import org.apache.sysds.runtime.controlprogram.ParForProgramBlock.POptMode; |
| import org.apache.sysds.runtime.controlprogram.ParForProgramBlock.PResultMerge; |
| import org.apache.sysds.runtime.controlprogram.ParForProgramBlock.PTaskPartitioner; |
| import org.apache.sysds.runtime.controlprogram.ParForProgramBlock.PartitionFormat; |
| import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; |
| import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; |
| import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; |
| import org.apache.sysds.runtime.util.UtilFunctions; |
| |
| /** |
| * This ParForStatementBlock is essentially identical to a ForStatementBlock, except an extended validate |
| * for checking/setting optional parfor parameters and running the loop dependency analysis. |
| * |
| */ |
| public class ParForStatementBlock extends ForStatementBlock |
| { |
| //external parameter names |
| private static HashSet<String> _paramNames; |
| public static final String CHECK = "check"; //run loop dependency analysis |
| public static final String PAR = "par"; //number of parallel workers |
| public static final String TASK_PARTITIONER = "taskpartitioner"; //task partitioner |
| public static final String TASK_SIZE = "tasksize"; //number of tasks |
| public static final String DATA_PARTITIONER = "datapartitioner"; //task partitioner |
| public static final String RESULT_MERGE = "resultmerge"; //task partitioner |
| public static final String EXEC_MODE = "mode"; //runtime execution mode |
| public static final String OPT_MODE = "opt"; //runtime execution mode |
| public static final String OPT_LOG = "log"; //parfor logging mode |
| public static final String PROFILE = "profile"; //monitor and report parfor performance profile |
| |
| //default external parameter values |
| private static HashMap<String, String> _paramDefaults; |
| private static HashMap<String, String> _paramDefaults2; //for constrained opt |
| |
| //internal parameter values |
| private static final boolean NORMALIZE = false; //normalize FOR from to incr |
| private static final boolean USE_FN_CACHE = false; //useful for larger scripts (due to O(n^2)) |
| private static final boolean ABORT_ON_FIRST_DEPENDENCY = true; |
| private static final boolean CONSERVATIVE_CHECK = false; //include FOR into dep analysis, reject unknown vars (otherwise use internal vars for whole row or column) |
| |
| public static final String INTERAL_FN_INDEX_ROW = "__ixr"; //pseudo index for range indexing row |
| public static final String INTERAL_FN_INDEX_COL = "__ixc"; //pseudo index for range indexing col |
| |
| private static final IDSequence _idSeq = new IDSequence(); |
| private static final IDSequence _idSeqfn = new IDSequence(); |
| |
| private static HashMap<String, LinearFunction> _fncache; //slower for most (small cases) cases |
| |
| //instance members |
| private final long _PID; |
| private VariableSet _vsParent = null; |
| private ArrayList<ResultVar> _resultVars = null; |
| private Bounds _bounds = null; |
| |
| static |
| { |
| // populate parameter name lookup-table |
| _paramNames = new HashSet<>(); |
| _paramNames.add( CHECK ); |
| _paramNames.add( PAR ); |
| _paramNames.add( TASK_PARTITIONER ); |
| _paramNames.add( TASK_SIZE ); |
| _paramNames.add( DATA_PARTITIONER ); |
| _paramNames.add( RESULT_MERGE ); |
| _paramNames.add( EXEC_MODE ); |
| _paramNames.add( OPT_MODE ); |
| _paramNames.add( PROFILE ); |
| _paramNames.add( OPT_LOG ); |
| |
| // populate defaults lookup-table |
| _paramDefaults = new HashMap<>(); |
| _paramDefaults.put( CHECK, "1" ); |
| _paramDefaults.put( PAR, String.valueOf(InfrastructureAnalyzer.getLocalParallelism()) ); |
| _paramDefaults.put( TASK_PARTITIONER, String.valueOf(PTaskPartitioner.FIXED) ); |
| _paramDefaults.put( TASK_SIZE, "1" ); |
| _paramDefaults.put( DATA_PARTITIONER, String.valueOf(PDataPartitioner.NONE) ); |
| _paramDefaults.put( RESULT_MERGE, String.valueOf(PResultMerge.LOCAL_AUTOMATIC) ); |
| _paramDefaults.put( EXEC_MODE, String.valueOf(PExecMode.LOCAL) ); |
| _paramDefaults.put( OPT_MODE, String.valueOf(POptMode.RULEBASED) ); |
| _paramDefaults.put( PROFILE, "0" ); |
| _paramDefaults.put( OPT_LOG, OptimizerUtils.getDefaultLogLevel().toString() ); |
| |
| _paramDefaults2 = new HashMap<>(); //OPT_MODE always specified |
| _paramDefaults2.put( CHECK, "1" ); |
| _paramDefaults2.put( PAR, "-1" ); |
| _paramDefaults2.put( TASK_PARTITIONER, String.valueOf(PTaskPartitioner.UNSPECIFIED) ); |
| _paramDefaults2.put( TASK_SIZE, "-1" ); |
| _paramDefaults2.put( DATA_PARTITIONER, String.valueOf(PDataPartitioner.UNSPECIFIED) ); |
| _paramDefaults2.put( RESULT_MERGE, String.valueOf(PResultMerge.UNSPECIFIED) ); |
| _paramDefaults2.put( EXEC_MODE, String.valueOf(PExecMode.UNSPECIFIED) ); |
| _paramDefaults2.put( PROFILE, "0" ); |
| _paramDefaults2.put( OPT_LOG, OptimizerUtils.getDefaultLogLevel().toString() ); |
| |
| //initialize function cache |
| if( USE_FN_CACHE ) { |
| _fncache = new HashMap<>(); |
| } |
| } |
| |
| public ParForStatementBlock() { |
| _PID = _idSeq.getNextID(); |
| _resultVars = new ArrayList<>(); |
| |
| LOG.trace("PARFOR("+_PID+"): ParForStatementBlock instance created"); |
| } |
| |
| public long getID() { |
| return _PID; |
| } |
| |
| public ArrayList<ResultVar> getResultVariables() { |
| return _resultVars; |
| } |
| |
| private void addToResultVariablesNoDup( String var, boolean accum ) { |
| addToResultVariablesNoDup(new ResultVar(var, accum)); |
| } |
| |
| private void addToResultVariablesNoDup( ResultVar var ) { |
| if( !_resultVars.contains( var ) ) |
| _resultVars.add( var ); |
| } |
| |
| @Override |
| public VariableSet validate(DMLProgram dmlProg, VariableSet ids, HashMap<String,ConstIdentifier> constVars, boolean conditional) |
| { |
| LOG.trace("PARFOR("+_PID+"): validating ParForStatementBlock."); |
| |
| //create parent variable set via cloning |
| _vsParent = new VariableSet( ids ); |
| |
| if(LOG.isTraceEnabled()) //note: A is matrix, and A[i,1] is scalar |
| for( DataIdentifier di : _vsParent.getVariables().values() ) |
| LOG.trace("PARFOR: non-local "+di._name+": "+di.getDataType().toString()+" with rowDim = "+di.getDim1()); |
| |
| //normal validate via ForStatement (sequential) |
| //NOTES: |
| // * validate/dependency checking of nested parfor-loops happens at this point |
| // * validate includes also constant propagation for from, to, incr expressions |
| // * this includes also function inlining |
| VariableSet vs = super.validate(dmlProg, ids, constVars, conditional); |
| |
| //check of correctness of specified parfor parameter names and |
| //set default parameter values for all not specified parameters |
| ParForStatement pfs = (ParForStatement) _statements.get(0); |
| IterablePredicate predicate = pfs.getIterablePredicate(); |
| HashMap<String, String> params = predicate.getParForParams(); |
| if( params != null ) //if parameter specified |
| { |
| //check for valid parameter types |
| for( String key : params.keySet() ) |
| if( !_paramNames.contains(key) ) //always unconditional |
| raiseValidateError("PARFOR: The specified parameter '"+key+"' is no valid parfor parameter.", false); |
| |
| //set defaults for all non-specified values |
| //(except if CONSTRAINT optimizer, in order to distinguish specified parameters) |
| boolean constrained = (params.containsKey( OPT_MODE ) |
| && params.get( OPT_MODE ).equalsIgnoreCase(POptMode.CONSTRAINED.name())); |
| for( String key : _paramNames ) |
| if( !params.containsKey(key) ) |
| { |
| if( constrained ) { |
| params.put(key, _paramDefaults2.get(key)); |
| } |
| else //default case |
| params.put(key, _paramDefaults.get(key)); |
| } |
| } |
| else { |
| //set all defaults |
| params = new HashMap<>(); |
| params.putAll( _paramDefaults ); |
| predicate.setParForParams(params); |
| } |
| |
| //start time measurement for normalization and dependency analysis |
| Timing time = new Timing(true); |
| |
| // LOOP DEPENDENCY ANALYSIS (test for dependency existence) |
| // no false negative guaranteed, but possibly false positives |
| |
| /* Basic intuition: WRITES to NON-local variables are only permitted iff |
| * - no data dep (no read other than own iteration w i < r j) |
| * - no anti dep (no read other than own iteration w i > r j) |
| * - no output dep (no write other than own iteration) |
| * |
| * ALGORITHM: |
| * 1) Determine candidates C (writes to non-local variables) |
| * 2) Prune all c from C where no dependencies --> C' |
| * 3) Raise an exception/warning if C' not the empty set |
| * |
| * RESTRICTIONS: |
| * - array subscripts of non-local variables must be linear functions of the form |
| * a0+ a1*i + ... + a2*j, where i and j are for or parfor indexes. |
| * - for and parfor increments must be integer values |
| * - only static (integer lower, upper bounds) range indexing |
| * - only input variables considered as potential candidates for checking |
| * |
| * (TODO: in order to remove the last restriction, dependencies must be checked again after |
| * live variable analysis against LIVEOUT) |
| * |
| * NOTE: validity is only checked during compilation, i.e., for dynamic from, to, incr MIN MAX values assumed. |
| */ |
| |
| LOG.trace("PARFOR: running loop dependency analysis ..."); |
| |
| //### Step 1 ###: determine candidate set C |
| HashSet<Candidate> C = new HashSet<>(); |
| HashSet<Candidate> C2 = new HashSet<>(); |
| Integer sCount = 0; //object for call by ref |
| rDetermineCandidates(pfs.getBody(), C, sCount); |
| if( LOG.isTraceEnabled() ) |
| for(Candidate c : C) |
| LOG.trace("PARFOR: dependency candidate: var '"+c._var+"' (accum="+c._isAccum+")"); |
| |
| boolean check = (Integer.parseInt(params.get(CHECK))==1); |
| if( check ) |
| { |
| //### Step 2 ###: prune c without dependencies |
| _bounds = new Bounds(); |
| for( FunctionStatementBlock fsb : dmlProg.getFunctionStatementBlocks() ) |
| rDetermineBounds( fsb, false ); //writes to _bounds |
| rDetermineBounds( dmlProg.getStatementBlocks(), false ); //writes to _bounds |
| |
| for( Candidate c : C ) |
| { |
| DataType cdt = _vsParent.getVariables().get(c._var).getDataType(); //might be different in DataIdentifier |
| |
| //assume no dependency |
| sCount = 0; |
| boolean[] dep = new boolean[]{false,false,false}; //output, data, anti |
| rCheckCandidates(c, cdt, pfs.getBody(), sCount, dep); |
| |
| if (LOG.isTraceEnabled()) { |
| if( dep[0] ) |
| LOG.trace("PARFOR: output dependency detected for var '"+c._var+"'."); |
| if( dep[1] ) |
| LOG.trace("PARFOR: data dependency detected for var '"+c._var+"'."); |
| if( dep[2] ) |
| LOG.trace("PARFOR: anti dependency detected for var '"+c._var+"'."); |
| } |
| |
| if( dep[0] || dep[1] || dep[2] ) { |
| C2.add(c); |
| if( ABORT_ON_FIRST_DEPENDENCY ) |
| break; |
| } |
| } |
| |
| |
| //### Step 3 ###: raise an exception / warning |
| if( C2.size() > 0 ) |
| { |
| LOG.trace("PARFOR: loop dependencies detected."); |
| |
| StringBuilder depVars = new StringBuilder(); |
| for( Candidate c : C2 ) { |
| if( depVars.length()>0 ) |
| depVars.append(", "); |
| depVars.append(c._var); |
| } |
| |
| //always unconditional (to ensure we always raise dependency issues) |
| raiseValidateError("PARFOR loop dependency analysis: " + |
| "inter-iteration (loop-carried) dependencies detected for variable(s): " + |
| depVars.toString() +". \n " + |
| "Please, ensure independence of iterations.", false); |
| } |
| else { |
| LOG.trace("PARFOR: no loop dependencies detected."); |
| } |
| } |
| else { |
| LOG.debug("INFO: PARFOR("+_PID+"): loop dependency analysis skipped."); |
| } |
| |
| //if successful, prepare result variables (all distinct vars in all candidates) |
| //a) add own candidates |
| for( Candidate var : C ) |
| if( check || var._dat.getDataType()!=DataType.SCALAR ) |
| addToResultVariablesNoDup( var._var, var._isAccum ); |
| //b) get and add child result vars (if required) |
| ArrayList<ResultVar> tmp = new ArrayList<>(); |
| rConsolidateResultVars(pfs.getBody(), tmp); |
| for( ResultVar var : tmp ) |
| if(_vsParent.containsVariable(var._name)) |
| addToResultVariablesNoDup(var); |
| if( LOG.isDebugEnabled() ) |
| for( ResultVar rvar : _resultVars ) |
| LOG.debug("INFO: PARFOR final result variable: "+rvar._name); |
| |
| //cleanup function cache in order to prevent side effects between parfor statements |
| if( USE_FN_CACHE ) |
| _fncache.clear(); |
| |
| LOG.debug("INFO: PARFOR("+_PID+"): validate successful (no dependencies) in "+time.stop()+"ms."); |
| |
| return vs; |
| } |
| |
| public List<String> getReadOnlyParentMatrixVars() { |
| VariableSet read = variablesRead(); |
| VariableSet updated = variablesUpdated(); |
| return liveIn().getVariableNames().stream() //read-only vars |
| .filter(var -> read.containsVariable(var) && !updated.containsVariable(var)) |
| .filter(var -> read.isMatrix(var)).collect(Collectors.toList()); |
| } |
| |
| /** |
| * Determines the PDataPartitioningFormat for read-only parent variables according |
| * to the access pattern of that variable within the parfor statement block. |
| * Row-wise or column wise partitioning is only suggested if we see pure row-wise or |
| * column-wise access patterns. |
| * |
| * @param var variables |
| * @return partition format |
| */ |
| public PartitionFormat determineDataPartitionFormat(String var) |
| { |
| PartitionFormat dpf = null; |
| List<PartitionFormat> dpfc = new LinkedList<>(); |
| |
| try |
| { |
| //determine partitioning candidates |
| ParForStatement dpfs = (ParForStatement) _statements.get(0); |
| rDeterminePartitioningCandidates(var, dpfs.getBody(), dpfc); |
| |
| //determine final solution |
| for( PartitionFormat tmp : dpfc ) |
| dpf = ( dpf!=null && !dpf.equals(tmp) ) ? //if no consensus |
| PartitionFormat.NONE : tmp; |
| if( dpf == null ) |
| dpf = PartitionFormat.NONE; |
| } |
| catch (LanguageException e) { |
| LOG.trace( "Unable to determine partitioning candidates.", e ); |
| dpf = PartitionFormat.NONE; |
| } |
| |
| return dpf; |
| } |
| |
| /** |
| * This method recursively determines candidates for output,data,anti dependencies. |
| * Candidates are defined as writes to non-local variables. |
| * |
| * @param asb list of statement blocks |
| * @param C set of candidates |
| * @param sCount statement count |
| */ |
| private void rDetermineCandidates(ArrayList<StatementBlock> asb, HashSet<Candidate> C, Integer sCount) |
| { |
| for(StatementBlock sb : asb ) // foreach statementblock in parforbody |
| for( Statement s : sb._statements ) // foreach statement in statement block |
| { |
| sCount++; |
| if( s instanceof ForStatement ) { //incl parfor |
| //despite separate dependency analysis for each nested parfor, we need to |
| //recursively check nested parfor as well in order to ensure correcteness |
| //of constantChecks with regard to outer indexes |
| rDetermineCandidates(((ForStatement)s).getBody(), C, sCount); |
| } |
| else if( s instanceof WhileStatement ) { |
| rDetermineCandidates(((WhileStatement)s).getBody(), C, sCount); |
| } |
| else if( s instanceof IfStatement ) { |
| rDetermineCandidates(((IfStatement)s).getIfBody(), C, sCount); |
| rDetermineCandidates(((IfStatement)s).getElseBody(), C, sCount); |
| } |
| else if( s instanceof FunctionStatement ) { |
| rDetermineCandidates(((FunctionStatement)s).getBody(), C, sCount); |
| } |
| else if( s instanceof PrintStatement && ((PrintStatement)s).getType() == PRINTTYPE.STOP ) { |
| raiseValidateError("PARFOR loop dependency analysis: " + |
| "stop() statement is not allowed inside a parfor loop body.", false); |
| } |
| else if( s instanceof PrintStatement && ((PrintStatement)s).getType() == PRINTTYPE.ASSERT ) { |
| raiseValidateError("PARFOR loop dependency analysis: " + |
| "assert() statement is not allowed inside a parfor loop body.", false); |
| } |
| else { |
| VariableSet vsUpdated = s.variablesUpdated(); |
| if( vsUpdated == null ) continue; |
| for(String write : vsUpdated.getVariableNames()) { |
| //add writes to non-local variables to candidate set |
| if( !_vsParent.containsVariable(write) ) continue; |
| List<DataIdentifier> dats = getDataIdentifiers( s, true ); |
| for( DataIdentifier dat : dats ) { |
| boolean accum = (s instanceof AssignmentStatement |
| && ((AssignmentStatement)s).isAccumulator()); |
| C.add( new Candidate(write, dat, accum) ); |
| } |
| } |
| } |
| } |
| } |
| |
| /** |
| * This method recursively determines partitioning candidates for input variables. |
| * Candidates are defined as index reads of non-local variables. |
| * |
| * @param var variables |
| * @param asb list of statement blocks |
| * @param C list of partition formats |
| */ |
| private void rDeterminePartitioningCandidates(String var, ArrayList<StatementBlock> asb, List<PartitionFormat> C) |
| { |
| for( StatementBlock sb : asb ) { |
| if( sb instanceof FunctionStatementBlock ) { |
| FunctionStatement fs = (FunctionStatement) sb.getStatement(0); |
| rDeterminePartitioningCandidates(var, fs.getBody(), C); |
| } |
| else if( sb instanceof ForStatementBlock ) { //incl parfor |
| ForStatementBlock fsb = (ForStatementBlock) sb; |
| ForStatement fs = (ForStatement) fsb.getStatement(0); |
| List<Hop> datsRead = new ArrayList<>(); |
| //predicate |
| rGetDataIdentifiers(resetVisitStatus(fsb.getFromHops()), datsRead); |
| rGetDataIdentifiers(resetVisitStatus(fsb.getToHops()), datsRead); |
| rGetDataIdentifiers(resetVisitStatus(fsb.getIncrementHops()), datsRead); |
| rDeterminePartitioningCandidates(var, datsRead, C); |
| //for / parfor body |
| rDeterminePartitioningCandidates(var, fs.getBody(), C); |
| } |
| else if( sb instanceof WhileStatementBlock ) { |
| WhileStatementBlock wsb = (WhileStatementBlock) sb; |
| WhileStatement ws = (WhileStatement) wsb.getStatement(0); |
| List<Hop> datsRead = new ArrayList<>(); |
| //predicate |
| rGetDataIdentifiers(resetVisitStatus(wsb.getPredicateHops()), datsRead); |
| rDeterminePartitioningCandidates(var, datsRead, C); |
| //while body |
| rDeterminePartitioningCandidates(var, ws.getBody(), C); |
| } |
| else if( sb instanceof IfStatementBlock ) { |
| IfStatementBlock isb = (IfStatementBlock) sb; |
| IfStatement is = (IfStatement) isb.getStatement(0); |
| List<Hop> datsRead = new ArrayList<>(); |
| //predicate |
| rGetDataIdentifiers(resetVisitStatus(isb.getPredicateHops()), datsRead); |
| rDeterminePartitioningCandidates(var, datsRead, C); |
| //if and else branch |
| rDeterminePartitioningCandidates(var, is.getIfBody(), C); |
| rDeterminePartitioningCandidates(var, is.getElseBody(), C); |
| } |
| else if( sb.getHops() != null ) { |
| Hop.resetVisitStatus(sb.getHops()); |
| List<Hop> datsRead = new ArrayList<>(); |
| for( Hop root : sb.getHops() ) |
| rGetDataIdentifiers(root, datsRead); |
| rDeterminePartitioningCandidates(var, datsRead, C); |
| } |
| } |
| } |
| |
| private void rDeterminePartitioningCandidates(String var, List<Hop> datsRead, List<PartitionFormat> C) { |
| if( datsRead == null ) |
| return; |
| for(Hop read : datsRead) { |
| if( read instanceof IndexingOp && var.equals( read.getInput().get(0).getName() ) ) |
| C.add( determineAccessPattern((IndexingOp) read) ); |
| else if( HopRewriteUtils.isData(read, OpOpData.TRANSIENTREAD) && var.equals(read.getName()) ) |
| C.add( PartitionFormat.NONE ); |
| } |
| } |
| |
| private static Hop resetVisitStatus(Hop hop) { |
| return hop == null ? hop : |
| hop.resetVisitStatus(); |
| } |
| |
| private PartitionFormat determineAccessPattern( IndexingOp rix ) { |
| boolean isSpark = OptimizerUtils.isSparkExecutionMode(); |
| int blksz = ConfigurationManager.getBlocksize(); |
| PartitionFormat dpf = null; |
| |
| //1) get all bounds expressions for index access |
| Hop rowL = rix.getInput().get(1); |
| Hop rowU = rix.getInput().get(2); |
| Hop colL = rix.getInput().get(3); |
| Hop colU = rix.getInput().get(4); |
| |
| try { |
| //2) decided on access pattern |
| //COLUMN_WISE if all rows and access to single column |
| if( rix.isAllRows() && colL == colU ) { |
| dpf = PartitionFormat.COLUMN_WISE; |
| } |
| //ROW_WISE if all cols and access to single row |
| else if( rix.isAllCols() && rowL == rowU ) { |
| dpf = PartitionFormat.ROW_WISE; |
| } |
| //COLUMN_BLOCK_WISE |
| else if( isSpark && rix.isAllRows() && colL != colU ) { |
| LinearFunction l1 = getLinearFunction(colL, true); |
| LinearFunction l2 = getLinearFunction(colU, true); |
| dpf = !isAlignedBlocking(l1, l2, blksz) ? PartitionFormat.NONE : |
| new PartitionFormat(PDataPartitionFormat.COLUMN_BLOCK_WISE_N, (int)l1._b[0]); |
| } |
| //ROW_BLOCK_WISE |
| else if( isSpark && rix.isAllCols() && rowL != rowU ) { |
| LinearFunction l1 = getLinearFunction(rowL, true); |
| LinearFunction l2 = getLinearFunction(rowU, true); |
| dpf = !isAlignedBlocking(l1, l2, blksz) ? PartitionFormat.NONE : |
| new PartitionFormat(PDataPartitionFormat.ROW_BLOCK_WISE_N, (int)l1._b[0]); |
| } |
| //NONE otherwise (conservative) |
| else |
| dpf = PartitionFormat.NONE; |
| } |
| catch(Exception ex) { |
| throw new RuntimeException(ex); |
| } |
| return dpf; |
| } |
| |
| private static boolean isAlignedBlocking(LinearFunction l1, LinearFunction l2, int blksz) { |
| return (l1!=null && l2!=null && l1.equalSlope(l2) //same slope |
| && l1._b.length==1 && l1._b[0]<=blksz //single index and block |
| && (l2._a - l1._a + 1 == l1._b[0]) //intercept difference is slope |
| && (blksz/l1._b[0])*l1._b[0] == blksz //aligned slope |
| && l2.eval(1L) == l1._b[0] ); //aligned intercept |
| } |
| |
| private void rConsolidateResultVars(ArrayList<StatementBlock> asb, ArrayList<ResultVar> vars) |
| { |
| for(StatementBlock sb : asb ) // foreach statementblock in parforbody |
| { |
| if( sb instanceof ParForStatementBlock ) |
| vars.addAll(((ParForStatementBlock)sb).getResultVariables()); |
| |
| for( Statement s : sb._statements ) { |
| if( s instanceof ForStatement || s instanceof ParForStatement ) |
| rConsolidateResultVars(((ForStatement)s).getBody(), vars); |
| else if( s instanceof WhileStatement ) |
| rConsolidateResultVars(((WhileStatement)s).getBody(), vars); |
| else if( s instanceof IfStatement ) { |
| rConsolidateResultVars(((IfStatement)s).getIfBody(), vars); |
| rConsolidateResultVars(((IfStatement)s).getElseBody(), vars); |
| } |
| else if( s instanceof FunctionStatement ) |
| rConsolidateResultVars(((FunctionStatement)s).getBody(), vars); |
| } |
| } |
| } |
| |
| /** |
| * This method recursively checks a candidate against StatementBlocks for anti, data and output dependencies. |
| * A LanguageException is raised if at least one dependency is found, where it is guaranteed that no false negatives |
| * (undetected dependency) but potentially false positives (misdetected dependency) can appear. |
| * |
| * |
| * @param c candidate |
| * @param cdt candidate data type |
| * @param asb list of statement blocks |
| * @param sCount statement count |
| * @param dep array of boolean potential output dependencies |
| */ |
| private void rCheckCandidates(Candidate c, DataType cdt, |
| ArrayList<StatementBlock> asb, Integer sCount, boolean[] dep) |
| { |
| // check candidate only (output dependency if scalar or constant matrix subscript) |
| if( cdt == DataType.SCALAR |
| || cdt == DataType.UNKNOWN ) //dat2 checked for other candidate |
| { |
| //every write to a scalar or complete data object is an output dependency |
| dep[0] = true; |
| if( ABORT_ON_FIRST_DEPENDENCY ) |
| return; |
| } |
| else if( cdt == DataType.MATRIX ) |
| { |
| if( runConstantCheck(c._dat) && !c._isAccum ) { |
| if( LOG.isTraceEnabled() ) |
| LOG.trace("PARFOR: Possible output dependency detected via constant self-check: var '"+c._var+"'."); |
| dep[0] = true; |
| if( ABORT_ON_FIRST_DEPENDENCY ) |
| return; |
| } |
| } |
| |
| // check candidate against all statements |
| for(StatementBlock sb : asb ) |
| for( Statement s : sb._statements ) |
| { |
| sCount++; |
| if( s instanceof ForStatement ) { //incl parfor |
| //despite separate dependency analysis for each nested parfor, we need to |
| //recursively check nested parfor as well in order to ensure correcteness |
| //of constantChecks with regard to outer indexes |
| rCheckCandidates(c, cdt, ((ForStatement)s).getBody(), sCount, dep); |
| } |
| else if( s instanceof WhileStatement ) { |
| rCheckCandidates(c, cdt, ((WhileStatement)s).getBody(), sCount, dep); |
| } |
| else if( s instanceof IfStatement ) { |
| rCheckCandidates(c, cdt, ((IfStatement)s).getIfBody(), sCount, dep); |
| rCheckCandidates(c, cdt, ((IfStatement)s).getElseBody(), sCount, dep); |
| } |
| else if( s instanceof FunctionStatement ) { |
| rCheckCandidates(c, cdt, ((FunctionStatement)s).getBody(), sCount, dep); |
| } |
| else { |
| //CHECK output dependencies |
| List<DataIdentifier> datsUpdated = getDataIdentifiers(s, true); |
| if( datsUpdated != null ) { |
| for(DataIdentifier write : datsUpdated) { |
| if( !c._var.equals( write.getName() ) ) continue; |
| |
| if( cdt != DataType.MATRIX && cdt != DataType.LIST ) { |
| //cannot infer type, need to exit (conservative approach) |
| throw new LanguageException("PARFOR loop dependency analysis: cannot check " |
| + "for dependencies due to unknown datatype of var '"+c._var+"': "+cdt.name()+"."); |
| } |
| |
| DataIdentifier dat2 = write; |
| if( c._dat == dat2 ) continue; //skip self-check |
| if( runEqualsCheck(c._dat, dat2) ) { |
| //intra-iteration output dependencies (same index function) are OK |
| } |
| else if(runBanerjeeGCDTest( c._dat, dat2 )) { |
| LOG.trace("PARFOR: Possible output dependency detected via GCD/Banerjee: var '"+write+"'."); |
| dep[0] = true; |
| if( ABORT_ON_FIRST_DEPENDENCY ) |
| return; |
| } |
| } |
| } |
| |
| List<DataIdentifier> datsRead = getDataIdentifiers(s, false); |
| if( datsRead == null ) continue; |
| |
| //check data and anti dependencies |
| for(DataIdentifier read : datsRead) |
| { |
| if( !c._var.equals( read.getName() ) ) continue; |
| DataIdentifier dat2 = read; |
| DataType dat2dt = _vsParent.getVariables().get(read.getName()).getDataType(); |
| |
| if( cdt == DataType.SCALAR || cdt == DataType.UNKNOWN |
| || dat2dt == DataType.SCALAR || dat2dt == DataType.UNKNOWN ) |
| { |
| //every write, read combination involving a scalar is a data dependency |
| dep[1] = true; |
| if( ABORT_ON_FIRST_DEPENDENCY ) |
| return; |
| } |
| else if( (cdt == DataType.MATRIX && dat2dt == DataType.MATRIX) |
| || (cdt == DataType.LIST && dat2dt == DataType.LIST ) ) |
| { |
| boolean invalid = false; |
| if( runEqualsCheck(c._dat, dat2) ) |
| //read/write on same index, and not constant (checked for output) is OK |
| invalid = runConstantCheck(dat2); |
| else if( runBanerjeeGCDTest( c._dat, dat2 ) ) |
| invalid = true; |
| else if( !(dat2 instanceof IndexedIdentifier) ) |
| //non-indexed access to candidate result variable -> always a dependency |
| invalid = true; |
| |
| if( invalid ) { |
| LOG.trace("PARFOR: Possible data/anti dependency detected via GCD/Banerjee: var '"+read+"'."); |
| dep[1] = true; |
| dep[2] = true; |
| if( ABORT_ON_FIRST_DEPENDENCY ) |
| return; |
| } |
| } |
| else { //if( c._dat.getDataType() == DataType.UNKNOWN ) |
| //cannot infer type, need to exit (conservative approach) |
| throw new LanguageException("PARFOR loop dependency analysis: cannot check " |
| + "for dependencies due to unknown datatype of var '"+c._var+"': "+cdt.name()+"."); |
| } |
| } |
| } |
| } |
| } |
| |
| /** |
| * Get all target/source DataIdentifiers of the given statement. |
| * |
| * @param s statement |
| * @param target if true, get targets |
| * @return list of data identifiers |
| */ |
| private List<DataIdentifier> getDataIdentifiers(Statement s, boolean target) |
| { |
| List<DataIdentifier> ret = null; |
| |
| if( s instanceof AssignmentStatement ) { |
| AssignmentStatement s2 = (AssignmentStatement)s; |
| ret = target ? s2.getTargetList() : |
| rGetDataIdentifiers(s2.getSource()); |
| } |
| else if (s instanceof FunctionStatement) { |
| FunctionStatement s2 = (FunctionStatement)s; |
| ret = target ? s2.getOutputParams() : |
| s2.getInputParams(); |
| } |
| else if (s instanceof MultiAssignmentStatement) { |
| MultiAssignmentStatement s2 = (MultiAssignmentStatement)s; |
| ret = target ? s2.getTargetList() : |
| rGetDataIdentifiers(s2.getSource()); |
| } |
| else if (s instanceof PrintStatement) { |
| PrintStatement s2 = (PrintStatement)s; |
| ret = new ArrayList<>(); |
| for (Expression expression : s2.getExpressions()) |
| ret.addAll(rGetDataIdentifiers(expression)); |
| } |
| |
| //potentially extend this list with other Statements if required |
| //(e.g., IOStatement, RandStatement) |
| |
| return ret; |
| } |
| |
| private boolean isRowIgnorable(IndexedIdentifier dat1, IndexedIdentifier dat2) { |
| for( IndexedIdentifier dat : new IndexedIdentifier[]{dat1,dat2} ) |
| if( !checkLower(dat1.getRowLowerBound(), dat.getRowLowerBound(), INTERAL_FN_INDEX_ROW) |
| || !checkLower(dat1.getRowUpperBound(), dat.getRowUpperBound(), INTERAL_FN_INDEX_ROW) ) |
| return false; |
| return true; |
| } |
| |
| private boolean isColumnIgnorable(IndexedIdentifier dat1, IndexedIdentifier dat2) { |
| for( IndexedIdentifier dat : new IndexedIdentifier[]{dat1,dat2} ) |
| if( !checkLower(dat1.getColLowerBound(), dat.getColLowerBound(), INTERAL_FN_INDEX_COL) |
| || !checkLower(dat1.getColUpperBound(), dat.getColUpperBound(), INTERAL_FN_INDEX_COL) ) |
| return false; |
| return true; |
| } |
| |
| private boolean checkLower(Expression expr1, Expression expr2, String ix) { |
| if( expr1 != null ) |
| for( DataIdentifier datsub : rGetDataIdentifiers(expr2) ) |
| if( _bounds._lower.containsKey(datsub.getName()) && !datsub.getName().startsWith(ix) ) |
| return false; |
| return true; |
| } |
| |
| private List<DataIdentifier> rGetDataIdentifiers(Expression e) |
| { |
| List<DataIdentifier> ret = new ArrayList<>(); |
| |
| if( e instanceof DataIdentifier && !(e instanceof FunctionCallIdentifier |
| || e instanceof BuiltinFunctionExpression || e instanceof ParameterizedBuiltinFunctionExpression) ) { |
| ret.add( (DataIdentifier)e ); |
| } |
| else if( e instanceof FunctionCallIdentifier ) { |
| FunctionCallIdentifier fci = (FunctionCallIdentifier)e; |
| for( ParameterExpression ee : fci.getParamExprs() ) |
| ret.addAll(rGetDataIdentifiers( ee.getExpr() )); |
| } |
| else if(e instanceof BinaryExpression) { |
| BinaryExpression be = (BinaryExpression) e; |
| ret.addAll( rGetDataIdentifiers(be.getLeft()) ); |
| ret.addAll( rGetDataIdentifiers(be.getRight()) ); |
| } |
| else if(e instanceof BooleanExpression) { |
| BooleanExpression be = (BooleanExpression) e; |
| ret.addAll( rGetDataIdentifiers(be.getLeft()) ); |
| ret.addAll( rGetDataIdentifiers(be.getRight()) ); |
| } |
| else if(e instanceof BuiltinFunctionExpression) { |
| BuiltinFunctionExpression be = (BuiltinFunctionExpression) e; |
| //disregard meta data ops nrow/ncol (to exclude from candidates) |
| if( !((be.getOpCode() == Builtins.NROW || be.getOpCode() == Builtins.NCOL) |
| && be.getFirstExpr() instanceof DataIdentifier) ) { |
| ret.addAll( rGetDataIdentifiers(be.getFirstExpr()) ); |
| ret.addAll( rGetDataIdentifiers(be.getSecondExpr()) ); |
| ret.addAll( rGetDataIdentifiers(be.getThirdExpr()) ); |
| } |
| } |
| else if(e instanceof ParameterizedBuiltinFunctionExpression) { |
| ParameterizedBuiltinFunctionExpression be = (ParameterizedBuiltinFunctionExpression) e; |
| for( Expression ee : be.getVarParams().values() ) |
| ret.addAll( rGetDataIdentifiers(ee) ); |
| } |
| else if(e instanceof RelationalExpression) { |
| RelationalExpression re = (RelationalExpression) e; |
| ret.addAll( rGetDataIdentifiers(re.getLeft()) ); |
| ret.addAll( rGetDataIdentifiers(re.getRight()) ); |
| } |
| |
| return ret; |
| } |
| |
| private List<Hop> rGetDataIdentifiers(Hop root, List<Hop> direads) { |
| if( root == null || root.isVisited() ) |
| return direads; |
| //process children recursively (but disregard meta data ops and indexing) |
| if( !((HopRewriteUtils.isUnary(root, OpOp1.NROW, OpOp1.NCOL) |
| && isDataIdentifier(root.getInput().get(0))) || isDataIdentifier(root)) ) { |
| for( Hop c : root.getInput() ) |
| rGetDataIdentifiers(c, direads); |
| } |
| //handle transient read and right indexing over transient read |
| if( isDataIdentifier(root) ) |
| direads.add(root); |
| root.setVisited(); |
| return direads; |
| } |
| |
| private static boolean isDataIdentifier(Hop hop) { |
| return HopRewriteUtils.isData(hop, OpOpData.TRANSIENTREAD) |
| || (hop instanceof IndexingOp && HopRewriteUtils.isData( |
| hop.getInput().get(0), OpOpData.TRANSIENTREAD)) |
| || hop instanceof LiteralOp; |
| } |
| |
| private void rDetermineBounds( ArrayList<StatementBlock> sbs, boolean flag ) { |
| for( StatementBlock sb : sbs ) |
| rDetermineBounds(sb, flag); |
| } |
| |
| /** |
| * Determines the lower/upper bounds of all nested for/parfor indexes. |
| * |
| * @param sb statement block |
| * @param flag indicates that method is already in subtree of THIS. |
| */ |
| private void rDetermineBounds( StatementBlock sb, boolean flag ) |
| { |
| // catch all known for/ parfor bounds |
| // (all unknown bounds are assumed to be +-infinity) |
| |
| for( Statement s : sb._statements ) |
| { |
| boolean lFlag = flag; |
| if( s instanceof ParForStatement || (s instanceof ForStatement && CONSERVATIVE_CHECK) ) //incl. for if conservative |
| { |
| ForStatement fs = (ForStatement)s; |
| IterablePredicate ip = fs._predicate; |
| |
| //checks for position in overall tree |
| if( sb==this ) |
| lFlag = true; |
| |
| if( lFlag || rIsParent(sb,this) ) //add only if in subtree of this |
| { |
| //check for internal names |
| if( ip.getIterVar()._name.equals( INTERAL_FN_INDEX_ROW ) |
| || ip.getIterVar()._name.equals( INTERAL_FN_INDEX_COL )) |
| { |
| |
| throw new LanguageException(" The iteration variable must not use the " + |
| "internal iteration variable name prefix '"+ip.getIterVar()._name+"'."); |
| } |
| |
| long low = Integer.MIN_VALUE; |
| long up = Integer.MAX_VALUE; |
| long incr = -1; |
| |
| if( ip.getFromExpr()instanceof IntIdentifier) |
| low = ((IntIdentifier)ip.getFromExpr()).getValue(); |
| if( ip.getToExpr()instanceof IntIdentifier) |
| up = ((IntIdentifier)ip.getToExpr()).getValue(); |
| |
| //NOTE: conservative approach: include all index variables (also from for) |
| if( ip.getIncrementExpr() instanceof IntIdentifier ) |
| incr = ((IntIdentifier)ip.getIncrementExpr()).getValue(); |
| else |
| incr = ( low <= up ) ? 1 : -1; |
| |
| _bounds._lower.put(ip.getIterVar()._name, low); |
| _bounds._upper.put(ip.getIterVar()._name, up); |
| _bounds._increment.put(ip.getIterVar()._name, incr); |
| if( lFlag ) //if local (required for constant check) |
| _bounds._local.add(ip.getIterVar()._name); |
| } |
| |
| //recursive invocation (but not for nested parfors due to constant check) |
| if( !lFlag ) |
| if( fs.getBody() != null ) |
| rDetermineBounds(fs.getBody(), lFlag); |
| } |
| else if( s instanceof ForStatement ) { |
| ArrayList<StatementBlock> tmp = ((ForStatement) s).getBody(); |
| if( tmp != null ) |
| rDetermineBounds(tmp, lFlag); |
| } |
| else if( s instanceof WhileStatement ) { |
| ArrayList<StatementBlock> tmp = ((WhileStatement) s).getBody(); |
| if( tmp != null ) |
| rDetermineBounds(tmp, lFlag); |
| } |
| else if( s instanceof IfStatement ) { |
| ArrayList<StatementBlock> tmp = ((IfStatement) s).getIfBody(); |
| if( tmp != null ) |
| rDetermineBounds(tmp, lFlag); |
| ArrayList<StatementBlock> tmp2 = ((IfStatement) s).getElseBody(); |
| if( tmp2 != null ) |
| rDetermineBounds(tmp2, lFlag); |
| } |
| else if( s instanceof FunctionStatement ) { |
| ArrayList<StatementBlock> tmp = ((FunctionStatement) s).getBody(); |
| if( tmp != null ) |
| rDetermineBounds(tmp, lFlag); |
| } |
| } |
| } |
| |
| private boolean rIsParent( ArrayList<StatementBlock> cParent, StatementBlock cChild) { |
| return cParent.stream().anyMatch(sb -> rIsParent(sb, cChild)); |
| } |
| |
| private boolean rIsParent( StatementBlock cParent, StatementBlock cChild) |
| { |
| |
| if( cParent == cChild ) |
| return true; |
| |
| boolean ret = false; |
| for( Statement s : cParent.getStatements() ) { |
| //check all the complex control flow constructs |
| if( s instanceof ForStatement ) //for, parfor |
| ret = rIsParent( ((ForStatement) s).getBody(), cChild ); |
| else if( s instanceof WhileStatement ) |
| ret = rIsParent( ((WhileStatement) s).getBody(), cChild ); |
| else if( s instanceof IfStatement ) { |
| ret = rIsParent( ((IfStatement) s).getIfBody(), cChild ); |
| ret |= rIsParent( ((IfStatement) s).getElseBody(), cChild ); |
| } |
| |
| //early return if already found |
| if( ret ) break; |
| } |
| |
| return ret; |
| } |
| |
| /** |
| * Runs a combination of GCD and Banerjee test for a two potentially conflicting |
| * data identifiers. See below for a detailed explanation. |
| * |
| * NOTE: simply enumerating all combinations of iteration variable values and probing for |
| * duplicates is not applicable due to (1) arbitrary nested program blocks with potentially |
| * dynamic lower, upper, and increment expressions, and (2) therefore potentially large |
| * overheads in the general case. |
| * |
| * @param dat1 data identifier 1 |
| * @param dat2 data identifier 2 |
| * @return true if "anti or data dependency" |
| */ |
| private boolean runBanerjeeGCDTest(DataIdentifier dat1, DataIdentifier dat2) |
| { |
| /* The GCD (greatest common denominator) and the Banerjee test are two commonly used tests |
| * for determining loop-carried dependencies. Both rely on (1) linear index expressions of the |
| * form y = a + bx, where x is the loop index variable, and (2) conservative approaches that |
| * guarantee no false negatives (no missed dependencies) but possibly false positives. The GCD |
| * test probes for integer solutions without bounds, while the Banerjee test probes for real |
| * solutions with bounds. |
| * |
| * We use a combination of both: |
| * - the GCD test checks if dependencies are possible |
| * - the Banerjee test checks if those dependencies may arise within the given bounds |
| * |
| * NOTES: |
| * - #1 possible false positives may arise if there is a real solution within the bounds |
| * and an integer solution outside the bounds. This will lead to a detected dependencies |
| * although no integer solution within the bounds exists. |
| * - #2 for the sake of simplicity, we do not distinguish between anti and data dependencies, |
| * although possible in general |
| * - more advanced tests than GCD and Banerjee available (e.g., with symbolic checking for |
| * non-linear functions) but this is a tradeoff between number of false positives and overhead |
| */ |
| |
| LOG.trace("PARFOR: runBanerjeeGCDCheck."); |
| |
| boolean ret = true; //anti or data dependency |
| |
| //Step 1: analyze index expressions and transform them into linear functions |
| LinearFunction f1 = getLinearFunction(dat1); |
| LinearFunction f2 = getLinearFunction(dat2); |
| forceConsistency(f1,f2); |
| |
| LOG.trace("PARFOR: f1: " + f1.toString()); |
| LOG.trace("PARFOR: f2: " + f2.toString()); |
| |
| /////// |
| //Step 2: run GCD Test |
| /////// |
| long lgcd = f1._b[0]; |
| for( int i=1; i<f1._b.length; i++ ) |
| lgcd = determineGCD( lgcd, f1._b[i] ); |
| for( int i=0; i<f2._b.length; i++ ) |
| lgcd = determineGCD( lgcd, f2._b[i] ); |
| |
| if( (Math.abs(f1._a-f2._a) % lgcd) != 0 ) { //if GCD divides the intercepts |
| //no integer solution exists -> no dependency |
| ret = false; |
| } |
| |
| LOG.trace("PARFOR: GCD result: "+ret); |
| |
| if( !CONSERVATIVE_CHECK && ret ) //only if not already no dependency |
| { |
| //NOTE: cases both and none negligible already covered (constant check, general case) |
| boolean ixid = (dat1 instanceof IndexedIdentifier && dat2 instanceof IndexedIdentifier); |
| boolean ignoreRow = ixid && isRowIgnorable((IndexedIdentifier)dat1, (IndexedIdentifier)dat2); |
| boolean ignoreCol = ixid && isColumnIgnorable((IndexedIdentifier)dat1, (IndexedIdentifier)dat2); |
| |
| LinearFunction f1p = null, f2p = null; |
| if( ignoreRow ) { |
| f1p = getColLinearFunction(dat1); |
| f2p = getColLinearFunction(dat2); |
| } |
| if( ignoreCol ) { |
| f1p = getRowLinearFunction(dat1); |
| f2p = getRowLinearFunction(dat2); |
| } |
| |
| LOG.trace("PARFOR: f1p: "+((f1p==null)?"null":f1p.toString())); |
| LOG.trace("PARFOR: f2p: "+((f2p==null)?"null":f2p.toString())); |
| |
| if( f1p!=null && f2p!=null ) |
| { |
| forceConsistency(f1p, f2p); |
| |
| long lgcd2 = f1p._b[0]; |
| for( int i=1; i<f1p._b.length; i++ ) |
| lgcd2 = determineGCD( lgcd2, f1p._b[i] ); |
| for( int i=0; i<f2p._b.length; i++ ) |
| lgcd2 = determineGCD( lgcd2, f2p._b[i] ); |
| |
| if( (Math.abs(f1p._a-f2p._a) % lgcd2) != 0 ) { //if GCD divides the intercepts |
| //no integer solution exists -> no dependency |
| ret = false; |
| } |
| |
| LOG.trace("PARFOR: GCD result: "+ret); |
| } |
| } |
| |
| |
| /////// |
| //Step 3: run Banerjee Test |
| /////// |
| if( ret ) //only if GCD found possible dependencies |
| { |
| //determining anti/data dependencies |
| long lintercept = f2._a - f1._a; |
| long lmax=0; |
| long lmin=0; |
| |
| //min/max bound |
| int len = Math.max(f1._b.length, f2._b.length); |
| boolean invalid = false; |
| for( int i=0; i<len; i++ ) |
| { |
| String var=(f1._b.length>i) ? f1._vars[i] : f2._vars[i]; |
| if( !_bounds._lower.containsKey(var) || !_bounds._upper.containsKey(var) ) { |
| invalid = true; break; |
| } |
| |
| //get lower and upper bound for specific var or internal var |
| long lower = _bounds._lower.get(var); //bounds equal for f1 and f2 |
| long upper = _bounds._upper.get(var); |
| |
| //max bound |
| if( f1._b.length>i ) |
| lmax += (f1._b[i]>0) ? f1._b[i]*upper : f1._b[i]*lower; |
| if( f2._b.length>i ) |
| lmax -= (f2._b[i]>0) ? f2._b[i]*lower : f2._b[i]*upper; |
| |
| //min bound (unequal indexes) |
| if( f1._b.length>i ) |
| lmin += (f1._b[i]>0) ? f1._b[i]*lower : f1._b[i]*upper; |
| if( f2._b.length>i ) |
| lmin -= (f2._b[i]>0) ? f2._b[i]*upper : f2._b[i]*lower; |
| } |
| |
| if( LOG.isTraceEnabled() ) |
| LOG.trace("PARFOR: Banerjee lintercept=" + lintercept+", lmax="+lmax+", lmin="+lmin+", invalid="+invalid); |
| |
| if( !invalid && (!(lmin <= lintercept && lintercept <= lmax) || lmin==lmax) ) { |
| //dependency not within the bounds of the arrays |
| ret = false; |
| } |
| |
| LOG.trace("PARFOR: Banerjee result: "+ret); |
| } |
| |
| return ret; |
| } |
| |
| /** |
| * Runs a constant check for a single data identifier (target of assignment). If constant, then every |
| * iteration writes to the same cell. |
| * |
| * @param dat1 data identifier |
| * @return true if dependency |
| */ |
| private boolean runConstantCheck(DataIdentifier dat1) |
| { |
| LOG.trace("PARFOR: runConstantCheck."); |
| |
| boolean ret = true; //data dependency to itself |
| LinearFunction f1 = getLinearFunction(dat1); |
| if( f1 == null ) |
| return true; //dependency |
| |
| LOG.trace("PARFOR: f1: "+f1.toString()); |
| |
| // no output dependency to itself if no index access will happen twice |
| // hence we check for: (all surrounding indexes are used by f1 and all intercepts != 0 ) |
| boolean gcheck=true; |
| for( String var : _bounds._local ) //check only local, nested checked from parent |
| { |
| if( var.startsWith(INTERAL_FN_INDEX_ROW) |
| || var.startsWith(INTERAL_FN_INDEX_COL)) |
| { |
| continue; //skip internal vars for range indexing |
| } |
| |
| boolean lcheck = false; |
| for( int i=0; i<f1._vars.length; i++ ) |
| if( var.equals(f1._vars[i]) ) |
| if( f1._b[i] != 0 ) |
| lcheck = true; |
| if( !lcheck ) |
| { |
| gcheck=false; |
| break; |
| } |
| } |
| |
| if( gcheck ) // output dependencies impossible |
| ret = false; |
| |
| return ret; |
| } |
| |
| /** |
| * Runs an equality check for two data identifiers. If equal, there there are no |
| * inter-iteration (loop-carried) but only intra-iteration dependencies. |
| * |
| * @param dat1 data identifier 1 |
| * @param dat2 data identifier 2 |
| * @return true if equal data identifiers |
| */ |
| private boolean runEqualsCheck(DataIdentifier dat1, DataIdentifier dat2) |
| { |
| LOG.trace("PARFOR: runEqualsCheck."); |
| |
| //check if both data identifiers of same type |
| if(dat1 instanceof IndexedIdentifier != dat2 instanceof IndexedIdentifier) |
| return false; |
| |
| //general case function comparison |
| boolean ret = true; //true if equal index functions |
| LinearFunction f1 = getLinearFunction(dat1); |
| LinearFunction f2 = getLinearFunction(dat2); |
| forceConsistency(f1, f2); |
| ret = f1.equals(f2); |
| |
| LOG.trace("PARFOR: f1: " + f1.toString()); |
| LOG.trace("PARFOR: f2: " + f2.toString()); |
| LOG.trace("PARFOR: (f1==f2): " + ret); |
| |
| //additional check if cols/rows could be ignored |
| if( !CONSERVATIVE_CHECK && !ret ) //only if not already equal |
| { |
| //NOTE: cases both and none negligible already covered (constant check, general case) |
| boolean ixid = (dat1 instanceof IndexedIdentifier && dat2 instanceof IndexedIdentifier); |
| boolean ignoreRow = ixid && isRowIgnorable((IndexedIdentifier)dat1, (IndexedIdentifier)dat2); |
| boolean ignoreCol = ixid && isColumnIgnorable((IndexedIdentifier)dat1, (IndexedIdentifier)dat2); |
| |
| LinearFunction f1p = null, f2p = null; |
| if( ignoreRow ) { |
| f1p = getColLinearFunction(dat1); |
| f2p = getColLinearFunction(dat2); |
| } |
| if( ignoreCol ) { |
| f1p = getRowLinearFunction(dat1); |
| f2p = getRowLinearFunction(dat2); |
| } |
| |
| if( f1p!=null && f2p!=null ) { |
| forceConsistency(f1p, f2p); |
| ret = f1p.equals(f2p); |
| |
| LOG.trace("PARFOR: f1p: " + f1p.toString()); |
| LOG.trace("PARFOR: f2p: " + f2p.toString()); |
| LOG.trace("PARFOR: (f1p==f2p): " + ret); |
| } |
| } |
| |
| return ret; |
| } |
| |
| /** |
| * This is the Euclid's algorithm for GCD (greatest common denominator), |
| * required for the GCD test. |
| * |
| * @param a first value |
| * @param b second value |
| * @return greatest common denominator |
| */ |
| private long determineGCD(long a, long b) { |
| return (b==0) ? a : determineGCD(b, a % b); |
| } |
| |
| /** |
| * Creates or reuses a linear function for a given data identifier, where identifiers with equal |
| * names and matrix subscripts result in exactly the same linear function. |
| * |
| * @param dat data identifier |
| * @return linear function |
| */ |
| private LinearFunction getLinearFunction(DataIdentifier dat) |
| { |
| /* Notes: |
| * - Currently, this function supports 2dim matrix subscripts with arbitrary linear functions |
| * however, this could be extended to d-dim if necessary |
| * - Trick for range indexing: introduce a pseudo index variable with lower and upper according to |
| * the index range (e.g., [1:4,...]) or matrix dimensionality (e.g., [:,...]). This allows us to |
| * apply existing tests even for range indexing (multi-value instead of single-value functions) |
| */ |
| |
| LinearFunction out = null; |
| |
| if( ! (dat instanceof IndexedIdentifier ) ) //happens if matrix is now used as scalar |
| return new LinearFunction(0,0,dat.getName()); |
| |
| IndexedIdentifier idat = (IndexedIdentifier) dat; |
| |
| if( USE_FN_CACHE ) { |
| out = _fncache.get( getFunctionID(idat) ); |
| if( out != null ) |
| return out; |
| } |
| |
| Expression sub1 = idat.getRowLowerBound(); |
| Expression sub2 = idat.getColLowerBound(); |
| |
| //parse row expressions |
| try |
| { |
| //loop index or constant (default case) |
| if( idat.getRowLowerBound()!=null && idat.getRowUpperBound()!=null && |
| idat.getRowLowerBound() == idat.getRowUpperBound() ) |
| { |
| if( sub1 instanceof IntIdentifier ) |
| out = new LinearFunction(((IntIdentifier)sub1).getValue(), 0, null); |
| else if( sub1 instanceof DataIdentifier ) |
| out = new LinearFunction(0, 1, ((DataIdentifier)sub1)._name); |
| else |
| out = rParseBinaryExpression((BinaryExpression)sub1); |
| |
| if( !CONSERVATIVE_CHECK ) |
| if(out.hasNonIndexVariables()) |
| { |
| String id = INTERAL_FN_INDEX_ROW+_idSeqfn.getNextID(); |
| out = new LinearFunction(0, 1L, id); |
| |
| _bounds._lower.put(id, 1L); |
| _bounds._upper.put(id, _vsParent.getVariable(idat._name).getDim1()); //row dim |
| _bounds._increment.put(id, 1L); |
| } |
| } |
| else //range indexing |
| { |
| Expression sub1a = sub1; |
| Expression sub1b = idat.getRowUpperBound(); |
| |
| String id = INTERAL_FN_INDEX_ROW+_idSeqfn.getNextID(); |
| out = new LinearFunction(0, 1L, id); |
| |
| if( sub1a == null && sub1b == null //: operator |
| || !(sub1a instanceof IntIdentifier) || !(sub1b instanceof IntIdentifier) ) { //for robustness |
| _bounds._lower.put(id, 1L); |
| _bounds._upper.put(id, _vsParent.getVariable(idat._name).getDim1()); //row dim |
| _bounds._increment.put(id, 1L); |
| } |
| else if( sub1a instanceof IntIdentifier && sub1b instanceof IntIdentifier ) { |
| _bounds._lower.put(id, ((IntIdentifier)sub1a).getValue()); |
| _bounds._upper.put(id, ((IntIdentifier)sub1b).getValue()); |
| _bounds._increment.put(id, 1L); |
| } |
| else { |
| out = null; |
| } |
| } |
| |
| //scale row function 'out' with col dimensionality |
| long colDim = _vsParent.getVariable(idat._name).getDim2(); |
| if( colDim >= 0 ) { |
| out.scale( colDim ); |
| } |
| else { |
| //NOTE: we could mark sb for deferred validation and evaluate on execute (see ParForProgramBlock) |
| LOG.debug("PARFOR: Warning - matrix dimensionality of '"+idat._name+"' unknown, cannot scale linear functions."); |
| } |
| } |
| catch(Exception ex) { |
| LOG.debug("PARFOR: Unable to parse MATRIX subscript expression for '"+String.valueOf(sub1)+"'.", ex); |
| out = null; //let dependency analysis fail |
| } |
| |
| //parse col expression and merge functions |
| if( out!=null ) |
| { |
| try |
| { |
| LinearFunction tmpOut = null; |
| |
| //loop index or constant (default case) |
| if( idat.getColLowerBound()!=null && idat.getColUpperBound()!=null && |
| idat.getColLowerBound() == idat.getColUpperBound() ) |
| { |
| if( sub2 instanceof IntIdentifier ) |
| out.addConstant( ((IntIdentifier)sub2).getValue() ); |
| else if( sub2 instanceof DataIdentifier ) |
| tmpOut = new LinearFunction(0, 1, ((DataIdentifier)sub2)._name) ; |
| else |
| tmpOut = rParseBinaryExpression((BinaryExpression)sub2); |
| |
| if( !CONSERVATIVE_CHECK ) |
| if(tmpOut!=null && tmpOut.hasNonIndexVariables()) |
| { |
| String id = INTERAL_FN_INDEX_COL+_idSeqfn.getNextID(); |
| tmpOut = new LinearFunction(0, 1L, id); |
| _bounds._lower.put(id, 1l); |
| _bounds._upper.put(id, _vsParent.getVariable(idat._name).getDim2()); //col dim |
| _bounds._increment.put(id, 1L); |
| } |
| } |
| else //range indexing |
| { |
| Expression sub2a = sub2; |
| Expression sub2b = idat.getColUpperBound(); |
| |
| String id = INTERAL_FN_INDEX_COL+_idSeqfn.getNextID(); |
| tmpOut = new LinearFunction(0, 1L, id); |
| |
| if( sub2a == null && sub2b == null //: operator |
| || !(sub2a instanceof IntIdentifier) || !(sub2b instanceof IntIdentifier) ) //for robustness |
| { |
| _bounds._lower.put(id, 1L); |
| _bounds._upper.put(id, _vsParent.getVariable(idat._name).getDim2()); //col dim |
| _bounds._increment.put(id, 1L); |
| } |
| else if( sub2a instanceof IntIdentifier && sub2b instanceof IntIdentifier ) |
| { |
| _bounds._lower.put(id, ((IntIdentifier)sub2a).getValue()); |
| _bounds._upper.put(id, ((IntIdentifier)sub2b).getValue()); |
| _bounds._increment.put(id, 1L); |
| } |
| else |
| { |
| out = null; |
| } |
| } |
| |
| //final merge of row and col functions |
| if( tmpOut != null ) |
| out.addFunction(tmpOut); |
| } |
| catch(Exception ex) |
| { |
| LOG.debug("PARFOR: Unable to parse MATRIX subscript expression for '"+String.valueOf(sub2)+"'.", ex); |
| out = null; //let dependency analysis fail |
| } |
| } |
| |
| //post processing after creation |
| if( out != null ) |
| { |
| //cleanup and verify created function; raise exceptions if needed |
| cleanupFunction(out); |
| verifyFunction(out); |
| |
| // pseudo loop normalization of functions (incr=1, from=1 not necessary due to Banerjee) |
| // (precondition for GCD test) |
| if( NORMALIZE ) { |
| int index=0; |
| for( String var : out._vars ) { |
| long low = _bounds._lower.get(var); |
| long up = _bounds._upper.get(var); |
| long incr = _bounds._increment.get(var); |
| if( incr < 0 || 1 < incr ) { //does never apply to internal (artificial) vars |
| out.normalize(index,low,incr); // normalize linear functions |
| _bounds._upper.put(var,(long)Math.ceil(((double)up)/incr)); // normalize upper bound |
| } |
| index++; |
| } |
| } |
| |
| //put into cache |
| if( USE_FN_CACHE ) |
| _fncache.put( getFunctionID(idat), out ); |
| } |
| |
| return out; |
| } |
| |
| private LinearFunction getRowLinearFunction(DataIdentifier dat) |
| { |
| //NOTE: would require separate function cache, not realized due to inexpensive operations |
| |
| LinearFunction out = null; |
| IndexedIdentifier idat = (IndexedIdentifier) dat; |
| Expression sub1 = idat.getRowLowerBound(); |
| |
| try |
| { |
| //loop index or constant (default case) |
| if( idat.getRowLowerBound()!=null && idat.getRowUpperBound()!=null && |
| idat.getRowLowerBound() == idat.getRowUpperBound() ) |
| { |
| if( sub1 instanceof IntIdentifier ) |
| out = new LinearFunction(((IntIdentifier)sub1).getValue(), 0, null); |
| else if( sub1 instanceof DataIdentifier ) |
| out = new LinearFunction(0, 1, ((DataIdentifier)sub1).getName()); |
| else |
| out = rParseBinaryExpression((BinaryExpression)sub1); |
| } |
| } |
| catch(Exception ex) { |
| LOG.debug("PARFOR: Unable to parse MATRIX subscript expression for '"+String.valueOf(sub1)+"'.", ex); |
| out = null; //let dependency analysis fail |
| } |
| |
| //post processing after creation |
| if( out != null ) { |
| //cleanup and verify created function; raise exceptions if needed |
| cleanupFunction(out); |
| verifyFunction(out); |
| } |
| |
| return out; |
| } |
| |
| private LinearFunction getColLinearFunction(DataIdentifier dat) |
| { |
| //NOTE: would require separate function cache, not realized due to inexpensive operations |
| |
| LinearFunction out = null; |
| IndexedIdentifier idat = (IndexedIdentifier) dat; |
| Expression sub1 = idat.getColLowerBound(); |
| |
| try |
| { |
| //loop index or constant (default case) |
| if( idat.getColLowerBound()!=null && idat.getColUpperBound()!=null && |
| idat.getColLowerBound() == idat.getColUpperBound() ) |
| { |
| if( sub1 instanceof IntIdentifier ) |
| out = new LinearFunction(((IntIdentifier)sub1).getValue(), 0, null); |
| else if( sub1 instanceof DataIdentifier ) |
| out = new LinearFunction(0, 1, ((DataIdentifier)sub1).getName()); |
| else |
| out = rParseBinaryExpression((BinaryExpression)sub1); |
| } |
| } |
| catch(Exception ex) { |
| LOG.debug("PARFOR: Unable to parse MATRIX subscript expression for '"+String.valueOf(sub1)+"'.", ex); |
| out = null; //let dependency analysis fail |
| } |
| |
| //post processing after creation |
| if( out != null ) { |
| //cleanup and verify created function; raise exceptions if needed |
| cleanupFunction(out); |
| verifyFunction(out); |
| } |
| |
| return out; |
| } |
| |
| @SuppressWarnings("unused") |
| private LinearFunction getLinearFunction(Expression expr, boolean ignoreMinWithConstant) { |
| if( expr instanceof IntIdentifier ) |
| return new LinearFunction(((IntIdentifier)expr).getValue(), 0, null); |
| else if( expr instanceof BinaryExpression ) |
| return rParseBinaryExpression((BinaryExpression)expr); |
| else if( expr instanceof BuiltinFunctionExpression && ignoreMinWithConstant ) { |
| //note: builtin function expression is also a data identifier and hence order before |
| BuiltinFunctionExpression bexpr = (BuiltinFunctionExpression) expr; |
| if( bexpr.getOpCode()==Builtins.MIN ) { |
| if( bexpr.getFirstExpr() instanceof BinaryExpression ) |
| return rParseBinaryExpression((BinaryExpression)bexpr.getFirstExpr()); |
| else if( bexpr.getSecondExpr() instanceof BinaryExpression ) |
| return rParseBinaryExpression((BinaryExpression)bexpr.getSecondExpr()); |
| } |
| } |
| else if( expr instanceof DataIdentifier ) |
| return new LinearFunction(0, 1, ((DataIdentifier)expr).getName()); |
| |
| return null; |
| } |
| |
| private LinearFunction getLinearFunction(Hop hop, boolean ignoreMinWithConstant) { |
| if( hop instanceof LiteralOp && hop.getValueType()==ValueType.INT64 ) |
| return new LinearFunction(HopRewriteUtils.getIntValue((LiteralOp)hop), 0, null); |
| else if( HopRewriteUtils.isBinary(hop, OpOp2.PLUS, OpOp2.MINUS, OpOp2.MULT) ) |
| return rParseBinaryExpression(hop); |
| else if( HopRewriteUtils.isBinary(hop, OpOp2.MIN) && ignoreMinWithConstant ) { |
| //note: builtin function expression is also a data identifier and hence order before |
| if( hop.getInput().get(0) instanceof org.apache.sysds.hops.BinaryOp ) |
| return rParseBinaryExpression(hop.getInput().get(0)); |
| else if( hop.getInput().get(1) instanceof org.apache.sysds.hops.BinaryOp ) |
| return rParseBinaryExpression(hop.getInput().get(1)); |
| } |
| else if( HopRewriteUtils.isData(hop, OpOpData.TRANSIENTREAD) ) |
| return new LinearFunction(0, 1, hop.getName()); |
| |
| return null; |
| } |
| |
| /** |
| * Creates a functionID for a given data identifier (mainly used for caching purposes), |
| * where data identifiers with equal name and matrix subscripts results in equal |
| * functionIDs. |
| * |
| * @param dat indexed identifier |
| * @return string function id |
| */ |
| private static String getFunctionID( IndexedIdentifier dat ) |
| { |
| // note: using dat.hashCode can be different for same functions, |
| // hence, we use a custom String ID |
| IndexedIdentifier idat = dat; |
| Expression ex1a = idat.getRowLowerBound(); |
| Expression ex1b = idat.getRowUpperBound(); |
| Expression ex2a = idat.getColLowerBound(); |
| Expression ex2b = idat.getColUpperBound(); |
| |
| StringBuilder sb = new StringBuilder(); |
| sb.append(String.valueOf(ex1a)); |
| sb.append(','); |
| sb.append(String.valueOf(ex1b)); |
| sb.append(','); |
| sb.append(String.valueOf(ex2a)); |
| sb.append(','); |
| sb.append(String.valueOf(ex2b)); |
| return sb.toString(); |
| } |
| |
| |
| |
| /** |
| * Removes all zero intercepts created by recursive computation. |
| * |
| * @param f1 linear function |
| */ |
| private static void cleanupFunction( LinearFunction f1 ) { |
| for( int i=0; i<f1._b.length; i++ ) |
| if( f1._vars[i]==null ) { |
| f1.removeVar(i); |
| i--; |
| continue; |
| } |
| } |
| |
| /** |
| * Simply verification check of created linear functions, mainly used for |
| * robustness purposes. |
| * |
| * @param f1 linear function |
| */ |
| private void verifyFunction(LinearFunction f1) |
| { |
| //check for required form of linear functions |
| if( f1 == null || f1._b.length != f1._vars.length ) { |
| if( LOG.isTraceEnabled() && f1!=null ) |
| LOG.trace("PARFOR: f1: "+f1.toString()); |
| throw new LanguageException("PARFOR loop dependency analysis: " + |
| "MATRIX subscripts are not in linear form (a0 + a1*x)."); |
| } |
| |
| //check all function variables to be index variables |
| for( String var : f1._vars ) |
| { |
| if( !_bounds._lower.containsKey(var) ) { |
| LOG.trace("PARFOR: not allowed variable in matrix subscript: "+var); |
| throw new LanguageException("PARFOR loop dependency analysis: " + |
| "MATRIX subscripts use non-index variables."); |
| } |
| } |
| } |
| |
| /** |
| * Tries to obtain consistent linear functions by forcing the same variable ordering for |
| * efficient comparison: f2 is modified in a way that it matches the sequence of variables in f1. |
| * |
| * @param f1 linear function 1 |
| * @param f2 linear function 2 |
| */ |
| private static void forceConsistency(LinearFunction f1, LinearFunction f2) |
| { |
| boolean warn = false; |
| |
| for( int i=0; i<f1._b.length; i++ ) |
| { |
| if( f2._b.length<(i+1) ) |
| break; |
| |
| if( !f1._vars[i].equals(f2._vars[i]) |
| &&!(f1._vars[i].startsWith(INTERAL_FN_INDEX_ROW) && f2._vars[i].startsWith(INTERAL_FN_INDEX_ROW)) |
| &&!(f1._vars[i].startsWith(INTERAL_FN_INDEX_COL) && f2._vars[i].startsWith(INTERAL_FN_INDEX_COL))) |
| { |
| boolean exchange = false; |
| //scan |
| for( int j=i+1; j<f2._b.length; j++ ) |
| if( f1._vars[i].equals(f2._vars[j]) |
| ||(f1._vars[i].startsWith(INTERAL_FN_INDEX_ROW) && f2._vars[j].startsWith(INTERAL_FN_INDEX_ROW)) |
| ||(f1._vars[i].startsWith(INTERAL_FN_INDEX_COL) && f2._vars[j].startsWith(INTERAL_FN_INDEX_COL)) ) |
| { |
| //exchange |
| long btmp = f2._b[i]; |
| String vartmp = f2._vars[i]; |
| f2._b[i] = f2._b[j]; |
| f2._vars[i] = f2._vars[j]; |
| f2._b[j] = btmp; |
| f2._vars[j] = vartmp; |
| exchange = true; |
| } |
| if( !exchange ) |
| warn = true; |
| } |
| } |
| |
| |
| if( warn && LOG.isTraceEnabled() ) |
| LOG.trace( "PARFOR: Warning - index functions f1 and f2 cannot be made consistent." ); |
| } |
| |
| /** |
| * Recursively creates a linear function for a single BinaryExpression, where PLUS, MINUS, MULT |
| * are allowed as operators. |
| * |
| * @param be binary expression |
| * @return linear function |
| */ |
| private LinearFunction rParseBinaryExpression(BinaryExpression be) { |
| Expression l = be.getLeft(); |
| Expression r = be.getRight(); |
| if( be.getOpCode() == BinaryOp.PLUS || be.getOpCode() == BinaryOp.MINUS ) { |
| boolean plus = be.getOpCode() == BinaryOp.PLUS; |
| //parse binary expressions |
| if( l instanceof BinaryExpression) { |
| LinearFunction f = rParseBinaryExpression((BinaryExpression) l); |
| Long cvalR = parseLongConstant(r); |
| if( f != null && cvalR != null ) |
| return f.addConstant(cvalR * (plus?1:-1)); |
| } |
| else if (r instanceof BinaryExpression) { |
| LinearFunction f = rParseBinaryExpression((BinaryExpression) r); |
| Long cvalL = parseLongConstant(l); |
| if( f != null && cvalL != null ) |
| return f.scale(plus?1:-1).addConstant(cvalL); |
| } |
| else { // atomic case |
| //change everything to plus if necessary |
| Long cvalL = parseLongConstant(l); |
| Long cvalR = parseLongConstant(r); |
| if( cvalL != null ) |
| return new LinearFunction(cvalL,plus?1:-1,((DataIdentifier)r)._name); |
| else if( cvalR != null ) |
| return new LinearFunction(cvalR*(plus?1:-1),1,((DataIdentifier)l)._name); |
| } |
| } |
| else if( be.getOpCode() == BinaryOp.MULT ) { |
| //atomic case (only recursion for MULT expressions, where one side is a constant) |
| Long cvalL = parseLongConstant(l); |
| Long cvalR = parseLongConstant(r); |
| if( cvalL != null && r instanceof DataIdentifier ) |
| return new LinearFunction(0, cvalL,((DataIdentifier)r)._name); |
| else if( cvalR != null && l instanceof DataIdentifier ) |
| return new LinearFunction(0, cvalR,((DataIdentifier)l)._name); |
| else if( cvalL != null && r instanceof BinaryExpression ) |
| return rParseBinaryExpression((BinaryExpression)r).scale(cvalL); |
| else if( cvalR != null && l instanceof BinaryExpression ) |
| return rParseBinaryExpression((BinaryExpression)l).scale(cvalR); |
| } |
| return null; //let dependency analysis fail |
| } |
| |
| private LinearFunction rParseBinaryExpression(Hop hop) { |
| org.apache.sysds.hops.BinaryOp bop = (org.apache.sysds.hops.BinaryOp) hop; |
| Hop l = bop.getInput().get(0); |
| Hop r = bop.getInput().get(1); |
| if( bop.getOp()==OpOp2.PLUS || bop.getOp()==OpOp2.MINUS ) { |
| boolean plus = bop.getOp() == OpOp2.PLUS; |
| //parse binary expressions |
| if( l instanceof org.apache.sysds.hops.BinaryOp) { |
| LinearFunction f = rParseBinaryExpression(l); |
| Long cvalR = parseLongConstant(r); |
| if( f != null && cvalR != null ) |
| return f.addConstant(cvalR * (plus?1:-1)); |
| } |
| else if (r instanceof org.apache.sysds.hops.BinaryOp) { |
| LinearFunction f = rParseBinaryExpression(r); |
| Long cvalL = parseLongConstant(l); |
| if( f != null && cvalL != null ) |
| return f.scale(plus?1:-1).addConstant(cvalL); |
| } |
| else { // atomic case |
| //change everything to plus if necessary |
| Long cvalL = parseLongConstant(l); |
| Long cvalR = parseLongConstant(r); |
| if( cvalL != null ) |
| return new LinearFunction(cvalL, plus?1:-1, r.getName() ); |
| else if( cvalR != null ) |
| return new LinearFunction(cvalR*(plus?1:-1), 1, l.getName()); |
| } |
| } |
| else if( bop.getOp() == OpOp2.MULT ) { |
| //atomic case (only recursion for MULT expressions, where one side is a constant) |
| Long cvalL = parseLongConstant(l); |
| Long cvalR = parseLongConstant(r); |
| if( cvalL != null && HopRewriteUtils.isData(r, OpOpData.TRANSIENTREAD) ) |
| return new LinearFunction(0, cvalL, r.getName()); |
| else if( cvalR != null && HopRewriteUtils.isData(l, OpOpData.TRANSIENTREAD) ) |
| return new LinearFunction(0, cvalR, l.getName()); |
| else if( cvalL != null && r instanceof org.apache.sysds.hops.BinaryOp ) |
| return rParseBinaryExpression(r).scale(cvalL); |
| else if( cvalR != null && l instanceof org.apache.sysds.hops.BinaryOp ) |
| return rParseBinaryExpression(l).scale(cvalR); |
| } |
| return null; //let dependency analysis fail |
| } |
| |
| private static Long parseLongConstant(Expression expr) { |
| if( expr instanceof IntIdentifier ) { |
| return ((IntIdentifier) expr).getValue(); |
| } |
| else if( expr instanceof DoubleIdentifier ) { |
| double tmp = ((DoubleIdentifier) expr).getValue(); |
| if( tmp == Math.floor(tmp) ) //ensure int |
| return UtilFunctions.toLong(tmp); |
| } |
| return null; |
| } |
| |
| private static Long parseLongConstant(Hop hop) { |
| if( hop instanceof LiteralOp && hop.getValueType()==ValueType.INT64 ) { |
| return HopRewriteUtils.getIntValue((LiteralOp)hop); |
| } |
| else if( hop instanceof LiteralOp && hop.getValueType()==ValueType.FP64 ) { |
| double tmp = HopRewriteUtils.getDoubleValue((LiteralOp)hop); |
| if( tmp == Math.floor(tmp) ) //ensure int |
| return UtilFunctions.toLong(tmp); |
| } |
| return null; |
| } |
| |
| public static class ResultVar { |
| public final String _name; |
| public final boolean _isAccum; |
| public ResultVar(String name, boolean accum) { |
| _name = name; |
| _isAccum = accum; |
| } |
| @Override |
| public boolean equals(Object that) { |
| String varname = (that instanceof ResultVar) ? |
| ((ResultVar)that)._name : that.toString(); |
| return _name.equals(varname); |
| } |
| @Override |
| public int hashCode() { |
| return _name.hashCode(); |
| } |
| @Override |
| public String toString() { |
| return _name; |
| } |
| public static boolean contains(Collection<ResultVar> list, String varName) { |
| //helper function which is necessary because list.contains checks |
| //varName.equals(rvar) which always returns false because it not a string |
| return list.stream().anyMatch(rvar -> rvar._name.equals(varName)); |
| } |
| } |
| |
| private static class Candidate { |
| private final String _var; // variable name |
| private final DataIdentifier _dat; // _var data identifier |
| private final boolean _isAccum; |
| public Candidate(String var, DataIdentifier di, boolean accum) { |
| _var = var; |
| _dat = di; |
| _isAccum = accum; |
| } |
| } |
| |
| /** |
| * Helper class for representing all lower, upper bounds of (potentially nested) |
| * loop constructs. |
| * |
| */ |
| private static class Bounds { |
| HashMap<String, Long> _lower = new HashMap<>(); |
| HashMap<String, Long> _upper = new HashMap<>(); |
| HashMap<String, Long> _increment = new HashMap<>(); |
| //contains all local variable names (subset of lower/upper/incr sets) |
| HashSet<String> _local = new HashSet<>(); |
| } |
| |
| /** |
| * Helper class for representing linear functions of matrix subscripts. |
| * The allowed form is 'y = a + b1x1 + ... = bnxn', which is required by |
| * the applied GCD and Banerjee tests. |
| * |
| */ |
| private class LinearFunction { |
| long _a; // intercept |
| long[] _b; // slopes |
| String[] _vars; // b variable names |
| |
| LinearFunction( long a, long b, String name ) { |
| _a = a; |
| _b = new long[1]; |
| _b[0] = b; |
| _vars = new String[1]; |
| _vars[0] = name; |
| } |
| |
| public LinearFunction addConstant(long value) { |
| _a += value; |
| return this; |
| } |
| |
| public LinearFunction addFunction( LinearFunction f2) { |
| _a = _a + f2._a; |
| long[] tmpb = new long[_b.length+f2._b.length]; |
| System.arraycopy( _b, 0, tmpb, 0, _b.length ); |
| System.arraycopy( f2._b, 0, tmpb, _b.length, f2._b.length ); |
| _b = tmpb; |
| String[] tmpvars = new String[_vars.length+f2._vars.length]; |
| System.arraycopy( _vars, 0, tmpvars, 0, _vars.length ); |
| System.arraycopy( f2._vars, 0, tmpvars, _vars.length, f2._vars.length ); |
| _vars = tmpvars; |
| return this; |
| } |
| |
| public LinearFunction removeVar( int i ) { |
| long[] tmpb = new long[_b.length-1]; |
| System.arraycopy( _b, 0, tmpb, 0, i ); |
| System.arraycopy( _b, i+1, tmpb, i, _b.length-i-1 ); |
| _b = tmpb; |
| String[] tmpvars = new String[_vars.length-1]; |
| System.arraycopy( _vars, 0, tmpvars, 0, i ); |
| System.arraycopy( _vars, i+1, tmpvars, i, _vars.length-i-1 ); |
| _vars = tmpvars; |
| return this; |
| } |
| |
| public LinearFunction scale( long scale ) { |
| _a *= scale; |
| for( int i=0; i<_b.length; i++ ) |
| _b[i] *= scale; |
| return this; |
| } |
| |
| public LinearFunction normalize(int index, long lower, long increment) { |
| _a -= (_b[index] * lower); |
| _b[index] *= increment; |
| return this; |
| } |
| |
| public long eval(Long... x) { |
| long ret = _a; |
| for( int i=0; i<_b.length; i++ ) |
| ret += _b[i] *= x[i]; |
| return ret; |
| } |
| |
| @Override |
| public String toString() { |
| StringBuilder sb = new StringBuilder(); |
| sb.append("("); |
| sb.append(_a); |
| sb.append(") + "); |
| sb.append("("); |
| for( int i=0; i<_b.length; i++ ) { |
| if( i>0 ) |
| sb.append("+"); |
| sb.append("("); |
| sb.append(_b[i]); |
| sb.append(" * "); |
| sb.append(_vars[i]); |
| sb.append(")"); |
| } |
| sb.append(")"); |
| return sb.toString(); |
| } |
| |
| @Override |
| public boolean equals( Object o2 ) { |
| if( o2 == null || !(o2 instanceof LinearFunction) ) |
| return false; |
| LinearFunction f2 = (LinearFunction)o2; |
| return ( _a == f2._a ) |
| && equalSlope(f2); |
| } |
| |
| public boolean equalSlope(LinearFunction f2) { |
| boolean ret = ( _b.length == f2._b.length ); |
| for( int i=0; i<_b.length && ret; i++ ) { |
| ret &= (_b[i] == f2._b[i] ); |
| //note robustness for null var names |
| String var1 = String.valueOf(_vars[i]); |
| String var2 = String.valueOf(f2._vars[i]); |
| ret &= (var1.equals(var2) |
| ||(var1.startsWith(INTERAL_FN_INDEX_ROW) && var2.startsWith(INTERAL_FN_INDEX_ROW)) |
| ||(var1.startsWith(INTERAL_FN_INDEX_COL) && var2.startsWith(INTERAL_FN_INDEX_COL))); |
| } |
| return ret; |
| } |
| |
| @Override |
| public int hashCode() { |
| return super.hashCode(); //identity |
| } |
| |
| public boolean hasNonIndexVariables() { |
| for( String var : _vars ) |
| if( var!=null && !_bounds._lower.containsKey(var) ) |
| return true; |
| return false; |
| } |
| } |
| } |