| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysml.hops.globalopt.gdfgraph; |
| |
| import java.util.ArrayList; |
| import java.util.HashMap; |
| import java.util.Map.Entry; |
| import java.util.Set; |
| |
| import org.apache.sysml.hops.DataOp; |
| import org.apache.sysml.hops.Hop; |
| import org.apache.sysml.hops.Hop.DataOpTypes; |
| import org.apache.sysml.hops.globalopt.Summary; |
| import org.apache.sysml.hops.HopsException; |
| import org.apache.sysml.parser.ForStatementBlock; |
| import org.apache.sysml.parser.IfStatementBlock; |
| import org.apache.sysml.parser.StatementBlock; |
| import org.apache.sysml.parser.WhileStatementBlock; |
| import org.apache.sysml.runtime.DMLRuntimeException; |
| import org.apache.sysml.runtime.controlprogram.ForProgramBlock; |
| import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock; |
| import org.apache.sysml.runtime.controlprogram.IfProgramBlock; |
| import org.apache.sysml.runtime.controlprogram.Program; |
| import org.apache.sysml.runtime.controlprogram.ProgramBlock; |
| import org.apache.sysml.runtime.controlprogram.WhileProgramBlock; |
| import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; |
| import org.apache.sysml.utils.Explain; |
| |
| /** |
| * GENERAL 'GDF GRAPH' STRUCTURE, by MB: |
| * 1) Each hop is represented by an GDFNode |
| * 2) Each loop is represented by a structured GDFLoopNode |
| * 3) Transient Read/Write connections are represented via CrossBlockNodes, |
| * a) type PLAIN: single input crossblocknode represents unconditional data flow |
| * b) type MERGE: two inputs crossblocknode represent conditional data flow merge |
| * |
| * In detail, the graph builder essentially does a single pass over the entire program |
| * and constructs the global data flow graph bottom up. We create crossblocknodes for |
| * every transient write, loop nodes for for/while programblocks, and crossblocknodes |
| * after every if programblock. |
| * |
| */ |
| public class GraphBuilder |
| { |
| |
| private static final boolean IGNORE_UNBOUND_UPDATED_VARS = true; |
| |
| /** |
| * |
| * @param prog |
| * @return |
| * @throws DMLRuntimeException |
| * @throws HopsException |
| */ |
| public static GDFGraph constructGlobalDataFlowGraph( Program prog, Summary summary ) |
| throws DMLRuntimeException, HopsException |
| { |
| Timing time = new Timing(true); |
| |
| HashMap<String, GDFNode> roots = new HashMap<String, GDFNode>(); |
| for( ProgramBlock pb : prog.getProgramBlocks() ) |
| constructGDFGraph( pb, roots ); |
| |
| //create GDF graph root nodes |
| ArrayList<GDFNode> ret = new ArrayList<GDFNode>(); |
| for( GDFNode root : roots.values() ) |
| if( !(root instanceof GDFCrossBlockNode) ) |
| ret.add(root); |
| |
| //create GDF graph |
| GDFGraph graph = new GDFGraph(prog, ret); |
| |
| summary.setTimeGDFGraph(time.stop()); |
| return graph; |
| } |
| |
| /** |
| * |
| * @param pb |
| * @param roots |
| * @throws DMLRuntimeException |
| * @throws HopsException |
| */ |
| @SuppressWarnings("unchecked") |
| private static void constructGDFGraph( ProgramBlock pb, HashMap<String, GDFNode> roots ) |
| throws DMLRuntimeException, HopsException |
| { |
| if (pb instanceof FunctionProgramBlock ) |
| { |
| throw new DMLRuntimeException("FunctionProgramBlocks not implemented yet."); |
| } |
| else if (pb instanceof WhileProgramBlock) |
| { |
| WhileProgramBlock wpb = (WhileProgramBlock) pb; |
| WhileStatementBlock wsb = (WhileStatementBlock) pb.getStatementBlock(); |
| //construct predicate node (conceptually sequence of from/to/incr) |
| GDFNode pred = constructGDFGraph(wsb.getPredicateHops(), wpb, new HashMap<Long, GDFNode>(), roots); |
| HashMap<String,GDFNode> inputs = constructLoopInputNodes(wpb, wsb, roots); |
| HashMap<String,GDFNode> lroots = (HashMap<String, GDFNode>) inputs.clone(); |
| //process childs blocks |
| for( ProgramBlock pbc : wpb.getChildBlocks() ) |
| constructGDFGraph(pbc, lroots); |
| HashMap<String,GDFNode> outputs = constructLoopOutputNodes(wsb, lroots); |
| GDFLoopNode lnode = new GDFLoopNode(wpb, pred, inputs, outputs ); |
| //construct crossblock nodes |
| constructLoopOutputCrossBlockNodes(wsb, lnode, outputs, roots, wpb); |
| } |
| else if (pb instanceof IfProgramBlock) |
| { |
| IfProgramBlock ipb = (IfProgramBlock) pb; |
| IfStatementBlock isb = (IfStatementBlock) pb.getStatementBlock(); |
| //construct predicate |
| if( isb.getPredicateHops()!=null ) { |
| Hop pred = isb.getPredicateHops(); |
| roots.put(pred.getName(), constructGDFGraph(pred, ipb, new HashMap<Long,GDFNode>(), roots)); |
| } |
| //construct if and else branch separately |
| HashMap<String,GDFNode> ifRoots = (HashMap<String, GDFNode>) roots.clone(); |
| HashMap<String,GDFNode> elseRoots = (HashMap<String, GDFNode>) roots.clone(); |
| for( ProgramBlock pbc : ipb.getChildBlocksIfBody() ) |
| constructGDFGraph(pbc, ifRoots); |
| if( ipb.getChildBlocksElseBody()!=null ) |
| for( ProgramBlock pbc : ipb.getChildBlocksElseBody() ) |
| constructGDFGraph(pbc, elseRoots); |
| //merge data flow roots (if no else, elseRoots refer to original roots) |
| reconcileMergeIfProgramBlockOutputs(ifRoots, elseRoots, roots, ipb); |
| } |
| else if (pb instanceof ForProgramBlock) //incl parfor |
| { |
| ForProgramBlock fpb = (ForProgramBlock) pb; |
| ForStatementBlock fsb = (ForStatementBlock)pb.getStatementBlock(); |
| //construct predicate node (conceptually sequence of from/to/incr) |
| GDFNode pred = constructForPredicateNode(fpb, fsb, roots); |
| HashMap<String,GDFNode> inputs = constructLoopInputNodes(fpb, fsb, roots); |
| HashMap<String,GDFNode> lroots = (HashMap<String, GDFNode>) inputs.clone(); |
| //process childs blocks |
| for( ProgramBlock pbc : fpb.getChildBlocks() ) |
| constructGDFGraph(pbc, lroots); |
| HashMap<String,GDFNode> outputs = constructLoopOutputNodes(fsb, lroots); |
| GDFLoopNode lnode = new GDFLoopNode(fpb, pred, inputs, outputs ); |
| //construct crossblock nodes |
| constructLoopOutputCrossBlockNodes(fsb, lnode, outputs, roots, fpb); |
| } |
| else //last-level program block |
| { |
| StatementBlock sb = pb.getStatementBlock(); |
| ArrayList<Hop> hops = sb.get_hops(); |
| if( hops != null ) |
| { |
| //create new local memo structure for local dag |
| HashMap<Long, GDFNode> lmemo = new HashMap<Long, GDFNode>(); |
| for( Hop hop : hops ) |
| { |
| //recursively construct GDF graph for hop dag root |
| GDFNode root = constructGDFGraph(hop, pb, lmemo, roots); |
| if( root == null ) |
| throw new HopsException( "GDFGraphBuilder: failed to constuct dag root for: "+Explain.explain(hop) ); |
| |
| //create cross block nodes for all transient writes |
| if( hop instanceof DataOp && ((DataOp)hop).getDataOpType()==DataOpTypes.TRANSIENTWRITE ) |
| root = new GDFCrossBlockNode(hop, pb, root, hop.getName()); |
| |
| //add GDF root node to global roots |
| roots.put(hop.getName(), root); |
| } |
| } |
| |
| } |
| } |
| |
| /** |
| * |
| * @param hop |
| * @param pb |
| * @param lmemo |
| * @param roots |
| * @return |
| */ |
| private static GDFNode constructGDFGraph( Hop hop, ProgramBlock pb, HashMap<Long, GDFNode> lmemo, HashMap<String, GDFNode> roots ) |
| { |
| if( lmemo.containsKey(hop.getHopID()) ) |
| return lmemo.get(hop.getHopID()); |
| |
| //process childs recursively first |
| ArrayList<GDFNode> inputs = new ArrayList<GDFNode>(); |
| for( Hop c : hop.getInput() ) |
| inputs.add( constructGDFGraph(c, pb, lmemo, roots) ); |
| |
| //connect transient reads to existing roots of data flow graph |
| if( hop instanceof DataOp && ((DataOp)hop).getDataOpType()==DataOpTypes.TRANSIENTREAD ){ |
| inputs.add(roots.get(hop.getName())); |
| } |
| |
| //add current hop |
| GDFNode gnode = new GDFNode(hop, pb, inputs); |
| |
| //add GDF node of updated variables to global roots (necessary for loops, where updated local |
| //variables might never be bound to their logical variables names |
| if( !IGNORE_UNBOUND_UPDATED_VARS ) { |
| //NOTE: currently disabled because unnecessary, if no transientwrite by definition included in other transientwrite |
| if( pb.getStatementBlock()!=null && pb.getStatementBlock().variablesUpdated().containsVariable(hop.getName()) ) { |
| roots.put(hop.getName(), gnode); |
| } |
| } |
| |
| //memoize current node |
| lmemo.put(hop.getHopID(), gnode); |
| |
| return gnode; |
| } |
| |
| /** |
| * |
| * @param fpb |
| * @param fsb |
| * @param roots |
| * @return |
| */ |
| private static GDFNode constructForPredicateNode(ForProgramBlock fpb, ForStatementBlock fsb, HashMap<String, GDFNode> roots) |
| { |
| HashMap<Long, GDFNode> memo = new HashMap<Long, GDFNode>(); |
| GDFNode from = (fsb.getFromHops()!=null)? constructGDFGraph(fsb.getFromHops(), fpb, memo, roots) : null; |
| GDFNode to = (fsb.getToHops()!=null)? constructGDFGraph(fsb.getToHops(), fpb, memo, roots) : null; |
| GDFNode incr = (fsb.getIncrementHops()!=null)? constructGDFGraph(fsb.getIncrementHops(), fpb, memo, roots) : null; |
| ArrayList<GDFNode> inputs = new ArrayList<GDFNode>(); |
| inputs.add(from); |
| inputs.add(to); |
| inputs.add(incr); |
| //TODO for predicates |
| GDFNode pred = new GDFNode(null, fpb, inputs ); |
| |
| return pred; |
| } |
| |
| /** |
| * |
| * @param fpb |
| * @param fsb |
| * @param roots |
| * @return |
| * @throws DMLRuntimeException |
| */ |
| private static HashMap<String, GDFNode> constructLoopInputNodes( ProgramBlock fpb, StatementBlock fsb, HashMap<String, GDFNode> roots ) |
| throws DMLRuntimeException |
| { |
| HashMap<String, GDFNode> ret = new HashMap<String, GDFNode>(); |
| Set<String> invars = fsb.variablesRead().getVariableNames(); |
| for( String var : invars ) { |
| if( fsb.liveIn().containsVariable(var) ) { |
| GDFNode node = roots.get(var); |
| if( node == null ) |
| throw new DMLRuntimeException("GDFGraphBuilder: Non-existing input node for variable: "+var); |
| ret.put(var, node); |
| } |
| } |
| |
| return ret; |
| } |
| |
| private static HashMap<String, GDFNode> constructLoopOutputNodes( StatementBlock fsb, HashMap<String, GDFNode> roots ) |
| throws HopsException |
| { |
| HashMap<String, GDFNode> ret = new HashMap<String, GDFNode>(); |
| |
| Set<String> outvars = fsb.variablesUpdated().getVariableNames(); |
| for( String var : outvars ) |
| { |
| GDFNode node = roots.get(var); |
| |
| //handle non-existing nodes |
| if( node == null ) { |
| if( !IGNORE_UNBOUND_UPDATED_VARS ) |
| throw new HopsException( "GDFGraphBuilder: failed to constuct loop output for variable: "+var ); |
| else |
| continue; //skip unbound updated variables |
| } |
| |
| //add existing node to loop outputs |
| ret.put(var, node); |
| } |
| |
| return ret; |
| } |
| |
| /** |
| * |
| * @param ifRoots |
| * @param elseRoots |
| * @param roots |
| * @param pb |
| */ |
| private static void reconcileMergeIfProgramBlockOutputs( HashMap<String, GDFNode> ifRoots, HashMap<String, GDFNode> elseRoots, HashMap<String, GDFNode> roots, IfProgramBlock pb ) |
| { |
| //merge same variable names, different data |
| //( incl add new vars from if branch if node2==null) |
| for( Entry<String, GDFNode> e : ifRoots.entrySet() ){ |
| GDFNode node1 = e.getValue(); |
| GDFNode node2 = elseRoots.get(e.getKey()); //original or new |
| if( node1 != node2 ) |
| node1 = new GDFCrossBlockNode(null, pb, node1, node2, e.getKey() ); |
| roots.put(e.getKey(), node1); |
| } |
| |
| //add new vars from else branch |
| for( Entry<String, GDFNode> e : elseRoots.entrySet() ){ |
| if( !ifRoots.containsKey(e.getKey()) ) |
| roots.put(e.getKey(), e.getValue()); |
| } |
| } |
| |
| /** |
| * |
| * @param sb |
| * @param loop |
| * @param loutputs |
| * @param roots |
| * @param pb |
| */ |
| private static void constructLoopOutputCrossBlockNodes(StatementBlock sb, GDFLoopNode loop, HashMap<String, GDFNode> loutputs, HashMap<String, GDFNode> roots, ProgramBlock pb) |
| { |
| //iterate over all output (updated) variables |
| for( Entry<String,GDFNode> e : loutputs.entrySet() ) |
| { |
| //create crossblocknode, if updated variable is also in liveout |
| if( sb.liveOut().containsVariable(e.getKey()) ) { |
| GDFCrossBlockNode node = null; |
| if( roots.containsKey(e.getKey()) ) |
| node = new GDFCrossBlockNode(null, pb, roots.get(e.getKey()), loop, e.getKey()); //MERGE |
| else |
| node = new GDFCrossBlockNode(null, pb, loop, e.getKey()); //PLAIN |
| roots.put(e.getKey(), node); |
| } |
| } |
| } |
| } |