blob: de5b4feacc7f16ef8b11e9b12ff3f06c197b0f78 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
import java.util.List;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory;
/**
* Rule: Constant Folding. For all statement blocks,
* eliminate simple binary expressions of literals within dags by
* computing them and replacing them with a new Literal op once.
* For the moment, this only applies within a dag, later this should be
* extended across statements block (global, inter-procedure).
*/
public class RewriteConstantFolding extends HopRewriteRule
{
private static final String TMP_VARNAME = "__cf_tmp";
//reuse basic execution runtime
private BasicProgramBlock _tmpPB = null;
private ExecutionContext _tmpEC = null;
@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
if( roots == null )
return null;
for( int i=0; i<roots.size(); i++ ) {
Hop h = roots.get(i);
roots.set(i, rule_ConstantFolding(h));
}
return roots;
}
@Override
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
if( root == null )
return null;
return rule_ConstantFolding(root);
}
private Hop rule_ConstantFolding( Hop hop ) {
return rConstantFoldingExpression(hop);
}
private Hop rConstantFoldingExpression( Hop root ) {
if( root.isVisited() )
return root;
//recursively process childs (before replacement to allow bottom-recursion)
//no iterator in order to prevent concurrent modification
for( int i=0; i<root.getInput().size(); i++ ) {
Hop h = root.getInput().get(i);
rConstantFoldingExpression(h);
}
LiteralOp literal = null;
//fold binary op if both are literals / unary op if literal
if( root.getDataType() == DataType.SCALAR //scalar output
&& ( isApplicableBinaryOp(root) || isApplicableUnaryOp(root) ) )
{
literal = evalScalarOperation(root);
}
//fold conjunctive predicate if at least one input is literal 'false'
else if( isApplicableFalseConjunctivePredicate(root) ) {
literal = new LiteralOp(false);
}
//fold disjunctive predicate if at least one input is literal 'true'
else if( isApplicableTrueDisjunctivePredicate(root) ) {
literal = new LiteralOp(true);
}
//replace binary operator with folded constant
if( literal != null ) {
//bottom-up replacement to keep common subexpression elimination
if( !root.getParent().isEmpty() ) { //broot is NOT a DAG root
List<Hop> parents = new ArrayList<>(root.getParent());
for( Hop parent : parents )
HopRewriteUtils.replaceChildReference(parent, root, literal);
}
else { //broot IS a DAG root
root = literal;
}
}
//mark processed
root.setVisited();
return root;
}
/**
* In order to (1) prevent unexpected side effects from constant folding and
* (2) for simplicity with regard to arbitrary value type combinations,
* we use the same compilation and runtime for constant folding as we would
* use for actual instruction execution.
*
* @param bop high-level operator
* @return literal op
*/
private LiteralOp evalScalarOperation( Hop bop )
{
//Timing time = new Timing( true );
DataOp tmpWrite = new DataOp(TMP_VARNAME, bop.getDataType(),
bop.getValueType(), bop, OpOpData.TRANSIENTWRITE, TMP_VARNAME);
//generate runtime instruction
Dag<Lop> dag = new Dag<>();
Recompiler.rClearLops(tmpWrite); //prevent lops reuse
Lop lops = tmpWrite.constructLops(); //reconstruct lops
lops.addToDag( dag );
ArrayList<Instruction> inst = dag.getJobs(null, ConfigurationManager.getDMLConfig());
//execute instructions
ExecutionContext ec = getExecutionContext();
BasicProgramBlock pb = getProgramBlock();
pb.setInstructions( inst );
pb.execute( ec );
//get scalar result (check before invocation) and create literal according
//to observed scalar output type (not hop type) for runtime consistency
ScalarObject so = (ScalarObject) ec.getVariable(TMP_VARNAME);
LiteralOp literal = ScalarObjectFactory.createLiteralOp(so);
//cleanup
tmpWrite.getInput().clear();
bop.getParent().remove(tmpWrite);
pb.setInstructions(null);
ec.getVariables().removeAll();
//set literal properties (scalar)
HopRewriteUtils.setOutputParametersForScalar(literal);
//System.out.println("Constant folded in "+time.stop()+"ms.");
return literal;
}
private BasicProgramBlock getProgramBlock() {
if( _tmpPB == null )
_tmpPB = new BasicProgramBlock(new Program());
return _tmpPB;
}
private ExecutionContext getExecutionContext() {
if( _tmpEC == null )
_tmpEC = ExecutionContextFactory.createContext();
return _tmpEC;
}
private static boolean isApplicableBinaryOp( Hop hop )
{
ArrayList<Hop> in = hop.getInput();
return ( hop instanceof BinaryOp
&& in.get(0) instanceof LiteralOp
&& in.get(1) instanceof LiteralOp
&& ((BinaryOp)hop).getOp()!=OpOp2.CBIND
&& ((BinaryOp)hop).getOp()!=OpOp2.RBIND);
//string append is rejected although possible because it
//messes up the explain runtime output due to introduced \n
}
private static boolean isApplicableUnaryOp( Hop hop ) {
ArrayList<Hop> in = hop.getInput();
return ( hop instanceof UnaryOp
&& in.get(0) instanceof LiteralOp
&& ((UnaryOp)hop).getOp() != OpOp1.EXISTS
&& ((UnaryOp)hop).getOp() != OpOp1.PRINT
&& ((UnaryOp)hop).getOp() != OpOp1.ASSERT
&& ((UnaryOp)hop).getOp() != OpOp1.STOP
&& hop.getDataType() == DataType.SCALAR);
}
private static boolean isApplicableFalseConjunctivePredicate( Hop hop ) {
ArrayList<Hop> in = hop.getInput();
return ( HopRewriteUtils.isBinary(hop, OpOp2.AND) && hop.getDataType().isScalar()
&& ( (in.get(0) instanceof LiteralOp && !((LiteralOp)in.get(0)).getBooleanValue())
||(in.get(1) instanceof LiteralOp && !((LiteralOp)in.get(1)).getBooleanValue())) );
}
private static boolean isApplicableTrueDisjunctivePredicate( Hop hop ) {
ArrayList<Hop> in = hop.getInput();
return ( HopRewriteUtils.isBinary(hop, OpOp2.OR) && hop.getDataType().isScalar()
&& ( (in.get(0) instanceof LiteralOp && ((LiteralOp)in.get(0)).getBooleanValue())
||(in.get(1) instanceof LiteralOp && ((LiteralOp)in.get(1)).getBooleanValue())) );
}
}