| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysml.hops.rewrite; |
| |
| import java.util.ArrayList; |
| import java.util.HashMap; |
| |
| import org.apache.sysml.hops.DataOp; |
| import org.apache.sysml.hops.Hop; |
| import org.apache.sysml.hops.HopsException; |
| import org.apache.sysml.hops.LiteralOp; |
| |
| /** |
| * Rule: CommonSubexpressionElimination. For all statement blocks, |
| * eliminate common subexpressions within dags by merging equivalent |
| * operators (same input, equal parameters) bottom-up. For the moment, |
| * this only applies within a dag, later this should be extended across |
| * statements block (global, inter-procedure). |
| */ |
| public class RewriteCommonSubexpressionElimination extends HopRewriteRule |
| { |
| |
| private boolean _mergeLeafs = true; |
| |
| public RewriteCommonSubexpressionElimination() |
| { |
| this( true ); //default full CSE |
| } |
| |
| public RewriteCommonSubexpressionElimination( boolean mergeLeafs ) |
| { |
| _mergeLeafs = mergeLeafs; |
| } |
| |
| @Override |
| public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) |
| throws HopsException |
| { |
| if( roots == null ) |
| return null; |
| |
| HashMap<String, Hop> dataops = new HashMap<String, Hop>(); |
| HashMap<String, Hop> literalops = new HashMap<String, Hop>(); //key: <VALUETYPE>_<LITERAL> |
| for (Hop h : roots) |
| { |
| int cseMerged = 0; |
| if( _mergeLeafs ) { |
| cseMerged += rule_CommonSubexpressionElimination_MergeLeafs(h, dataops, literalops); |
| h.resetVisitStatus(); |
| } |
| cseMerged += rule_CommonSubexpressionElimination(h); |
| |
| if( cseMerged > 0 ) |
| LOG.debug("Common Subexpression Elimination - removed "+cseMerged+" operators."); |
| } |
| |
| return roots; |
| } |
| |
| @Override |
| public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) |
| throws HopsException |
| { |
| if( root == null ) |
| return null; |
| |
| HashMap<String, Hop> dataops = new HashMap<String, Hop>(); |
| HashMap<String, Hop> literalops = new HashMap<String, Hop>(); //key: <VALUETYPE>_<LITERAL> |
| int cseMerged = 0; |
| if( _mergeLeafs ) { |
| cseMerged += rule_CommonSubexpressionElimination_MergeLeafs(root, dataops, literalops); |
| root.resetVisitStatus(); |
| } |
| cseMerged += rule_CommonSubexpressionElimination(root); |
| |
| if( cseMerged > 0 ) |
| LOG.debug("Common Subexpression Elimination - removed "+cseMerged+" operators."); |
| |
| return root; |
| } |
| |
| /** |
| * |
| * @param dataops |
| * @param literalops |
| * @return |
| * @throws HopsException |
| */ |
| private int rule_CommonSubexpressionElimination_MergeLeafs( Hop hop, HashMap<String, Hop> dataops, HashMap<String, Hop> literalops ) |
| throws HopsException |
| { |
| int ret = 0; |
| if( hop.getVisited() == Hop.VisitStatus.DONE ) |
| return ret; |
| |
| if( hop.getInput().isEmpty() ) //LEAF NODE |
| { |
| if( hop instanceof LiteralOp ) |
| { |
| String key = hop.getValueType()+"_"+hop.getName(); |
| if( !literalops.containsKey(key) ) |
| literalops.put(key, hop); |
| } |
| else if( hop instanceof DataOp && ((DataOp)hop).isRead()) |
| { |
| if(!dataops.containsKey(hop.getName()) ) |
| dataops.put(hop.getName(), hop); |
| } |
| } |
| else //INNER NODE |
| { |
| //merge leaf nodes (data, literal) |
| for( int i=0; i<hop.getInput().size(); i++ ) |
| { |
| Hop hi = hop.getInput().get(i); |
| String litKey = hi.getValueType()+"_"+hi.getName(); |
| if( hi instanceof DataOp && ((DataOp)hi).isRead() && dataops.containsKey(hi.getName()) ) |
| { |
| |
| //replace child node ref |
| Hop tmp = dataops.get(hi.getName()); |
| if( tmp != hi ) { //if required |
| tmp.getParent().add(hop); |
| hop.getInput().set(i, tmp); |
| ret++; |
| } |
| } |
| else if( hi instanceof LiteralOp && literalops.containsKey(litKey) ) |
| { |
| Hop tmp = literalops.get(litKey); |
| |
| //replace child node ref |
| if( tmp != hi ){ //if required |
| tmp.getParent().add(hop); |
| hop.getInput().set(i, tmp); |
| ret++; |
| } |
| } |
| |
| //recursive invocation (direct return on merged nodes) |
| ret += rule_CommonSubexpressionElimination_MergeLeafs(hi, dataops, literalops); |
| } |
| } |
| |
| hop.setVisited(Hop.VisitStatus.DONE); |
| return ret; |
| } |
| |
| /** |
| * |
| * @param dataops |
| * @param literalops |
| * @return |
| * @throws HopsException |
| */ |
| private int rule_CommonSubexpressionElimination( Hop hop ) |
| throws HopsException |
| { |
| int ret = 0; |
| if( hop.getVisited() == Hop.VisitStatus.DONE ) |
| return ret; |
| |
| //step 1: merge childs recursively first |
| for(Hop hi : hop.getInput()) |
| ret += rule_CommonSubexpressionElimination(hi); |
| |
| |
| //step 2: merge parent nodes |
| if( hop.getParent().size()>1 ) //multiple consumers |
| { |
| //for all pairs |
| for( int i=0; i<hop.getParent().size()-1; i++ ) |
| for( int j=i+1; j<hop.getParent().size(); j++ ) |
| { |
| Hop h1 = hop.getParent().get(i); |
| Hop h2 = hop.getParent().get(j); |
| |
| if( h1==h2 ) |
| { |
| //do nothing, note: we should not remove redundant parent links |
| //(otherwise rewrites would need to take this property into account) |
| |
| //remove redundant h2 from parent list |
| //hop.getParent().remove(j); |
| //j--; |
| } |
| else if( h1.compare(h2) ) //merge h2 into h1 |
| { |
| //remove h2 from parent list |
| hop.getParent().remove(j); |
| |
| //replace h2 w/ h1 in h2-parent inputs |
| ArrayList<Hop> parent = h2.getParent(); |
| for( Hop p : parent ) |
| for( int k=0; k<p.getInput().size(); k++ ) |
| if( p.getInput().get(k)==h2 ) |
| { |
| p.getInput().set(k, h1); |
| h1.getParent().add(p); |
| } |
| |
| //replace h2 w/ h1 in h2-input parents |
| for( Hop in : h2.getInput() ) |
| in.getParent().remove(h2); |
| |
| ret++; |
| j--; |
| } |
| } |
| } |
| |
| hop.setVisited(Hop.VisitStatus.DONE); |
| |
| return ret; |
| } |
| |
| } |