blob: c773319b749950e211bae808e1ea87cfa37000ba [file] [log] [blame]
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
package org.apache.sysml.hops.rewrite;
import java.util.ArrayList;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.Direction;
import org.apache.sysml.hops.Hop.OpOp1;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.LeftIndexingOp;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.Hop.OpOp2;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.parser.ForStatement;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatementBlock;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
* Rule: Simplify program structure by pulling if or else statement body out
* (removing the if statement block ifself) in order to allow intra-procedure
* analysis to propagate exact statistics.
public class RewriteForLoopVectorization extends StatementBlockRewriteRule
private static final OpOp2[] MAP_SCALAR_AGGREGATE_SOURCE_OPS = new OpOp2[]{OpOp2.PLUS, OpOp2.MULT, OpOp2.MIN, OpOp2.MAX};
private static final AggOp[] MAP_SCALAR_AGGREGATE_TARGET_OPS = new AggOp[]{AggOp.SUM, AggOp.PROD, AggOp.MIN, AggOp.MAX};
public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state)
throws HopsException
ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>();
if( sb instanceof ForStatementBlock )
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fs = (ForStatement) fsb.getStatement(0);
Hop from = fsb.getFromHops();
Hop to = fsb.getToHops();
Hop incr = fsb.getIncrementHops();
String iterVar = fsb.getIterPredicate().getIterVar().getName();
if( fs.getBody()!=null && fs.getBody().size()==1 ) //single child block
StatementBlock csb = (StatementBlock) fs.getBody().get(0);
if( !( csb instanceof WhileStatementBlock //last level block
|| csb instanceof IfStatementBlock
|| csb instanceof ForStatementBlock ) )
//auto vectorization pattern
sb = vectorizeScalarAggregate(sb, csb, from, to, incr, iterVar); //e.g., for(i){s = s + as.scalar(X[i,2])}
sb = vectorizeElementwiseBinary(sb, csb, from, to, incr, iterVar);
sb = vectorizeElementwiseUnary(sb, csb, from, to, incr, iterVar);
//if no rewrite applied sb is the original for loop otherwise a last level statement block
//that includes the equivalent vectorized operations.
ret.add( sb );
return ret;
* Note: unnecessary row or column indexing then later removed via
* dynamic rewrites
* @param sb
* @param csb
* @param from
* @param to
* @param increment
* @param itervar
* @return
* @throws HopsException
private StatementBlock vectorizeScalarAggregate( StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar )
throws HopsException
StatementBlock ret = sb;
//check missing and supported increment values
if( !(increment!=null && increment instanceof LiteralOp
&& ((LiteralOp)increment).getDoubleValue()==1.0) ) {
return ret;
//check for applicability
boolean leftScalar = false;
boolean rightScalar = false;
boolean rowIx = false; //row or col
if( csb.get_hops()!=null && csb.get_hops().size()==1 ){
Hop root = csb.get_hops().get(0);
if( root.getDataType()==DataType.SCALAR && root.getInput().get(0) instanceof BinaryOp ) {
BinaryOp bop = (BinaryOp) root.getInput().get(0);
Hop left = bop.getInput().get(0);
Hop right = bop.getInput().get(1);
//check for left scalar plus
if( HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS)
&& left instanceof DataOp && left.getDataType() == DataType.SCALAR
&& root.getName().equals(left.getName())
&& right instanceof UnaryOp && ((UnaryOp) right).getOp() == OpOp1.CAST_AS_SCALAR
&& right.getInput().get(0) instanceof IndexingOp )
IndexingOp ix = (IndexingOp)right.getInput().get(0);
if( ix.getRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp
&& ix.getInput().get(1).getName().equals(itervar) ){
leftScalar = true;
rowIx = true;
else if( ix.getColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp
&& ix.getInput().get(3).getName().equals(itervar) ){
leftScalar = true;
rowIx = false;
//check for right scalar plus
else if( HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS)
&& right instanceof DataOp && right.getDataType() == DataType.SCALAR
&& root.getName().equals(right.getName())
&& left instanceof UnaryOp && ((UnaryOp) left).getOp() == OpOp1.CAST_AS_SCALAR
&& left.getInput().get(0) instanceof IndexingOp )
IndexingOp ix = (IndexingOp)left.getInput().get(0);
if( ix.getRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp
&& ix.getInput().get(1).getName().equals(itervar) ){
rightScalar = true;
rowIx = true;
else if( ix.getColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp
&& ix.getInput().get(3).getName().equals(itervar) ){
rightScalar = true;
rowIx = false;
//apply rewrite if possible
if( leftScalar || rightScalar )
Hop root = csb.get_hops().get(0);
BinaryOp bop = (BinaryOp) root.getInput().get(0);
Hop cast = bop.getInput().get( leftScalar?1:0 );
Hop ix = cast.getInput().get(0);
int aggOpPos = HopRewriteUtils.getValidOpPos(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS);
//replace cast with sum
AggUnaryOp newSum = new AggUnaryOp(cast.getName(), DataType.SCALAR, ValueType.DOUBLE, aggOp, Direction.RowCol, ix);
HopRewriteUtils.removeChildReference(cast, ix);
HopRewriteUtils.removeChildReference(bop, cast);
HopRewriteUtils.addChildReference(bop, newSum, leftScalar?1:0 );
//modify indexing expression according to loop predicate from-to
//NOTE: any redundant index operations are removed via dynamic algebraic simplification rewrites
int index1 = rowIx ? 1 : 3;
int index2 = rowIx ? 2 : 4;
HopRewriteUtils.removeChildReferenceByPos(ix, ix.getInput().get(index1), index1);
HopRewriteUtils.addChildReference(ix, from, index1);
HopRewriteUtils.removeChildReferenceByPos(ix, ix.getInput().get(index2), index2);
HopRewriteUtils.addChildReference(ix, to, index2);
ret = csb;
LOG.debug("Applied vectorizeScalarSumForLoop.");
return ret;
* Note: unnecessary row or column indexing then later removed via
* dynamic rewrites
* @param sb
* @param csb
* @param from
* @param to
* @param increment
* @param itervar
* @return
* @throws HopsException
private StatementBlock vectorizeElementwiseBinary( StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar )
throws HopsException
StatementBlock ret = sb;
//check supported increment values
if( !(increment instanceof LiteralOp && ((LiteralOp)increment).getDoubleValue()==1.0) ){
return ret;
//check for applicability
boolean apply = false;
boolean rowIx = false; //row or col
if( csb.get_hops()!=null && csb.get_hops().size()==1 )
Hop root = csb.get_hops().get(0);
if( root.getDataType()==DataType.MATRIX && root.getInput().get(0) instanceof LeftIndexingOp )
LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0);
Hop lixlhs = lix.getInput().get(0);
Hop lixrhs = lix.getInput().get(1);
if( lixlhs instanceof DataOp && lixrhs instanceof BinaryOp
&& lixrhs.getInput().get(0) instanceof IndexingOp
&& lixrhs.getInput().get(1) instanceof IndexingOp
&& lixrhs.getInput().get(0).getInput().get(0) instanceof DataOp
&& lixrhs.getInput().get(1).getInput().get(0) instanceof DataOp)
IndexingOp rix0 = (IndexingOp) lixrhs.getInput().get(0);
IndexingOp rix1 = (IndexingOp) lixrhs.getInput().get(1);
//check for rowwise
if( lix.getRowLowerEqualsUpper() && rix0.getRowLowerEqualsUpper() && rix1.getRowLowerEqualsUpper()
&& lix.getInput().get(2).getName().equals(itervar)
&& rix0.getInput().get(1).getName().equals(itervar)
&& rix1.getInput().get(1).getName().equals(itervar))
apply = true;
rowIx = true;
//check for colwise
if( lix.getColLowerEqualsUpper() && rix0.getColLowerEqualsUpper() && rix1.getColLowerEqualsUpper()
&& lix.getInput().get(4).getName().equals(itervar)
&& rix0.getInput().get(3).getName().equals(itervar)
&& rix1.getInput().get(3).getName().equals(itervar))
apply = true;
rowIx = false;
//apply rewrite if possible
if( apply )
Hop root = csb.get_hops().get(0);
LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0);
BinaryOp bop = (BinaryOp) lix.getInput().get(1);
IndexingOp rix0 = (IndexingOp) bop.getInput().get(0);
IndexingOp rix1 = (IndexingOp) bop.getInput().get(1);
int index1 = rowIx ? 2 : 4;
int index2 = rowIx ? 3 : 5;
//modify left indexing bounds
HopRewriteUtils.removeChildReferenceByPos(lix, lix.getInput().get(index1), index1 );
HopRewriteUtils.addChildReference(lix, from, index1);
HopRewriteUtils.removeChildReferenceByPos(lix, lix.getInput().get(index2), index2 );
HopRewriteUtils.addChildReference(lix, to, index2);
//modify both right indexing
HopRewriteUtils.removeChildReferenceByPos(rix0, rix0.getInput().get(index1-1), index1-1 );
HopRewriteUtils.addChildReference(rix0, from, index1-1);
HopRewriteUtils.removeChildReferenceByPos(rix0, rix0.getInput().get(index2-1), index2-1 );
HopRewriteUtils.addChildReference(rix0, to, index2-1);
HopRewriteUtils.removeChildReferenceByPos(rix1, rix1.getInput().get(index1-1), index1-1 );
HopRewriteUtils.addChildReference(rix1, from, index1-1);
HopRewriteUtils.removeChildReferenceByPos(rix1, rix1.getInput().get(index2-1), index2-1 );
HopRewriteUtils.addChildReference(rix1, to, index2-1);
ret = csb;
LOG.debug("Applied vectorizeElementwiseBinaryForLoop.");
return ret;
* Note: unnecessary row or column indexing then later removed via
* dynamic rewrites
* @param sb
* @param csb
* @param from
* @param to
* @param increment
* @param itervar
* @return
* @throws HopsException
private StatementBlock vectorizeElementwiseUnary( StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar )
throws HopsException
StatementBlock ret = sb;
//check supported increment values
if( !(increment instanceof LiteralOp && ((LiteralOp)increment).getDoubleValue()==1.0) ){
return ret;
//check for applicability
boolean apply = false;
boolean rowIx = false; //row or col
if( csb.get_hops()!=null && csb.get_hops().size()==1 )
Hop root = csb.get_hops().get(0);
if( root.getDataType()==DataType.MATRIX && root.getInput().get(0) instanceof LeftIndexingOp )
LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0);
Hop lixlhs = lix.getInput().get(0);
Hop lixrhs = lix.getInput().get(1);
if( lixlhs instanceof DataOp && lixrhs instanceof UnaryOp
&& lixrhs.getInput().get(0) instanceof IndexingOp
&& lixrhs.getInput().get(0).getInput().get(0) instanceof DataOp )
IndexingOp rix = (IndexingOp) lixrhs.getInput().get(0);
//check for rowwise
if( lix.getRowLowerEqualsUpper() && rix.getRowLowerEqualsUpper()
&& lix.getInput().get(2).getName().equals(itervar)
&& rix.getInput().get(1).getName().equals(itervar) )
apply = true;
rowIx = true;
//check for colwise
if( lix.getColLowerEqualsUpper() && rix.getColLowerEqualsUpper()
&& lix.getInput().get(4).getName().equals(itervar)
&& rix.getInput().get(3).getName().equals(itervar) )
apply = true;
rowIx = false;
//apply rewrite if possible
if( apply )
Hop root = csb.get_hops().get(0);
LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0);
UnaryOp uop = (UnaryOp) lix.getInput().get(1);
IndexingOp rix = (IndexingOp) uop.getInput().get(0);
int index1 = rowIx ? 2 : 4;
int index2 = rowIx ? 3 : 5;
//modify left indexing bounds
HopRewriteUtils.removeChildReferenceByPos(lix, lix.getInput().get(index1), index1 );
HopRewriteUtils.addChildReference(lix, from, index1);
HopRewriteUtils.removeChildReferenceByPos(lix, lix.getInput().get(index2), index2 );
HopRewriteUtils.addChildReference(lix, to, index2);
//modify right indexing
HopRewriteUtils.removeChildReferenceByPos(rix, rix.getInput().get(index1-1), index1-1 );
HopRewriteUtils.addChildReference(rix, from, index1-1);
HopRewriteUtils.removeChildReferenceByPos(rix, rix.getInput().get(index2-1), index2-1 );
HopRewriteUtils.addChildReference(rix, to, index2-1);
ret = csb;
LOG.debug("Applied vectorizeElementwiseUnaryForLoop.");
return ret;