blob: 97205e4761b2e67e6cf3d58b25d51d84a8ac7f39 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops.rewrite;
import java.util.ArrayList;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.LeftIndexingOp;
import org.apache.sysml.hops.LiteralOp;
/**
* Rule: Indexing vectorization. This rewrite rule set simplifies
* multiple right / left indexing accesses within a DAG into row/column
* index accesses, which is beneficial for two reasons: (1) it is an
* enabler for later row/column partitioning, and (2) it reduces the number
* of operations over potentially large data (i.e., prevents unnecessary MR
* operations and reduces pressure on the buffer pool due to copy on write
* on left indexing).
*
*/
public class RewriteIndexingVectorization extends HopRewriteRule
{
private static final Log LOG = LogFactory.getLog(RewriteIndexingVectorization.class.getName());
@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state)
throws HopsException
{
if( roots == null )
return roots;
for( Hop h : roots )
rule_IndexingVectorization( h );
return roots;
}
@Override
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state)
throws HopsException
{
if( root == null )
return root;
rule_IndexingVectorization( root );
return root;
}
/**
*
* @param hop
* @param descendFirst
* @throws HopsException
*/
private void rule_IndexingVectorization( Hop hop )
throws HopsException
{
if(hop.getVisited() == Hop.VisitStatus.DONE)
return;
//recursively process children
for( int i=0; i<hop.getInput().size(); i++)
{
Hop hi = hop.getInput().get(i);
//apply indexing vectorization rewrites
//MB: disabled right indexing rewrite because (1) piggybacked in MR anyway, (2) usually
//not too much overhead, and (3) makes literal replacement more difficult
//vectorizeRightIndexing( hi ); //e.g., multiple rightindexing X[i,1], X[i,3] -> X[i,];
vectorizeLeftIndexing( hi ); //e.g., multiple left indexing X[i,1], X[i,3] -> X[i,];
//process childs recursively after rewrites
rule_IndexingVectorization( hi );
}
hop.setVisited(Hop.VisitStatus.DONE);
}
/**
* Note: unnecessary row or column indexing then later removed via
* dynamic rewrites
*
* @param hop
* @throws HopsException
*/
@SuppressWarnings("unused")
private void vectorizeRightIndexing( Hop hop )
throws HopsException
{
if( hop instanceof IndexingOp ) //right indexing
{
IndexingOp ihop0 = (IndexingOp) hop;
boolean isSingleRow = ihop0.getRowLowerEqualsUpper();
boolean isSingleCol = ihop0.getColLowerEqualsUpper();
boolean appliedRow = false;
//search for multiple indexing in same row
if( isSingleRow && isSingleCol ){
Hop input = ihop0.getInput().get(0);
//find candidate set
//dependence on common subexpression elimination to find equal input / row expression
ArrayList<Hop> ihops = new ArrayList<Hop>();
ihops.add(ihop0);
for( Hop c : input.getParent() ){
if( c != ihop0 && c instanceof IndexingOp && c.getInput().get(0) == input
&& ((IndexingOp) c).getRowLowerEqualsUpper()
&& c.getInput().get(1)==ihop0.getInput().get(1) )
{
ihops.add( c );
}
}
//apply rewrite if found candidates
if( ihops.size() > 1 ){
//new row indexing operator
IndexingOp newRix = new IndexingOp("tmp", input.getDataType(), input.getValueType(), input,
ihop0.getInput().get(1), ihop0.getInput().get(1), new LiteralOp(1),
HopRewriteUtils.createValueHop(input, false), true, false);
HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
newRix.refreshSizeInformation();
//rewire current operator and all candidates
for( Hop c : ihops ) {
HopRewriteUtils.removeChildReference(c, input); //input data
HopRewriteUtils.addChildReference(c, newRix, 0);
HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(1),1); //row lower expr
HopRewriteUtils.addChildReference(c, new LiteralOp(1), 1);
HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(2),2); //row upper expr
HopRewriteUtils.addChildReference(c, new LiteralOp(1), 2);
c.refreshSizeInformation();
}
appliedRow = true;
LOG.debug("Applied vectorizeRightIndexingRow");
}
}
//search for multiple indexing in same col
if( isSingleRow && isSingleCol && !appliedRow ){
Hop input = ihop0.getInput().get(0);
//find candidate set
//dependence on common subexpression elimination to find equal input / row expression
ArrayList<Hop> ihops = new ArrayList<Hop>();
ihops.add(ihop0);
for( Hop c : input.getParent() ){
if( c != ihop0 && c instanceof IndexingOp && c.getInput().get(0) == input
&& ((IndexingOp) c).getColLowerEqualsUpper()
&& c.getInput().get(3)==ihop0.getInput().get(3) )
{
ihops.add( c );
}
}
//apply rewrite if found candidates
if( ihops.size() > 1 ){
//new row indexing operator
IndexingOp newRix = new IndexingOp("tmp", input.getDataType(), input.getValueType(), input,
new LiteralOp(1), HopRewriteUtils.createValueHop(input, true),
ihop0.getInput().get(3), ihop0.getInput().get(3), false, true);
HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
newRix.refreshSizeInformation();
//rewire current operator and all candidates
for( Hop c : ihops ) {
HopRewriteUtils.removeChildReference(c, input); //input data
HopRewriteUtils.addChildReference(c, newRix, 0);
HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(3),3); //col lower expr
HopRewriteUtils.addChildReference(c, new LiteralOp(1), 3);
HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(4),4); //col upper expr
HopRewriteUtils.addChildReference(c, new LiteralOp(1), 4);
c.refreshSizeInformation();
}
LOG.debug("Applied vectorizeRightIndexingCol");
}
}
}
}
/**
*
* @param hop
* @throws HopsException
*/
@SuppressWarnings("unchecked")
private void vectorizeLeftIndexing( Hop hop )
throws HopsException
{
if( hop instanceof LeftIndexingOp ) //left indexing
{
LeftIndexingOp ihop0 = (LeftIndexingOp) hop;
boolean isSingleRow = ihop0.getRowLowerEqualsUpper();
boolean isSingleCol = ihop0.getColLowerEqualsUpper();
boolean appliedRow = false;
if( isSingleRow && isSingleCol )
{
//collect simple chains (w/o multiple consumers) of left indexing ops
ArrayList<Hop> ihops = new ArrayList<Hop>();
ihops.add(ihop0);
Hop current = ihop0;
while( current.getInput().get(0) instanceof LeftIndexingOp ) {
LeftIndexingOp tmp = (LeftIndexingOp) current.getInput().get(0);
if( tmp.getParent().size()>1 //multiple consumers, i.e., not a simple chain
|| !((LeftIndexingOp) tmp).getRowLowerEqualsUpper() //row merge not applicable
|| tmp.getInput().get(2) != ihop0.getInput().get(2) //not the same row
|| tmp.getInput().get(0).getDim2() <= 1 ) //target is single column or unknown
{
break;
}
ihops.add( tmp );
current = tmp;
}
//apply rewrite if found candidates
if( ihops.size() > 1 ){
Hop input = current.getInput().get(0);
Hop rowExpr = ihop0.getInput().get(2); //keep before reset
//new row indexing operator
IndexingOp newRix = new IndexingOp("tmp1", input.getDataType(), input.getValueType(), input,
rowExpr, rowExpr, new LiteralOp(1),
HopRewriteUtils.createValueHop(input, false), true, false);
HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
newRix.refreshSizeInformation();
//rewrite bottom left indexing operator
HopRewriteUtils.removeChildReference(current, input); //input data
HopRewriteUtils.addChildReference(current, newRix, 0);
//reset row index all candidates and refresh sizes (bottom-up)
for( int i=ihops.size()-1; i>=0; i-- ) {
Hop c = ihops.get(i);
HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(2), 2); //row lower expr
HopRewriteUtils.addChildReference(c, new LiteralOp(1), 2);
HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(3), 3); //row upper expr
HopRewriteUtils.addChildReference(c, new LiteralOp(1), 3);
((LeftIndexingOp)c).setRowLowerEqualsUpper(true);
c.refreshSizeInformation();
}
//new row left indexing operator (for all parents, only intermediates are guaranteed to have 1 parent)
//(note: it's important to clone the parent list before creating newLix on top of ihop0)
ArrayList<Hop> ihop0parents = (ArrayList<Hop>) ihop0.getParent().clone();
ArrayList<Integer> ihop0parentsPos = new ArrayList<Integer>();
for( Hop parent : ihop0parents ) {
int posp = HopRewriteUtils.getChildReferencePos(parent, ihop0);
HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp); //input data
ihop0parentsPos.add(posp);
}
LeftIndexingOp newLix = new LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), input, ihop0,
rowExpr, rowExpr, new LiteralOp(1),
HopRewriteUtils.createValueHop(input, false), true, false);
HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
newLix.refreshSizeInformation();
for( int i=0; i<ihop0parentsPos.size(); i++ ) {
Hop parent = ihop0parents.get(i);
int posp = ihop0parentsPos.get(i);
HopRewriteUtils.addChildReference(parent, newLix, posp);
}
appliedRow = true;
LOG.debug("Applied vectorizeLeftIndexingRow");
}
}
if( isSingleRow && isSingleCol && !appliedRow )
{
//collect simple chains (w/o multiple consumers) of left indexing ops
ArrayList<Hop> ihops = new ArrayList<Hop>();
ihops.add(ihop0);
Hop current = ihop0;
while( current.getInput().get(0) instanceof LeftIndexingOp ) {
LeftIndexingOp tmp = (LeftIndexingOp) current.getInput().get(0);
if( tmp.getParent().size()>1 //multiple consumers, i.e., not a simple chain
|| !((LeftIndexingOp) tmp).getColLowerEqualsUpper() //row merge not applicable
|| tmp.getInput().get(4) != ihop0.getInput().get(4) //not the same col
|| tmp.getInput().get(0).getDim1() <= 1 ) //target is single row or unknown
{
break;
}
ihops.add( tmp );
current = tmp;
}
//apply rewrite if found candidates
if( ihops.size() > 1 ){
Hop input = current.getInput().get(0);
Hop colExpr = ihop0.getInput().get(4); //keep before reset
//new row indexing operator
IndexingOp newRix = new IndexingOp("tmp1", input.getDataType(), input.getValueType(), input,
new LiteralOp(1), HopRewriteUtils.createValueHop(input, true),
colExpr, colExpr, false, true);
HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
newRix.refreshSizeInformation();
//rewrite bottom left indexing operator
HopRewriteUtils.removeChildReference(current, input); //input data
HopRewriteUtils.addChildReference(current, newRix, 0);
//reset col index all candidates and refresh sizes (bottom-up)
for( int i=ihops.size()-1; i>=0; i-- ) {
Hop c = ihops.get(i);
HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(4), 4); //col lower expr
HopRewriteUtils.addChildReference(c, new LiteralOp(1), 4);
HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(5), 5); //col upper expr
HopRewriteUtils.addChildReference(c, new LiteralOp(1), 5);
((LeftIndexingOp)c).setColLowerEqualsUpper(true);
c.refreshSizeInformation();
}
//new row left indexing operator (for all parents, only intermediates are guaranteed to have 1 parent)
//(note: it's important to clone the parent list before creating newLix on top of ihop0)
ArrayList<Hop> ihop0parents = (ArrayList<Hop>) ihop0.getParent().clone();
ArrayList<Integer> ihop0parentsPos = new ArrayList<Integer>();
for( Hop parent : ihop0parents ) {
int posp = HopRewriteUtils.getChildReferencePos(parent, ihop0);
HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp); //input data
ihop0parentsPos.add(posp);
}
LeftIndexingOp newLix = new LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), input, ihop0,
new LiteralOp(1), HopRewriteUtils.createValueHop(input, true),
colExpr, colExpr, false, true);
HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getRowsInBlock(), input.getColsInBlock(), -1);
newLix.refreshSizeInformation();
for( int i=0; i<ihop0parentsPos.size(); i++ ) {
Hop parent = ihop0parents.get(i);
int posp = ihop0parentsPos.get(i);
HopRewriteUtils.addChildReference(parent, newLix, posp);
}
appliedRow = true;
LOG.debug("Applied vectorizeLeftIndexingCol");
}
}
}
}
}