blob: d620b7d10a9bbbbbd13ee735f1069e2e47231b2d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
import java.util.HashMap;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.util.UtilFunctions;
/**
* Rule: CommonSubexpressionElimination. For all statement blocks,
* eliminate common subexpressions within dags by merging equivalent
* operators (same input, equal parameters) bottom-up. For the moment,
* this only applies within a dag, later this should be extended across
* statements block (global, inter-procedure).
*/
public class RewriteCommonSubexpressionElimination extends HopRewriteRule
{
private final boolean _mergeLeafs;
public RewriteCommonSubexpressionElimination() {
this( true ); //default full CSE
}
public RewriteCommonSubexpressionElimination( boolean mergeLeafs ) {
_mergeLeafs = mergeLeafs;
}
@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state)
{
if( roots == null )
return null;
//CSE pass 1: merge leaf nodes by name
int cseMerged = 0;
if( _mergeLeafs ) {
HashMap<String, Hop> dataops = new HashMap<>();
HashMap<LiteralKey, Hop> literalops = new HashMap<>();
for (Hop h : roots)
cseMerged += rule_CommonSubexpressionElimination_MergeLeafs(h, dataops, literalops);
Hop.resetVisitStatus(roots);
}
//CSE pass 2: bottom-up merge of inner nodes
for (Hop h : roots)
cseMerged += rule_CommonSubexpressionElimination(h);
if( cseMerged > 0 )
LOG.debug("Common Subexpression Elimination - removed "+cseMerged+" operators.");
return roots;
}
@Override
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state)
{
if( root == null )
return null;
//CSE pass 1: merge leaf nodes by name
int cseMerged = 0;
if( _mergeLeafs ) {
HashMap<String, Hop> dataops = new HashMap<>();
HashMap<LiteralKey, Hop> literalops = new HashMap<>();
cseMerged += rule_CommonSubexpressionElimination_MergeLeafs(root, dataops, literalops);
root.resetVisitStatus();
}
//CSE pass 2: bottom-up merge of inner nodes
cseMerged += rule_CommonSubexpressionElimination(root);
if( cseMerged > 0 )
LOG.debug("Common Subexpression Elimination - removed "+cseMerged+" operators.");
return root;
}
private int rule_CommonSubexpressionElimination_MergeLeafs( Hop hop,
HashMap<String, Hop> dataops, HashMap<LiteralKey, Hop> literalops )
{
if( hop.isVisited() )
return 0;
int ret = 0;
if( hop.getInput().isEmpty() //LEAF NODE
|| HopRewriteUtils.isData(hop, OpOpData.TRANSIENTREAD) )
{
if( hop instanceof LiteralOp ) {
LiteralKey key = new LiteralKey(hop.getValueType(), hop.getName());
if( !literalops.containsKey(key) )
literalops.put(key, hop);
}
else if( hop instanceof DataOp && ((DataOp)hop).isRead()
&& !dataops.containsKey(hop.getName())) {
dataops.put(hop.getName(), hop);
}
}
else //INNER NODE
{
//merge leaf nodes (data, literal)
for( int i=0; i<hop.getInput().size(); i++ )
{
Hop hi = hop.getInput().get(i);
LiteralKey litKey = new LiteralKey(hi.getValueType(), hi.getName());
if( hi instanceof DataOp && ((DataOp)hi).isRead() && dataops.containsKey(hi.getName()) ) {
//replace child node ref
Hop tmp = dataops.get(hi.getName());
if( tmp != hi ) { //if required
tmp.getParent().add(hop);
tmp.setVisited();
hop.getInput().set(i, tmp);
ret++;
}
}
else if( hi instanceof LiteralOp && literalops.containsKey(litKey) ) {
Hop tmp = literalops.get(litKey);
//replace child node ref
if( tmp != hi ){ //if required
tmp.getParent().add(hop);
tmp.setVisited();
hop.getInput().set(i, tmp);
ret++;
}
}
//recursive invocation (direct return on merged nodes)
ret += rule_CommonSubexpressionElimination_MergeLeafs(hi, dataops, literalops);
}
}
hop.setVisited();
return ret;
}
private int rule_CommonSubexpressionElimination( Hop hop )
{
if( hop.isVisited() )
return 0;
//step 1: merge childs recursively first
int ret = 0;
for(Hop hi : hop.getInput())
ret += rule_CommonSubexpressionElimination(hi);
//step 2: merge parent nodes
if( hop.getParent().size()>1 ) //multiple consumers
{
//for all pairs
for( int i=0; i<hop.getParent().size()-1; i++ )
for( int j=i+1; j<hop.getParent().size(); j++ ) {
Hop h1 = hop.getParent().get(i);
Hop h2 = hop.getParent().get(j);
if( h1==h2 ) {
//do nothing, note: we should not remove redundant parent links
//(otherwise rewrites would need to take this property into account)
//remove redundant h2 from parent list
//hop.getParent().remove(j);
//j--;
}
else if( h1.compare(h2) ) { //merge h2 into h1
//remove h2 from parent list
hop.getParent().remove(j);
//replace h2 w/ h1 in h2-parent inputs
ArrayList<Hop> parent = h2.getParent();
for( Hop p : parent )
for( int k=0; k<p.getInput().size(); k++ )
if( p.getInput().get(k)==h2 ) {
p.getInput().set(k, h1);
h1.getParent().add(p);
h1.setVisited();
}
//replace h2 w/ h1 in h2-input parents
for( Hop in : h2.getInput() )
in.getParent().remove(h2);
ret++;
j--;
}
}
}
hop.setVisited();
return ret;
}
protected static class LiteralKey {
private final int _vtType;
private final String _name;
public LiteralKey(ValueType vt, String name) {
_vtType = vt.ordinal();
_name = name;
}
@Override
public int hashCode() {
return UtilFunctions.longHashCode(_vtType, _name.hashCode());
}
@Override
public boolean equals(Object o) {
return (o instanceof LiteralKey
&& _vtType == ((LiteralKey)o)._vtType
&& _name.equals(((LiteralKey)o)._name));
}
}
}