blob: 92c81b5ea17efa824525637477b14f5d2be0e7a3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.VariableSet;
/**
* Rule: Simplify program structure by merging sequences of last-level
* statement blocks in order to create optimization opportunities.
*
*/
public class RewriteMergeBlockSequence extends StatementBlockRewriteRule
{
private ProgramRewriter rewriter = new ProgramRewriter(
new RewriteCommonSubexpressionElimination(true));
@Override
public boolean createsSplitDag() {
return false;
}
@Override
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb,
ProgramRewriteStatus state) {
return Arrays.asList(sb);
}
@Override
public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs,
ProgramRewriteStatus sate)
{
if( sbs == null || sbs.isEmpty() )
return sbs;
//execute binary merging iterations until fixpoint
ArrayList<StatementBlock> tmpList = new ArrayList<>(sbs);
boolean merged = true;
while( merged ) {
merged = false;
for( int i=0; i<tmpList.size()-1; i++ ) {
StatementBlock sb1 = tmpList.get(i);
StatementBlock sb2 = tmpList.get(i+1);
if( HopRewriteUtils.isLastLevelStatementBlock(sb1)
&& HopRewriteUtils.isLastLevelStatementBlock(sb2)
&& !sb1.isSplitDag() && !sb2.isSplitDag()
&& !(hasExternalFunctionOpRootWithSideEffect(sb1)
&& hasExternalFunctionOpRootWithSideEffect(sb2))
&& (!hasFunctionOpRoot(sb1) || !hasFunctionIOConflict(sb1,sb2))
&& (!hasFunctionOpRoot(sb2) || !hasFunctionIOConflict(sb2,sb1)) )
{
//note: we intend to merge sb1 into sb2 to connect data dependencies
//however, we work with a temporary list of root nodes to preserve
//the original order of roots, which affects prints w/o dependencies
ArrayList<Hop> sb1Hops = sb1.getHops();
ArrayList<Hop> sb2Hops = sb2.getHops();
ArrayList<Hop> newHops = new ArrayList<>();
//determine transient read inputs s2
Hop.resetVisitStatus(sb2Hops);
HashMap<String,Hop> treads = new HashMap<>();
HashMap<String,Hop> twrites = new HashMap<>();
for( Hop root : sb2Hops )
rCollectTransientReadWrites(root, treads, twrites);
Hop.resetVisitStatus(sb2Hops);
//merge hop dags of s1 and s2
Hop.resetVisitStatus(sb1Hops);
for( Hop root : sb1Hops ) {
//connect transient writes s1 and reads s2
if( HopRewriteUtils.isData(root, OpOpData.TRANSIENTWRITE)
&& treads.containsKey(root.getName()) ) {
//rewire transient write and transient read
Hop tread = treads.get(root.getName());
Hop in = root.getInput().get(0);
for( Hop parent : new ArrayList<>(tread.getParent()) )
HopRewriteUtils.replaceChildReference(parent, tread, in);
HopRewriteUtils.removeAllChildReferences(root);
//add transient write if necessary
if( !twrites.containsKey(root.getName())
&& sb2.liveOut().containsVariable(root.getName()) ) {
newHops.add(HopRewriteUtils.createDataOp(
root.getName(), in, OpOpData.TRANSIENTWRITE));
}
}
//add remaining roots from s1 to s2
else if( !(HopRewriteUtils.isData(root, OpOpData.TRANSIENTWRITE)
&& (twrites.containsKey(root.getName()) || !sb2.liveOut().containsVariable(root.getName()))) ) {
newHops.add(root);
}
}
//clear partial hops from the merged statement block to avoid problems with
//other statement block rewrites that iterate over the original program
sb1Hops.clear();
//append all root nodes of s2 after root nodes of s1
newHops.addAll(sb2Hops);
sb2.setHops(newHops);
//run common-subexpression elimination
Hop.resetVisitStatus(sb2.getHops());
rewriter.rewriteHopDAG(sb2.getHops(), new ProgramRewriteStatus());
//modify live variable sets of s2
sb2.setLiveIn(sb1.liveIn()); //liveOut remains unchanged
sb2.setGen(VariableSet.minus(VariableSet.union(sb1.getGen(), sb2.getGen()), sb1.getKill()));
sb2.setKill(VariableSet.union(sb1.getKill(), sb2.getKill()));
sb2.setReadVariables(VariableSet.union(sb1.variablesRead(), sb2.variablesRead()));
sb2.setUpdatedVariables(VariableSet.union(sb1.variablesUpdated(), sb2.variablesUpdated()));
LOG.debug("Applied mergeStatementBlockSequences "
+ "(blocks of lines "+sb1.getBeginLine()+"-"+sb1.getEndLine()
+" and "+sb2.getBeginLine()+"-"+sb2.getEndLine()+").");
//modify line numbers of s2
sb2.setBeginLine(sb1.getBeginLine());
sb2.setBeginColumn(sb1.getBeginColumn());
//remove sb1 from list of statement blocks
tmpList.remove(i);
merged = true;
break; //for
}
}
}
return tmpList;
}
private void rCollectTransientReadWrites(Hop current, HashMap<String, Hop> treads, HashMap<String, Hop> twrites) {
if( current.isVisited() )
return;
//process nodes recursively
for( Hop c : current.getInput() )
rCollectTransientReadWrites(c, treads, twrites);
//collect all transient reads
if( HopRewriteUtils.isData(current, OpOpData.TRANSIENTREAD) )
treads.put(current.getName(), current);
else if( HopRewriteUtils.isData(current, OpOpData.TRANSIENTWRITE) )
twrites.put(current.getName(), current);
else if( current instanceof FunctionOp ) {
for( String output : ((FunctionOp)current).getOutputVariableNames() )
twrites.put(output, null); //only name lookup
}
current.setVisited();
}
private static boolean hasFunctionOpRoot(StatementBlock sb) {
if( sb == null || sb.getHops() == null )
return false;
boolean ret = false;
for( Hop root : sb.getHops() )
ret |= (root instanceof FunctionOp);
return ret;
}
private static boolean hasExternalFunctionOpRootWithSideEffect(StatementBlock sb) {
return false;
}
private static boolean hasFunctionIOConflict(StatementBlock sb1, StatementBlock sb2)
{
//semantics: a function op root in sb1 conflicts with sb2 if this function op writes
//to a variable that is read or written by sb2, where the write might be either
//a traditional transient write or another function op.
//collect all function output variables of sb1
HashSet<String> outSb1 = new HashSet<>();
for( Hop root : sb1.getHops() )
if( root instanceof FunctionOp )
outSb1.addAll(Arrays.asList(((FunctionOp)root).getOutputVariableNames()));
//check all output variables against read/updated sets
return sb2.variablesRead().containsAnyName(outSb1)
|| sb2.variablesUpdated().containsAnyName(outSb1);
}
}