blob: 2cbf1228a9472af09dd0ed5f43efb2eebcb9d437 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.controlprogram.parfor.opt;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Set;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.ForProgramBlock;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.IfProgramBlock;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.ProgramBlock;
import org.apache.sysds.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
public class OptTreePlanChecker
{
public static void checkProgramCorrectness( ProgramBlock pb, StatementBlock sb, Set<String> fnStack )
{
Program prog = pb.getProgram();
DMLProgram dprog = sb.getDMLProg();
if (pb instanceof FunctionProgramBlock && sb instanceof FunctionStatementBlock ) {
FunctionProgramBlock fpb = (FunctionProgramBlock)pb;
FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
for( int i=0; i<fpb.getChildBlocks().size(); i++ ) {
ProgramBlock pbc = fpb.getChildBlocks().get(i);
StatementBlock sbc = fstmt.getBody().get(i);
checkProgramCorrectness(pbc, sbc, fnStack);
}
}
else if (pb instanceof WhileProgramBlock && sb instanceof WhileStatementBlock) {
WhileProgramBlock wpb = (WhileProgramBlock) pb;
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
checkHopDagCorrectness(prog, dprog, wsb.getPredicateHops(), wpb.getPredicate(), fnStack);
for( int i=0; i<wpb.getChildBlocks().size(); i++ ) {
ProgramBlock pbc = wpb.getChildBlocks().get(i);
StatementBlock sbc = wstmt.getBody().get(i);
checkProgramCorrectness(pbc, sbc, fnStack);
}
checkLinksProgramStatementBlock(wpb, wsb);
}
else if (pb instanceof IfProgramBlock && sb instanceof IfStatementBlock) {
IfProgramBlock ipb = (IfProgramBlock) pb;
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) isb.getStatement(0);
checkHopDagCorrectness(prog, dprog, isb.getPredicateHops(), ipb.getPredicate(), fnStack);
for( int i=0; i<ipb.getChildBlocksIfBody().size(); i++ ) {
ProgramBlock pbc = ipb.getChildBlocksIfBody().get(i);
StatementBlock sbc = istmt.getIfBody().get(i);
checkProgramCorrectness(pbc, sbc, fnStack);
}
for( int i=0; i<ipb.getChildBlocksElseBody().size(); i++ ) {
ProgramBlock pbc = ipb.getChildBlocksElseBody().get(i);
StatementBlock sbc = istmt.getElseBody().get(i);
checkProgramCorrectness(pbc, sbc, fnStack);
}
checkLinksProgramStatementBlock(ipb, isb);
}
else if (pb instanceof ForProgramBlock && sb instanceof ForStatementBlock) { //incl parfor
ForProgramBlock fpb = (ForProgramBlock) pb;
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) sb.getStatement(0);
checkHopDagCorrectness(prog, dprog, fsb.getFromHops(), fpb.getFromInstructions(), fnStack);
checkHopDagCorrectness(prog, dprog, fsb.getToHops(), fpb.getToInstructions(), fnStack);
checkHopDagCorrectness(prog, dprog, fsb.getIncrementHops(), fpb.getIncrementInstructions(), fnStack);
for( int i=0; i<fpb.getChildBlocks().size(); i++ ) {
ProgramBlock pbc = fpb.getChildBlocks().get(i);
StatementBlock sbc = fstmt.getBody().get(i);
checkProgramCorrectness(pbc, sbc, fnStack);
}
checkLinksProgramStatementBlock(fpb, fsb);
}
else if( pb instanceof BasicProgramBlock ) {
BasicProgramBlock bpb = (BasicProgramBlock) pb;
checkHopDagCorrectness(prog, dprog, sb.getHops(), bpb.getInstructions(), fnStack);
}
}
private static void checkHopDagCorrectness( Program prog, DMLProgram dprog, ArrayList<Hop> roots, ArrayList<Instruction> inst, Set<String> fnStack ) {
if( roots != null )
for( Hop hop : roots )
checkHopDagCorrectness(prog, dprog, hop, inst, fnStack);
}
private static void checkHopDagCorrectness( Program prog, DMLProgram dprog, Hop root, ArrayList<Instruction> inst, Set<String> fnStack ) {
//set of checks to perform
checkFunctionNames(prog, dprog, root, inst, fnStack);
}
private static void checkLinksProgramStatementBlock( ProgramBlock pb, StatementBlock sb ) {
if( pb.getStatementBlock() != sb )
throw new DMLRuntimeException("Links between programblocks and statementblocks are incorrect ("+pb+").");
}
private static void checkFunctionNames( Program prog, DMLProgram dprog, Hop root, ArrayList<Instruction> inst, Set<String> fnStack ) {
//reset visit status of dag
root.resetVisitStatus();
//get all function op in this dag
HashMap<String, FunctionOp> fops = new HashMap<>();
getAllFunctionOps(root, fops);
for( Instruction linst : inst )
if( linst instanceof FunctionCallCPInstruction )
{
FunctionCallCPInstruction flinst = (FunctionCallCPInstruction) linst;
String fnamespace = flinst.getNamespace();
String fname = flinst.getFunctionName();
String key = DMLProgram.constructFunctionKey(fnamespace, fname);
//check 1: instruction name equal to hop name
if( !fops.containsKey(key) )
throw new DMLRuntimeException( "Function Check: instruction and hop names differ ("+key+", "+fops.keySet()+")" );
//check 2: function exists
if( !prog.getFunctionProgramBlocks().containsKey(key) )
throw new DMLRuntimeException( "Function Check: function does not exits ("+key+")" );
//check 3: recursive program check
FunctionProgramBlock fpb = prog.getFunctionProgramBlock(fnamespace, fname);
FunctionStatementBlock fsb = dprog.getFunctionStatementBlock(fnamespace, fname);
if( !fnStack.contains(key) )
{
fnStack.add(key);
checkProgramCorrectness(fpb, fsb, fnStack);
fnStack.remove(key);
}
}
}
private static void getAllFunctionOps( Hop hop, HashMap<String, FunctionOp> memo )
{
if( hop.isVisited() )
return;
//process functionop
if( hop instanceof FunctionOp ) {
FunctionOp fop = (FunctionOp) hop;
memo.put(fop.getFunctionKey(), fop);
}
//process children
for( Hop in : hop.getInput() )
getAllFunctionOps(in, memo);
hop.setVisited();
}
}