blob: 3c10f02732f7eaa8f40853d185c074c223d9d420 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
import java.util.List;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LeftIndexingOp;
import org.apache.sysds.hops.LiteralOp;
/**
* Rule: Indexing vectorization. This rewrite rule set simplifies
* multiple right / left indexing accesses within a DAG into row/column
* index accesses, which is beneficial for two reasons: (1) it is an
* enabler for later row/column partitioning, and (2) it reduces the number
* of operations over potentially large data (i.e., prevents unnecessary MR
* operations and reduces pressure on the buffer pool due to copy on write
* on left indexing).
*
*/
public class RewriteIndexingVectorization extends HopRewriteRule
{
@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
if( roots == null )
return roots;
for( Hop h : roots )
rule_IndexingVectorization( h );
return roots;
}
@Override
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
if( root == null )
return root;
rule_IndexingVectorization( root );
return root;
}
private void rule_IndexingVectorization( Hop hop ) {
if(hop.isVisited())
return;
//recursively process children
for( int i=0; i<hop.getInput().size(); i++)
{
Hop hi = hop.getInput().get(i);
//apply indexing vectorization rewrites
//MB: disabled right indexing rewrite because (1) piggybacked in MR anyway, (2) usually
//not too much overhead, and (3) makes literal replacement more difficult
hi = vectorizeRightLeftIndexingChains(hi); //e.g., B[,1]=A[,1]; B[,2]=A[2]; -> B[,1:2] = A[,1:2]
//vectorizeRightIndexing( hi ); //e.g., multiple rightindexing X[i,1], X[i,3] -> X[i,];
hi = vectorizeLeftIndexing( hi ); //e.g., multiple left indexing X[i,1], X[i,3] -> X[i,];
//process childs recursively after rewrites
rule_IndexingVectorization( hi );
}
hop.setVisited();
}
private static Hop vectorizeRightLeftIndexingChains(Hop hi) {
//check for valid root operator
if( !(hi instanceof LeftIndexingOp
&& hi.getInput().get(1) instanceof IndexingOp
&& hi.getInput().get(1).getParent().size()==1) )
return hi;
LeftIndexingOp lix0 = (LeftIndexingOp) hi;
IndexingOp rix0 = (IndexingOp) hi.getInput().get(1);
if( !(lix0.isRowLowerEqualsUpper() || lix0.isColLowerEqualsUpper())
|| lix0.isRowLowerEqualsUpper() != rix0.isRowLowerEqualsUpper()
|| lix0.isColLowerEqualsUpper() != rix0.isColLowerEqualsUpper())
return hi;
boolean row = lix0.isRowLowerEqualsUpper();
if( !( (row ? HopRewriteUtils.isFullRowIndexing(lix0) : HopRewriteUtils.isFullColumnIndexing(lix0))
&& (row ? HopRewriteUtils.isFullRowIndexing(rix0) : HopRewriteUtils.isFullColumnIndexing(rix0))) )
return hi;
//determine consecutive left-right indexing chains for rows/columns
List<LeftIndexingOp> lix = new ArrayList<>(); lix.add(lix0);
List<IndexingOp> rix = new ArrayList<>(); rix.add(rix0);
LeftIndexingOp clix = lix0;
IndexingOp crix = rix0;
while( isConsecutiveLeftRightIndexing(clix, crix, clix.getInput().get(0))
&& clix.getInput().get(0).getParent().size()==1
&& clix.getInput().get(0).getInput().get(1).getParent().size()==1 ) {
clix = (LeftIndexingOp)clix.getInput().get(0);
crix = (IndexingOp)clix.getInput().get(1);
lix.add(clix); rix.add(crix);
}
//rewrite pattern if at least two consecutive pairs
if( lix.size() >= 2 ) {
IndexingOp rixn = rix.get(rix.size()-1);
Hop rlrix = rixn.getInput().get(1);
Hop rurix = row ? HopRewriteUtils.createBinary(rlrix, new LiteralOp(rix.size()-1), OpOp2.PLUS) : rixn.getInput().get(2);
Hop clrix = rixn.getInput().get(3);
Hop curix = row ? rixn.getInput().get(4) : HopRewriteUtils.createBinary(clrix, new LiteralOp(rix.size()-1), OpOp2.PLUS);
IndexingOp rixNew = HopRewriteUtils.createIndexingOp(rixn.getInput().get(0), rlrix, rurix, clrix, curix);
LeftIndexingOp lixn = lix.get(rix.size()-1);
Hop rllix = lixn.getInput().get(2);
Hop rulix = row ? HopRewriteUtils.createBinary(rllix, new LiteralOp(lix.size()-1), OpOp2.PLUS) : lixn.getInput().get(3);
Hop cllix = lixn.getInput().get(4);
Hop culix = row ? lixn.getInput().get(5) : HopRewriteUtils.createBinary(cllix, new LiteralOp(lix.size()-1), OpOp2.PLUS);
LeftIndexingOp lixNew = HopRewriteUtils.createLeftIndexingOp(lixn.getInput().get(0), rixNew, rllix, rulix, cllix, culix);
//rewire parents and childs
HopRewriteUtils.replaceChildReference(hi.getParent().get(0), hi, lixNew);
for( int i=0; i<lix.size(); i++ ) {
HopRewriteUtils.removeAllChildReferences(lix.get(i));
HopRewriteUtils.removeAllChildReferences(rix.get(i));
}
hi = lixNew;
LOG.debug("Applied vectorizeRightLeftIndexingChains (line "+hi.getBeginLine()+")");
}
return hi;
}
private static boolean isConsecutiveLeftRightIndexing(LeftIndexingOp lix, IndexingOp rix, Hop input) {
if( !(input instanceof LeftIndexingOp
&& input.getInput().get(1) instanceof IndexingOp) )
return false;
boolean row = lix.isRowLowerEqualsUpper();
LeftIndexingOp lix2 = (LeftIndexingOp) input;
IndexingOp rix2 = (IndexingOp) input.getInput().get(1);
//check row/column access with full row/column indexing
boolean access = (row ? HopRewriteUtils.isFullRowIndexing(lix2) && HopRewriteUtils.isFullRowIndexing(rix2) :
HopRewriteUtils.isFullColumnIndexing(lix2) && HopRewriteUtils.isFullColumnIndexing(rix2));
//check equivalent right indexing inputs
boolean rixInputs = (rix.getInput().get(0) == rix2.getInput().get(0));
//check consecutive access
boolean consecutive = (row ? HopRewriteUtils.isConsecutiveIndex(lix2.getInput().get(2), lix.getInput().get(2))
&& HopRewriteUtils.isConsecutiveIndex(rix2.getInput().get(1), rix.getInput().get(1)) :
HopRewriteUtils.isConsecutiveIndex(lix2.getInput().get(4), lix.getInput().get(4))
&& HopRewriteUtils.isConsecutiveIndex(rix2.getInput().get(3), rix.getInput().get(3)));
return access && rixInputs && consecutive;
}
/**
* Note: unnecessary row or column indexing then later removed via
* dynamic rewrites
*
* @param hop high-level operator
*/
@SuppressWarnings("unused")
private static void vectorizeRightIndexing( Hop hop )
{
if( hop instanceof IndexingOp ) //right indexing
{
IndexingOp ihop0 = (IndexingOp) hop;
boolean isSingleRow = ihop0.isRowLowerEqualsUpper();
boolean isSingleCol = ihop0.isColLowerEqualsUpper();
boolean appliedRow = false;
//search for multiple indexing in same row
if( isSingleRow && isSingleCol ){
Hop input = ihop0.getInput().get(0);
//find candidate set
//dependence on common subexpression elimination to find equal input / row expression
ArrayList<Hop> ihops = new ArrayList<Hop>();
ihops.add(ihop0);
for( Hop c : input.getParent() ){
if( c != ihop0 && c instanceof IndexingOp && c.getInput().get(0) == input
&& ((IndexingOp) c).isRowLowerEqualsUpper()
&& c.getInput().get(1)==ihop0.getInput().get(1) )
{
ihops.add( c );
}
}
//apply rewrite if found candidates
if( ihops.size() > 1 ){
//new row indexing operator
IndexingOp newRix = new IndexingOp("tmp", input.getDataType(), input.getValueType(), input,
ihop0.getInput().get(1), ihop0.getInput().get(1), new LiteralOp(1),
HopRewriteUtils.createValueHop(input, false), true, false);
HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getBlocksize(), -1);
newRix.refreshSizeInformation();
//rewire current operator and all candidates
for( Hop c : ihops ) {
HopRewriteUtils.removeChildReference(c, input); //input data
HopRewriteUtils.addChildReference(c, newRix, 0);
HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(1),1); //row lower expr
HopRewriteUtils.addChildReference(c, new LiteralOp(1), 1);
HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(2),2); //row upper expr
HopRewriteUtils.addChildReference(c, new LiteralOp(1), 2);
c.refreshSizeInformation();
}
appliedRow = true;
LOG.debug("Applied vectorizeRightIndexingRow");
}
}
//search for multiple indexing in same col
if( isSingleRow && isSingleCol && !appliedRow ){
Hop input = ihop0.getInput().get(0);
//find candidate set
//dependence on common subexpression elimination to find equal input / row expression
ArrayList<Hop> ihops = new ArrayList<Hop>();
ihops.add(ihop0);
for( Hop c : input.getParent() ){
if( c != ihop0 && c instanceof IndexingOp && c.getInput().get(0) == input
&& ((IndexingOp) c).isColLowerEqualsUpper()
&& c.getInput().get(3)==ihop0.getInput().get(3) )
{
ihops.add( c );
}
}
//apply rewrite if found candidates
if( ihops.size() > 1 ){
//new row indexing operator
IndexingOp newRix = new IndexingOp("tmp", input.getDataType(), input.getValueType(), input,
new LiteralOp(1), HopRewriteUtils.createValueHop(input, true),
ihop0.getInput().get(3), ihop0.getInput().get(3), false, true);
HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getBlocksize(), -1);
newRix.refreshSizeInformation();
//rewire current operator and all candidates
for( Hop c : ihops ) {
HopRewriteUtils.removeChildReference(c, input); //input data
HopRewriteUtils.addChildReference(c, newRix, 0);
HopRewriteUtils.replaceChildReference(c, c.getInput().get(3), new LiteralOp(1), 3); //col lower expr
HopRewriteUtils.replaceChildReference(c, c.getInput().get(4), new LiteralOp(1), 4); //col upper expr
c.refreshSizeInformation();
}
LOG.debug("Applied vectorizeRightIndexingCol");
}
}
}
}
@SuppressWarnings("unchecked")
private static Hop vectorizeLeftIndexing( Hop hop )
{
Hop ret = hop;
if( hop instanceof LeftIndexingOp ) //left indexing
{
LeftIndexingOp ihop0 = (LeftIndexingOp) hop;
boolean isSingleRow = ihop0.isRowLowerEqualsUpper();
boolean isSingleCol = ihop0.isColLowerEqualsUpper();
boolean appliedRow = false;
if( isSingleRow && isSingleCol )
{
//collect simple chains (w/o multiple consumers) of left indexing ops
ArrayList<Hop> ihops = new ArrayList<>();
ihops.add(ihop0);
Hop current = ihop0;
while( current.getInput().get(0) instanceof LeftIndexingOp ) {
LeftIndexingOp tmp = (LeftIndexingOp) current.getInput().get(0);
if( tmp.getParent().size()>1 //multiple consumers, i.e., not a simple chain
|| !tmp.isRowLowerEqualsUpper() //row merge not applicable
|| tmp.getInput().get(2) != ihop0.getInput().get(2) //not the same row
|| tmp.getInput().get(0).getDim2() <= 1 ) //target is single column or unknown
{
break;
}
ihops.add( tmp );
current = tmp;
}
//apply rewrite if found candidates
if( ihops.size() > 1 ){
Hop input = current.getInput().get(0);
Hop rowExpr = ihop0.getInput().get(2); //keep before reset
//new row indexing operator
IndexingOp newRix = new IndexingOp("tmp1", input.getDataType(), input.getValueType(),
input, rowExpr, rowExpr, new LiteralOp(1),
HopRewriteUtils.createValueHop(input, false), true, false);
HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getBlocksize(), -1);
newRix.refreshSizeInformation();
//reset visit status of copied hops (otherwise hidden by left indexing)
for( Hop c : newRix.getInput() )
c.resetVisitStatus();
//rewrite bottom left indexing operator
HopRewriteUtils.removeChildReference(current, input); //input data
HopRewriteUtils.addChildReference(current, newRix, 0);
//reset row index all candidates and refresh sizes (bottom-up)
for( int i=ihops.size()-1; i>=0; i-- ) {
Hop c = ihops.get(i);
HopRewriteUtils.replaceChildReference(c, c.getInput().get(2), new LiteralOp(1), 2); //row lower expr
HopRewriteUtils.replaceChildReference(c, c.getInput().get(3), new LiteralOp(1), 3); //row upper expr
((LeftIndexingOp)c).setRowLowerEqualsUpper(true);
c.refreshSizeInformation();
}
//new row left indexing operator (for all parents, only intermediates are guaranteed to have 1 parent)
//(note: it's important to clone the parent list before creating newLix on top of ihop0)
ArrayList<Hop> ihop0parents = (ArrayList<Hop>) ihop0.getParent().clone();
ArrayList<Integer> ihop0parentsPos = new ArrayList<>();
for( Hop parent : ihop0parents ) {
int posp = HopRewriteUtils.getChildReferencePos(parent, ihop0);
HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp); //input data
ihop0parentsPos.add(posp);
}
LeftIndexingOp newLix = new LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(),
input, ihop0, rowExpr, rowExpr, new LiteralOp(1),
HopRewriteUtils.createValueHop(input, false), true, false);
HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getBlocksize(), -1);
newLix.refreshSizeInformation();
//reset visit status of copied hops (otherwise hidden by left indexing)
for( Hop c : newLix.getInput() )
c.resetVisitStatus();
for( int i=0; i<ihop0parentsPos.size(); i++ ) {
Hop parent = ihop0parents.get(i);
int posp = ihop0parentsPos.get(i);
HopRewriteUtils.addChildReference(parent, newLix, posp);
}
appliedRow = true;
ret = newLix;
LOG.debug("Applied vectorizeLeftIndexingRow for hop "+hop.getHopID());
}
}
if( isSingleRow && isSingleCol && !appliedRow )
{
//collect simple chains (w/o multiple consumers) of left indexing ops
ArrayList<Hop> ihops = new ArrayList<>();
ihops.add(ihop0);
Hop current = ihop0;
while( current.getInput().get(0) instanceof LeftIndexingOp ) {
LeftIndexingOp tmp = (LeftIndexingOp) current.getInput().get(0);
if( tmp.getParent().size()>1 //multiple consumers, i.e., not a simple chain
|| !tmp.isColLowerEqualsUpper() //row merge not applicable
|| tmp.getInput().get(4) != ihop0.getInput().get(4) //not the same col
|| tmp.getInput().get(0).getDim1() <= 1 ) //target is single row or unknown
{
break;
}
ihops.add( tmp );
current = tmp;
}
//apply rewrite if found candidates
if( ihops.size() > 1 ){
Hop input = current.getInput().get(0);
Hop colExpr = ihop0.getInput().get(4); //keep before reset
//new row indexing operator
IndexingOp newRix = new IndexingOp("tmp1", input.getDataType(), input.getValueType(),
input, new LiteralOp(1), HopRewriteUtils.createValueHop(input, true),
colExpr, colExpr, false, true);
HopRewriteUtils.setOutputParameters(newRix, -1, -1, input.getBlocksize(), -1);
newRix.refreshSizeInformation();
//reset visit status of copied hops (otherwise hidden by left indexing)
for( Hop c : newRix.getInput() )
c.resetVisitStatus();
//rewrite bottom left indexing operator
HopRewriteUtils.removeChildReference(current, input); //input data
HopRewriteUtils.addChildReference(current, newRix, 0);
//reset col index all candidates and refresh sizes (bottom-up)
for( int i=ihops.size()-1; i>=0; i-- ) {
Hop c = ihops.get(i);
HopRewriteUtils.replaceChildReference(c, c.getInput().get(4), new LiteralOp(1), 4); //col lower expr
HopRewriteUtils.replaceChildReference(c, c.getInput().get(5), new LiteralOp(1), 5); //col upper expr
((LeftIndexingOp)c).setColLowerEqualsUpper(true);
c.refreshSizeInformation();
}
//new row left indexing operator (for all parents, only intermediates are guaranteed to have 1 parent)
//(note: it's important to clone the parent list before creating newLix on top of ihop0)
ArrayList<Hop> ihop0parents = (ArrayList<Hop>) ihop0.getParent().clone();
ArrayList<Integer> ihop0parentsPos = new ArrayList<>();
for( Hop parent : ihop0parents ) {
int posp = HopRewriteUtils.getChildReferencePos(parent, ihop0);
HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp); //input data
ihop0parentsPos.add(posp);
}
LeftIndexingOp newLix = new LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(),
input, ihop0, new LiteralOp(1), HopRewriteUtils.createValueHop(input, true),
colExpr, colExpr, false, true);
HopRewriteUtils.setOutputParameters(newLix, -1, -1, input.getBlocksize(), -1);
newLix.refreshSizeInformation();
//reset visit status of copied hops (otherwise hidden by left indexing)
for( Hop c : newLix.getInput() )
c.resetVisitStatus();
for( int i=0; i<ihop0parentsPos.size(); i++ ) {
Hop parent = ihop0parents.get(i);
int posp = ihop0parentsPos.get(i);
HopRewriteUtils.addChildReference(parent, newLix, posp);
}
ret = newLix;
LOG.debug("Applied vectorizeLeftIndexingCol for hop "+hop.getHopID());
}
}
}
return ret;
}
}