blob: f4fc86ce24f5728a3e83eefda30ec3acb5811dfa [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
public class MarkForLineageReuse extends StatementBlockRewriteRule
{
@Override
public boolean createsSplitDag() {
return false;
}
@Override
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status)
{
if (!HopRewriteUtils.isLoopStatementBlock(sb) || LineageCacheConfig.ReuseCacheType.isNone())
return Arrays.asList(sb); //early abort
if (sb instanceof ForStatementBlock) {
ForStatement fstmt = (ForStatement)sb.getStatement(0);
Set<String> loopVar = new HashSet<>(Arrays.asList(fstmt.getIterablePredicate().getIterVar().getName()));
HashSet<String> deproots = new HashSet<>();
rUnmarkLoopDepVarsSB(fstmt.getBody(), deproots, loopVar);
}
if (sb instanceof WhileStatementBlock) {
WhileStatement wstmt = (WhileStatement)sb.getStatement(0);
// intersection of updated and conditional variables are the loop variables
Set<String> loopVar = sb.variablesUpdated().getVariableNames().stream()
.filter(v -> wstmt.getConditionalPredicate().variablesRead().containsVariable(v))
.collect(Collectors.toSet());
HashSet<String> deproots = new HashSet<>();
rUnmarkLoopDepVarsSB(wstmt.getBody(), deproots, loopVar);
}
return Arrays.asList(sb);
}
private void rUnmarkLoopDepVarsSB(ArrayList<StatementBlock> sbs, HashSet<String> deproots, Set<String> loopVar)
{
HashSet<String> newdepsbs = new HashSet<>();
int lim = 0;
do {
newdepsbs.clear();
newdepsbs.addAll(deproots);
for (StatementBlock sb : sbs) {
if (sb instanceof ForStatementBlock) {
ForStatement fstmt = (ForStatement)sb.getStatement(0);
rUnmarkLoopDepVarsSB(fstmt.getBody(), newdepsbs, loopVar);
//TODO: nested loops.
}
else if (sb instanceof WhileStatementBlock) {
WhileStatement wstmt = (WhileStatement)sb.getStatement(0);
rUnmarkLoopDepVarsSB(wstmt.getBody(), newdepsbs, loopVar);
}
else if (sb instanceof IfStatementBlock) {
IfStatement ifstmt = (IfStatement)sb.getStatement(0);
rUnmarkLoopDepVarsSB(ifstmt.getIfBody(), newdepsbs, loopVar);
if (ifstmt.getElseBody() != null)
rUnmarkLoopDepVarsSB(ifstmt.getElseBody(), newdepsbs, loopVar);
}
else if (sb instanceof FunctionStatementBlock) {
FunctionStatement fnstmt = (FunctionStatement)sb.getStatement(0);
rUnmarkLoopDepVarsSB(fnstmt.getBody(), newdepsbs, loopVar);
}
else {
if (sb.getHops() != null)
for (int j=0; j<sb.variablesUpdated().getSize(); j++) {
HashSet<String> newdeproots = new HashSet<>(deproots);
for (Hop hop : sb.getHops()) {
// find the loop dependent DAG roots
Hop.resetVisitStatus(sb.getHops());
HashSet<Long> dephops = new HashSet<>();
rUnmarkLoopDepVars(hop, loopVar, newdeproots, dephops);
}
if (!deproots.isEmpty() && deproots.equals(newdeproots))
// break if loop dependent DAGs are converged to a unvarying set
break;
else
// iterate to propagate the loop dependents across all the DAGs in this SB
deproots.addAll(newdeproots);
}
}
}
deproots.addAll(newdepsbs);
lim++;
}
// iterate to propagate the loop dependents across all the SBs
while (lim < sbs.size() && (deproots.isEmpty() || !deproots.equals(newdepsbs)));
}
private void rUnmarkLoopDepVars(Hop hop, Set<String> loopVar, HashSet<String> deproots, HashSet<Long> dephops)
{
if (hop.isVisited())
return;
for (Hop hi : hop.getInput())
rUnmarkLoopDepVars(hi, loopVar, deproots, dephops);
// unmark operation if this itself or any of its inputs are loop dependent
boolean loopdephop = loopVar.contains(hop.getName())
|| (HopRewriteUtils.isData(hop, OpOpData.TRANSIENTREAD)
&& deproots.contains(hop.getName()));
for (Hop hi : hop.getInput())
loopdephop |= dephops.contains(hi.getHopID());
if (loopdephop) {
dephops.add(hop.getHopID());
hop.setRequiresLineageCaching(false);
//TODO: extend all the hops to propagate till variablecp output
}
// TODO: logic to separate out partially reusable cases (e.g cbind-tsmm)
if (HopRewriteUtils.isData(hop, OpOpData.TRANSIENTWRITE)
&& !dephops.isEmpty())
// copy to propagate across
deproots.add(hop.getName());
hop.setVisited();
}
@Override
public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus status) {
return sbs;
}
}