blob: f1d9cddb3d982519a2ca013afde2d5f0044e51a2 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.sysds.common.Types.OpOpDG;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.VariableSet;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
/**
* Rule: Simplify program structure by hoisting loop-invariant operations
* out of while, for, or parfor loops.
*/
public class RewriteHoistLoopInvariantOperations extends StatementBlockRewriteRule
{
private final boolean _sideEffectFreeFuns;
public RewriteHoistLoopInvariantOperations() {
this(false);
}
public RewriteHoistLoopInvariantOperations(boolean noSideEffects) {
_sideEffectFreeFuns = noSideEffects;
}
@Override
public boolean createsSplitDag() {
return true;
}
@Override
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
//early abort if possible
if( sb == null || !HopRewriteUtils.isLoopStatementBlock(sb) )
return Arrays.asList(sb); //rewrite only applies to loops
//step 1: determine read-only variables
Set<String> candInputs = sb.variablesRead().getVariableNames().stream()
.filter(v -> !sb.variablesUpdated().containsVariable(v))
.collect(Collectors.toSet());
//step 2: collect loop-invariant operations along with their tmp names
Map<String, Hop> invariantOps = new HashMap<>();
collectOperations(sb, candInputs, invariantOps);
//step 3: create new statement block for all temporary intermediates
return invariantOps.isEmpty() ? Arrays.asList(sb) :
Arrays.asList(createStatementBlock(sb, invariantOps), sb);
}
@Override
public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus state) {
return sbs;
}
private void collectOperations(StatementBlock sb, Set<String> candInputs, Map<String, Hop> invariantOps) {
if( sb instanceof WhileStatementBlock ) {
WhileStatement wstmt = (WhileStatement) sb.getStatement(0);
for( StatementBlock csb : wstmt.getBody() )
collectOperations(csb, candInputs, invariantOps);
}
else if( sb instanceof ForStatementBlock ) {
ForStatement fstmt = (ForStatement) sb.getStatement(0);
for( StatementBlock csb : fstmt.getBody() )
collectOperations(csb, candInputs, invariantOps);
}
else if( sb instanceof IfStatementBlock ) {
//note: for now we do not pull loop-invariant code out of
//if statement blocks because these operations are conditionally
//executed, so unconditional execution might be counter productive
}
else if( sb.getHops() != null ) {
//step a: bottom-up flagging of loop-invariant operations
//(these are defined operations whose inputs are read only
//variables or other loop-invariant operations)
Hop.resetVisitStatus(sb.getHops());
HashSet<Long> memo = new HashSet<>();
for( Hop hop : sb.getHops() )
rTagLoopInvariantOperations(hop, candInputs, memo);
//step b: copy hop sub dag and replace it via tread
Hop.resetVisitStatus(sb.getHops());
for( Hop hop : sb.getHops() )
rCollectAndReplaceOperations(hop, candInputs, memo, invariantOps);
if( !memo.isEmpty() ) {
LOG.debug("Applied hoistLoopInvariantOperations (lines "
+sb.getBeginLine()+"-"+sb.getEndLine()+"): "+memo.size()+".");
}
}
}
private void rTagLoopInvariantOperations(Hop hop, Set<String> candInputs, Set<Long> memo) {
if( hop.isVisited() )
return;
//process inputs first (depth first)
for( Hop c : hop.getInput() )
rTagLoopInvariantOperations(c, candInputs, memo);
//flag operation if all inputs are loop invariant
boolean invariant = !HopRewriteUtils.isDataGenOp(hop, OpOpDG.RAND)
&& (!(hop instanceof FunctionOp) || _sideEffectFreeFuns)
&& !HopRewriteUtils.isData(hop, OpOpData.TRANSIENTREAD)
&& !HopRewriteUtils.isData(hop, OpOpData.TRANSIENTWRITE);
for( Hop c : hop.getInput() ) {
invariant &= (candInputs.contains(c.getName())
|| memo.contains(c.getHopID()) || c instanceof LiteralOp);
}
if( invariant )
memo.add(hop.getHopID());
hop.setVisited();
}
private void rCollectAndReplaceOperations(Hop hop, Set<String> candInputs, Set<Long> memo, Map<String, Hop> invariantOps) {
if( hop.isVisited() )
return;
//replace amenable inputs or process recursively
//(without iterators due to parent-child modifications)
for( int i=0; i<hop.getInput().size(); i++ ) {
Hop c = hop.getInput().get(i);
if( memo.contains(c.getHopID()) ) {
String tmpName = createCutVarName(false);
Hop tmp = Recompiler.deepCopyHopsDag(c);
tmp.getParent().clear();
invariantOps.put(tmpName, tmp);
//create read and replace all parent references
DataOp tread = HopRewriteUtils.createTransientRead(tmpName, c);
List<Hop> parents = new ArrayList<>(c.getParent());
for( Hop p : parents )
HopRewriteUtils.replaceChildReference(p, c, tread);
}
else {
rCollectAndReplaceOperations(c, candInputs, memo, invariantOps);
}
}
hop.setVisited();
}
private static StatementBlock createStatementBlock(StatementBlock sb, Map<String, Hop> invariantOps) {
//create empty last-level statement block
StatementBlock ret = new StatementBlock();
ret.setDMLProg(sb.getDMLProg());
ret.setParseInfo(sb);
ret.setLiveIn(new VariableSet(sb.liveIn()));
ret.setLiveOut(new VariableSet(sb.liveIn()));
//append hops with custom
ArrayList<Hop> hops = new ArrayList<>();
for( Entry<String, Hop> e : invariantOps.entrySet() ) {
Hop h = e.getValue();
DataOp twrite = HopRewriteUtils.createTransientWrite(e.getKey(), h);
hops.add(twrite);
//update live variable analysis
DataIdentifier diVar = new DataIdentifier(e.getKey());
diVar.setDimensions(h.getDim1(), h.getDim2());
diVar.setBlocksize(h.getBlocksize());
diVar.setDataType(h.getDataType());
diVar.setValueType(h.getValueType());
ret.liveOut().addVariable(e.getKey(), diVar);
sb.liveIn().addVariable(e.getKey(), diVar);
}
ret.setHops(hops);
return ret;
}
}