blob: f387657606ce61143512d0d9766a7f8616c9fe0a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.controlprogram.parfor.opt;
import java.util.ArrayList;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.codegen.SpoofCompiler;
import org.apache.sysds.hops.codegen.SpoofCompiler.IntegrationType;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DMLTranslator;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.ForProgramBlock;
import org.apache.sysds.runtime.controlprogram.IfProgramBlock;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.ProgramBlock;
import org.apache.sysds.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
public class ProgramRecompiler
{
public static ArrayList<ProgramBlock> generatePartitialRuntimeProgram(Program rtprog, ArrayList<StatementBlock> sbs)
{
ArrayList<ProgramBlock> ret = new ArrayList<>();
DMLConfig config = ConfigurationManager.getDMLConfig();
//construct lops from hops if not existing
DMLTranslator dmlt = new DMLTranslator(sbs.get(0).getDMLProg());
for( StatementBlock sb : sbs ) {
dmlt.constructLops(sb);
}
//construct runtime program from lops
for( StatementBlock sb : sbs ) {
ret.add(dmlt.createRuntimeProgramBlock(rtprog, sb, config));
}
//enhance runtime program by automatic operator fusion
if( ConfigurationManager.isCodegenEnabled()
&& SpoofCompiler.INTEGRATION==IntegrationType.RUNTIME ) {
for( ProgramBlock pb : ret )
dmlt.codgenHopsDAG(pb);
}
return ret;
}
/**
* NOTE: if force is set, we set and recompile the respective indexing hops;
* otherwise, we release the forced exec type and recompile again. Hence,
* any changes can be exactly reverted with the same access behavior.
*
* @param sb statement block
* @param pb program block
* @param var variable
* @param ec execution context
* @param force if true, set and recompile the respective indexing hops
*/
public static void rFindAndRecompileIndexingHOP( StatementBlock sb, ProgramBlock pb, String var, ExecutionContext ec, boolean force )
{
if( pb instanceof IfProgramBlock && sb instanceof IfStatementBlock ) {
IfProgramBlock ipb = (IfProgramBlock) pb;
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement is = (IfStatement) sb.getStatement(0);
//process if condition
if( isb.getPredicateHops()!=null )
ipb.setPredicate( rFindAndRecompileIndexingHOP(isb.getPredicateHops(),ipb.getPredicate(),var,ec,force) );
//process if branch
int len = is.getIfBody().size();
for( int i=0; i<ipb.getChildBlocksIfBody().size() && i<len; i++ ) {
ProgramBlock lpb = ipb.getChildBlocksIfBody().get(i);
StatementBlock lsb = is.getIfBody().get(i);
rFindAndRecompileIndexingHOP(lsb,lpb,var,ec,force);
}
//process else branch
if( ipb.getChildBlocksElseBody() != null ) {
int len2 = is.getElseBody().size();
for( int i=0; i<ipb.getChildBlocksElseBody().size() && i<len2; i++ ) {
ProgramBlock lpb = ipb.getChildBlocksElseBody().get(i);
StatementBlock lsb = is.getElseBody().get(i);
rFindAndRecompileIndexingHOP(lsb,lpb,var,ec,force);
}
}
}
else if( pb instanceof WhileProgramBlock && sb instanceof WhileStatementBlock ) {
WhileProgramBlock wpb = (WhileProgramBlock) pb;
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement ws = (WhileStatement) sb.getStatement(0);
//process while condition
if( wsb.getPredicateHops()!=null )
wpb.setPredicate( rFindAndRecompileIndexingHOP(wsb.getPredicateHops(),wpb.getPredicate(),var,ec,force) );
//process body
int len = ws.getBody().size(); //robustness for potentially added problem blocks
for( int i=0; i<wpb.getChildBlocks().size() && i<len; i++ )
{
ProgramBlock lpb = wpb.getChildBlocks().get(i);
StatementBlock lsb = ws.getBody().get(i);
rFindAndRecompileIndexingHOP(lsb,lpb,var,ec, force);
}
}
else if( pb instanceof ForProgramBlock && sb instanceof ForStatementBlock ) { //for or parfor
ForProgramBlock fpb = (ForProgramBlock) pb;
ForStatementBlock fsb = (ForStatementBlock)sb;
ForStatement fs = (ForStatement) fsb.getStatement(0);
if( fsb.getFromHops()!=null )
fpb.setFromInstructions( rFindAndRecompileIndexingHOP(fsb.getFromHops(),fpb.getFromInstructions(),var,ec,force) );
if( fsb.getToHops()!=null )
fpb.setToInstructions( rFindAndRecompileIndexingHOP(fsb.getToHops(),fpb.getToInstructions(),var,ec,force) );
if( fsb.getIncrementHops()!=null )
fpb.setIncrementInstructions( rFindAndRecompileIndexingHOP(fsb.getIncrementHops(),fpb.getIncrementInstructions(),var,ec,force) );
//process body
int len = fs.getBody().size(); //robustness for potentially added problem blocks
for( int i=0; i<fpb.getChildBlocks().size() && i<len; i++ ) {
ProgramBlock lpb = fpb.getChildBlocks().get(i);
StatementBlock lsb = fs.getBody().get(i);
rFindAndRecompileIndexingHOP(lsb,lpb,var,ec, force);
}
}
else if( pb instanceof BasicProgramBlock ) {
BasicProgramBlock bpb = (BasicProgramBlock) pb;
try
{
//process actual hops
boolean ret = false;
Hop.resetVisitStatus(sb.getHops());
if( force ) {
//set forced execution type
for( Hop h : sb.getHops() )
ret |= rFindAndSetCPIndexingHOP(h, var);
}
else {
//release forced execution type
for( Hop h : sb.getHops() )
ret |= rFindAndReleaseIndexingHOP(h, var);
}
//recompilation on-demand
if( ret ) {
//construct new instructions
ArrayList<Instruction> newInst = Recompiler.recompileHopsDag(
sb, sb.getHops(), ec.getVariables(), null, true, false, 0);
bpb.setInstructions( newInst );
}
}
catch(Exception ex)
{
throw new DMLRuntimeException(ex);
}
}
}
public static LocalVariableMap getReusableScalarVariables( DMLProgram prog, StatementBlock parforSB, LocalVariableMap vars ) {
LocalVariableMap constVars = new LocalVariableMap();
for( String varname : vars.keySet() )
{
Data dat = vars.get(varname);
if( dat instanceof ScalarObject //scalar
&& isApplicableForReuseVariable(prog, parforSB, varname) ) //constant
{
constVars.put(varname, dat);
}
}
return constVars;
}
public static void replaceConstantScalarVariables( StatementBlock sb, LocalVariableMap vars )
{
if( sb instanceof IfStatementBlock )
{
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement is = (IfStatement) sb.getStatement(0);
replacePredicateLiterals(isb.getPredicateHops(), vars);
for( StatementBlock lsb : is.getIfBody() )
replaceConstantScalarVariables(lsb, vars);
for( StatementBlock lsb : is.getElseBody() )
replaceConstantScalarVariables(lsb, vars);
}
else if( sb instanceof WhileStatementBlock )
{
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement ws = (WhileStatement) sb.getStatement(0);
replacePredicateLiterals(wsb.getPredicateHops(), vars);
for( StatementBlock lsb : ws.getBody() )
replaceConstantScalarVariables(lsb, vars);
}
else if( sb instanceof ForStatementBlock ) //for or parfor
{
ForStatementBlock fsb = (ForStatementBlock)sb;
ForStatement fs = (ForStatement) fsb.getStatement(0);
replacePredicateLiterals(fsb.getFromHops(), vars);
replacePredicateLiterals(fsb.getToHops(), vars);
replacePredicateLiterals(fsb.getIncrementHops(), vars);
for( StatementBlock lsb : fs.getBody() )
replaceConstantScalarVariables(lsb, vars);
}
else //last level block
{
ArrayList<Hop> hops = sb.getHops();
if( hops != null )
{
//replace constant literals
Hop.resetVisitStatus(hops);
for( Hop hopRoot : hops )
Recompiler.rReplaceLiterals( hopRoot, vars, true );
}
}
}
private static void replacePredicateLiterals( Hop pred, LocalVariableMap vars ) {
if( pred != null ){
pred.resetVisitStatus();
Recompiler.rReplaceLiterals(pred, vars, true);
}
}
/**
* This function determines if an parfor input variable is guaranteed to be read-only
* across multiple invocations of parfor optimization (e.g., in a surrounding while loop).
* In case of invariant variables we can reuse partitioned matrices and propagate constants
* for better size estimation.
*
* @param prog dml program
* @param parforSB parfor statement block
* @param var variable
* @return true if can reuse variable
*/
public static boolean isApplicableForReuseVariable( DMLProgram prog, StatementBlock parforSB, String var ) {
boolean ret = false;
for( StatementBlock sb : prog.getStatementBlocks() )
ret |= isApplicableForReuseVariable(sb, parforSB, var);
return ret;
}
private static boolean isApplicableForReuseVariable( StatementBlock sb, StatementBlock parforSB, String var ) {
boolean ret = false;
if( sb instanceof IfStatementBlock ) {
IfStatement is = (IfStatement) sb.getStatement(0);
for( StatementBlock lsb : is.getIfBody() )
ret |= isApplicableForReuseVariable(lsb, parforSB, var);
for( StatementBlock lsb : is.getElseBody() )
ret |= isApplicableForReuseVariable(lsb, parforSB, var);
}
else if( sb instanceof WhileStatementBlock ) {
WhileStatement ws = (WhileStatement) sb.getStatement(0);
for( StatementBlock lsb : ws.getBody() )
ret |= isApplicableForReuseVariable(lsb, parforSB, var);
}
else if( sb instanceof ForStatementBlock ) { //for or parfor
ForStatementBlock fsb = (ForStatementBlock)sb;
ForStatement fs = (ForStatement) fsb.getStatement(0);
if( fsb == parforSB ) {
//found parfor statement
ret = true;
}
else {
for( StatementBlock lsb : fs.getBody() )
ret |= isApplicableForReuseVariable(lsb, parforSB, var);
}
}
return ret && !sb.variablesUpdated().containsVariable(var);
}
public static boolean containsAtLeastOneFunction( ProgramBlock pb )
{
if( pb instanceof IfProgramBlock ) {
IfProgramBlock ipb = (IfProgramBlock) pb;
for( ProgramBlock lpb : ipb.getChildBlocksIfBody() )
if( containsAtLeastOneFunction(lpb) )
return true;
for( ProgramBlock lpb : ipb.getChildBlocksElseBody() )
if( containsAtLeastOneFunction(lpb) )
return true;
}
else if( pb instanceof WhileProgramBlock ) {
WhileProgramBlock wpb = (WhileProgramBlock) pb;
for( ProgramBlock lpb : wpb.getChildBlocks() )
if( containsAtLeastOneFunction(lpb) )
return true;
}
else if( pb instanceof ForProgramBlock ) { //incl parfor
ForProgramBlock fpb = (ForProgramBlock) pb;
for( ProgramBlock lpb : fpb.getChildBlocks() )
if( containsAtLeastOneFunction(lpb) )
return true;
}
else if( pb instanceof BasicProgramBlock ) {
BasicProgramBlock bpb = (BasicProgramBlock) pb;
if( bpb.getInstructions() != null )
for( Instruction inst : bpb.getInstructions() )
if( inst instanceof FunctionCallCPInstruction )
return true;
}
return false;
}
private static ArrayList<Instruction> rFindAndRecompileIndexingHOP( Hop hop, ArrayList<Instruction> in, String var, ExecutionContext ec, boolean force )
{
ArrayList<Instruction> tmp = in;
try
{
boolean ret = false;
hop.resetVisitStatus();
if( force ) //set forced execution type
ret = rFindAndSetCPIndexingHOP(hop, var);
else //release forced execution type
ret = rFindAndReleaseIndexingHOP(hop, var);
//recompilation on-demand
if( ret )
{
//construct new instructions
tmp = Recompiler.recompileHopsDag(
hop, ec.getVariables(), null, true, false, 0);
}
}
catch(Exception ex)
{
throw new DMLRuntimeException(ex);
}
return tmp;
}
private static boolean rFindAndSetCPIndexingHOP(Hop hop, String var)
{
boolean ret = false;
if( hop.isVisited() )
return ret;
ArrayList<Hop> in = hop.getInput();
if( hop instanceof IndexingOp )
{
String inMatrix = hop.getInput().get(0).getName();
if( inMatrix.equals(var) )
{
//NOTE: mem estimate of RIX, set to output size by parfor optmizer
//(rowblock/colblock only applied if in total less than two blocks,
// hence always mem_est<mem_budget)
if( hop.getMemEstimate() < OptimizerUtils.getLocalMemBudget() )
hop.setForcedExecType( LopProperties.ExecType.CP );
else
hop.setForcedExecType( LopProperties.ExecType.CP_FILE );
ret = true;
}
}
//recursive search
if( in != null )
for( Hop hin : in )
ret |= rFindAndSetCPIndexingHOP(hin,var);
hop.setVisited();
return ret;
}
private static boolean rFindAndReleaseIndexingHOP(Hop hop, String var)
{
boolean ret = false;
if( hop.isVisited() )
return ret;
ArrayList<Hop> in = hop.getInput();
if( hop instanceof IndexingOp )
{
String inMatrix = hop.getInput().get(0).getName();
if( inMatrix.equals(var) )
{
hop.setForcedExecType(null);
hop.clearMemEstimate();
ret = true;
}
}
//recursive search
if( in != null )
for( Hop hin : in )
ret |= rFindAndReleaseIndexingHOP(hin,var);
hop.setVisited();
return ret;
}
///////
// additional general-purpose functionalities
protected static ArrayList<Instruction> createNestedParallelismToInstructionSet(String iterVar, String offset) {
//create instruction string
StringBuilder sb = new StringBuilder("CP"+Lop.OPERAND_DELIMITOR+"+"+Lop.OPERAND_DELIMITOR);
sb.append(iterVar);
sb.append(Lop.DATATYPE_PREFIX+"SCALAR"+Lop.VALUETYPE_PREFIX+"INT"+Lop.OPERAND_DELIMITOR);
sb.append(offset);
sb.append(Lop.DATATYPE_PREFIX+"SCALAR"+Lop.VALUETYPE_PREFIX+"INT"+Lop.OPERAND_DELIMITOR);
sb.append(iterVar);
sb.append(Lop.DATATYPE_PREFIX+"SCALAR"+Lop.VALUETYPE_PREFIX+"INT");
String str = sb.toString();
//create instruction set
ArrayList<Instruction> tmp = new ArrayList<>();
Instruction inst = BinaryCPInstruction.parseInstruction(str);
tmp.add(inst);
return tmp;
}
}