blob: 07b6fca0c7ab1db48a40dea0904186c2935debfc [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops.ipa;
import java.util.List;
import org.apache.sysds.common.Builtins;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.OpOpN;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.FunctionOp.FunctionType;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.controlprogram.Program;
/**
* This rewrite applies static hop dag and statement block
* rewrites such as constant folding and branch removal
* in order to simplify statistic propagation.
*
*/
public class IPAPassReplaceEvalFunctionCalls extends IPAPass
{
@Override
public boolean isApplicable(FunctionCallGraph fgraph) {
return fgraph.containsSecondOrderCall()
&& OptimizerUtils.ALLOW_EVAL_FCALL_REPLACEMENT;
}
@Override
public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
//note: we only replace eval calls that feed into a twrite/pwrite call
//(i.e., after statement-block rewrites for splitting after val have been
//applied) - this approach ensures that the requirements of fcalls are met
// for each namespace, handle function statement blocks
boolean ret = false;
for (String namespaceKey : prog.getNamespaces().keySet())
for (String fname : prog.getFunctionStatementBlocks(namespaceKey).keySet()) {
FunctionStatementBlock fsblock = prog.getFunctionStatementBlock(namespaceKey,fname);
ret |= rewriteStatementBlock(prog, fsblock, fgraph);
}
// handle regular statement blocks in "main" method
for(StatementBlock sb : prog.getStatementBlocks())
ret |= rewriteStatementBlock(prog, sb, fgraph);
return ret;
}
private static boolean rewriteStatementBlock(DMLProgram prog, StatementBlock sb, FunctionCallGraph fgraph) {
boolean ret = false;
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
for (StatementBlock csb : fstmt.getBody())
ret |= rewriteStatementBlock(prog, csb, fgraph);
}
else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
for (StatementBlock csb : wstmt.getBody())
ret |= rewriteStatementBlock(prog, csb, fgraph);
}
else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement)isb.getStatement(0);
for (StatementBlock csb : istmt.getIfBody())
ret |= rewriteStatementBlock(prog, csb, fgraph);
for (StatementBlock csb : istmt.getElseBody())
ret |= rewriteStatementBlock(prog, csb, fgraph);
}
else if (sb instanceof ForStatementBlock) { //incl parfor
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement)fsb.getStatement(0);
for (StatementBlock csb : fstmt.getBody())
ret |= rewriteStatementBlock(prog, csb, fgraph);
}
else { //generic (last-level)
ret |= checkAndReplaceEvalFunctionCall(prog, sb, fgraph);
}
return ret;
}
private static boolean checkAndReplaceEvalFunctionCall(DMLProgram prog, StatementBlock sb, FunctionCallGraph fgraph) {
if( sb.getHops() == null )
return false;
List<Hop> roots = sb.getHops();
boolean ret = false;
for( int i=0; i<roots.size(); i++ ) {
Hop root = roots.get(i);
if( HopRewriteUtils.isData(root, OpOpData.TRANSIENTWRITE, OpOpData.PERSISTENTWRITE)
&& HopRewriteUtils.isNary(root.getInput(0), OpOpN.EVAL)
&& root.getInput(0).getInput(0) instanceof LiteralOp //constant name
&& root.getInput(0).getParent().size() == 1)
{
Hop eval = root.getInput(0);
String outvar = ((DataOp)root).getName();
//get function name and namespace
String fname = ((LiteralOp)eval.getInput(0)).getStringValue();
String fnamespace = prog.getDefaultFunctionDictionary().containsFunction(fname) ?
DMLProgram.DEFAULT_NAMESPACE : DMLProgram.BUILTIN_NAMESPACE;
if( fname.contains(Program.KEY_DELIM) ) {
String[] fparts = DMLProgram.splitFunctionKey(fname);
fnamespace = fparts[0];
fname = fparts[1];
}
fname = fnamespace.equals(DMLProgram.BUILTIN_NAMESPACE) ?
Builtins.getInternalFName(fname, eval.getInput(1).getDataType()) : fname;
//obtain functions and abort if inputs passed via list or output not a matrix
FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fnamespace, fname);
FunctionStatement fstmt = fsb!=null ? (FunctionStatement)fsb.getStatement(0) : null;
if( eval.getInput().size() > 1 && eval.getInput(1).getDataType().isList()
&& (fstmt==null || !fstmt.getInputParams().get(0).getDataType().isList())) {
LOG.warn("IPA: eval("+fnamespace+"::"+fname+") "
+ "applicable for replacement, but list inputs not yet supported.");
continue;
}
if( eval.getDataType().isList() ) {
LOG.warn("IPA: eval("+fnamespace+"::"+fname+") "
+ "applicable for replacement, but list output not yet supported.");
continue;
}
if( fstmt.getOutputParams().size() != 1 || !fstmt.getOutputParams().get(0).getDataType().isMatrix() ) {
LOG.warn("IPA: eval("+fnamespace+"::"+fname+") "
+ "applicable for replacement, but function output is not a matrix.");
continue;
}
//construct direct function call
FunctionOp fop = new FunctionOp(FunctionType.DML, fnamespace, fname,
fstmt.getInputParamNames(), eval.getInput().subList(1, eval.getInput().size()),
new String[]{outvar}, true);
HopRewriteUtils.copyLineNumbers(eval, fop);
HopRewriteUtils.removeAllChildReferences(eval);
roots.set(i, fop); //replaced
ret = true;
}
}
return ret;
}
}