blob: e60d9f111f433eb14f1050ce51ef321a906e804b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops.rewrite;
import java.io.IOException;
import java.util.ArrayList;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.Hop.DataOpTypes;
import org.apache.sysml.hops.Hop.OpOp1;
import org.apache.sysml.hops.Hop.OpOp2;
import org.apache.sysml.hops.Hop.VisitStatus;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.lops.compile.Dag;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
/**
* Rule: Constant Folding. For all statement blocks,
* eliminate simple binary expressions of literals within dags by
* computing them and replacing them with a new Literal op once.
* For the moment, this only applies within a dag, later this should be
* extended across statements block (global, inter-procedure).
*/
public class RewriteConstantFolding extends HopRewriteRule
{
private static final String TMP_VARNAME = "__cf_tmp";
//reuse basic execution runtime
private static ProgramBlock _tmpPB = null;
private static ExecutionContext _tmpEC = null;
@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state)
throws HopsException
{
if( roots == null )
return null;
for( int i=0; i<roots.size(); i++ )
{
Hop h = roots.get(i);
roots.set(i, rule_ConstantFolding(h));
}
return roots;
}
@Override
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state)
throws HopsException
{
if( root == null )
return null;
return rule_ConstantFolding(root);
}
/**
*
* @param hop
* @throws HopsException
*/
private Hop rule_ConstantFolding( Hop hop )
throws HopsException
{
return rConstantFoldingExpression(hop);
}
/**
*
* @param root
* @throws HopsException
*/
private Hop rConstantFoldingExpression( Hop root )
throws HopsException
{
if( root.getVisited() == VisitStatus.DONE )
return root;
//recursively process childs (before replacement to allow bottom-recursion)
//no iterator in order to prevent concurrent modification
for( int i=0; i<root.getInput().size(); i++ )
{
Hop h = root.getInput().get(i);
rConstantFoldingExpression(h);
}
LiteralOp literal = null;
//fold binary op if both are literals / unary op if literal
if( root.getDataType() == DataType.SCALAR //scalar ouput
&& ( isApplicableBinaryOp(root) || isApplicableUnaryOp(root) ) )
{
//core constant folding via runtime instructions
try {
literal = evalScalarOperation(root);
}
catch(Exception ex)
{
LOG.error("Failed to execute constant folding instructions. No abort.", ex);
}
}
//fold conjunctive predicate if at least one input is literal 'false'
else if( isApplicableFalseConjunctivePredicate(root) )
{
literal = new LiteralOp(false);
}
//fold disjunctive predicate if at least one input is literal 'true'
else if( isApplicableTrueDisjunctivePredicate(root) )
{
literal = new LiteralOp(true);
}
//replace binary operator with folded constant
if( literal != null )
{
//reverse replacement in order to keep common subexpression elimination
int plen = root.getParent().size();
if( plen > 0 ) //broot is NOT a DAG root
{
for( int i=0; i<root.getParent().size(); i++ ) //for all parents
{
Hop parent = root.getParent().get(i);
for( int j=0; j<parent.getInput().size(); j++ )
{
Hop child = parent.getInput().get(j);
if( root == child )
{
//replace operator
//root to parent link cannot be removed within this loop, as loop iterates over list containing parents.
parent.getInput().remove(j);
HopRewriteUtils.addChildReference(parent, literal,j);
}
}
}
root.getParent().clear();
}
else //broot IS a DAG root
{
root = literal;
}
}
//mark processed
root.setVisited( VisitStatus.DONE );
return root;
}
/**
* In order to (1) prevent unexpected side effects from constant folding and
* (2) for simplicity with regard to arbitrary value type combinations,
* we use the same compilation and runtime for constant folding as we would
* use for actual instruction execution.
*
* @return
* @throws IOException
* @throws LopsException
* @throws DMLRuntimeException
* @throws HopsException
*/
private LiteralOp evalScalarOperation( Hop bop )
throws LopsException, DMLRuntimeException, IOException, HopsException
{
//Timing time = new Timing( true );
DataOp tmpWrite = new DataOp(TMP_VARNAME, bop.getDataType(), bop.getValueType(), bop, DataOpTypes.TRANSIENTWRITE, TMP_VARNAME);
//generate runtime instruction
Dag<Lop> dag = new Dag<Lop>();
Recompiler.rClearLops(tmpWrite); //prevent lops reuse
Lop lops = tmpWrite.constructLops(); //reconstruct lops
lops.addToDag( dag );
ArrayList<Instruction> inst = dag.getJobs(null, ConfigurationManager.getDMLConfig());
//execute instructions
ExecutionContext ec = getExecutionContext();
ProgramBlock pb = getProgramBlock();
pb.setInstructions( inst );
pb.execute( ec );
//get scalar result (check before invocation) and create literal according
//to observed scalar output type (not hop type) for runtime consistency
ScalarObject so = (ScalarObject) ec.getVariable(TMP_VARNAME);
LiteralOp literal = null;
switch( so.getValueType() ){
case DOUBLE: literal = new LiteralOp(so.getDoubleValue()); break;
case INT: literal = new LiteralOp(so.getLongValue()); break;
case BOOLEAN: literal = new LiteralOp(so.getBooleanValue()); break;
case STRING: literal = new LiteralOp(so.getStringValue()); break;
default:
throw new HopsException("Unsupported literal value type: "+bop.getValueType());
}
//cleanup
tmpWrite.getInput().clear();
bop.getParent().remove(tmpWrite);
pb.setInstructions(null);
ec.getVariables().removeAll();
//set literal properties (scalar)
literal.setDim1(0);
literal.setDim2(0);
literal.setRowsInBlock(-1);
literal.setColsInBlock(-1);
//System.out.println("Constant folded in "+time.stop()+"ms.");
return literal;
}
/**
*
* @return
* @throws DMLRuntimeException
*/
private static ProgramBlock getProgramBlock()
throws DMLRuntimeException
{
if( _tmpPB == null )
_tmpPB = new ProgramBlock( new Program() );
return _tmpPB;
}
/**
*
* @return
*/
private static ExecutionContext getExecutionContext()
{
if( _tmpEC == null )
_tmpEC = ExecutionContextFactory.createContext();
return _tmpEC;
}
/**
*
* @param hop
* @return
*/
private boolean isApplicableBinaryOp( Hop hop )
{
ArrayList<Hop> in = hop.getInput();
return ( hop instanceof BinaryOp
&& in.get(0) instanceof LiteralOp
&& in.get(1) instanceof LiteralOp
&& ((BinaryOp)hop).getOp()!=OpOp2.CBIND
&& ((BinaryOp)hop).getOp()!=OpOp2.RBIND);
//string append is rejected although possible because it
//messes up the explain runtime output due to introduced \n
}
/**
*
* @param hop
* @return
*/
private boolean isApplicableUnaryOp( Hop hop )
{
ArrayList<Hop> in = hop.getInput();
return ( hop instanceof UnaryOp
&& in.get(0) instanceof LiteralOp
&& ((UnaryOp)hop).getOp() != OpOp1.PRINT
&& ((UnaryOp)hop).getOp() != OpOp1.STOP
&& hop.getDataType() == DataType.SCALAR);
}
/**
*
* @param hop
* @return
* @throws HopsException
*/
private boolean isApplicableFalseConjunctivePredicate( Hop hop )
throws HopsException
{
ArrayList<Hop> in = hop.getInput();
return ( hop instanceof BinaryOp
&& ((BinaryOp)hop).getOp()==OpOp2.AND
&& ( (in.get(0) instanceof LiteralOp && !((LiteralOp)in.get(0)).getBooleanValue())
||(in.get(1) instanceof LiteralOp && !((LiteralOp)in.get(1)).getBooleanValue())) );
}
/**
*
* @param hop
* @return
* @throws HopsException
*/
private boolean isApplicableTrueDisjunctivePredicate( Hop hop )
throws HopsException
{
ArrayList<Hop> in = hop.getInput();
return ( hop instanceof BinaryOp
&& ((BinaryOp)hop).getOp()==OpOp2.OR
&& ( (in.get(0) instanceof LiteralOp && ((LiteralOp)in.get(0)).getBooleanValue())
||(in.get(1) instanceof LiteralOp && ((LiteralOp)in.get(1)).getBooleanValue())) );
}
}