| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysml.hops.rewrite; |
| |
| import java.util.ArrayList; |
| |
| import org.apache.sysml.hops.AggUnaryOp; |
| import org.apache.sysml.hops.BinaryOp; |
| import org.apache.sysml.hops.DataOp; |
| import org.apache.sysml.hops.Hop; |
| import org.apache.sysml.hops.Hop.AggOp; |
| import org.apache.sysml.hops.Hop.Direction; |
| import org.apache.sysml.hops.Hop.OpOp1; |
| import org.apache.sysml.hops.HopsException; |
| import org.apache.sysml.hops.IndexingOp; |
| import org.apache.sysml.hops.LeftIndexingOp; |
| import org.apache.sysml.hops.LiteralOp; |
| import org.apache.sysml.hops.Hop.OpOp2; |
| import org.apache.sysml.hops.UnaryOp; |
| import org.apache.sysml.parser.ForStatement; |
| import org.apache.sysml.parser.ForStatementBlock; |
| import org.apache.sysml.parser.IfStatementBlock; |
| import org.apache.sysml.parser.StatementBlock; |
| import org.apache.sysml.parser.WhileStatementBlock; |
| import org.apache.sysml.parser.Expression.DataType; |
| import org.apache.sysml.parser.Expression.ValueType; |
| |
| /** |
| * Rule: Simplify program structure by pulling if or else statement body out |
| * (removing the if statement block ifself) in order to allow intra-procedure |
| * analysis to propagate exact statistics. |
| * |
| */ |
| public class RewriteForLoopVectorization extends StatementBlockRewriteRule |
| { |
| |
| private static final OpOp2[] MAP_SCALAR_AGGREGATE_SOURCE_OPS = new OpOp2[]{OpOp2.PLUS, OpOp2.MULT, OpOp2.MIN, OpOp2.MAX}; |
| private static final AggOp[] MAP_SCALAR_AGGREGATE_TARGET_OPS = new AggOp[]{AggOp.SUM, AggOp.PROD, AggOp.MIN, AggOp.MAX}; |
| |
| |
| @Override |
| public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) |
| throws HopsException |
| { |
| ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>(); |
| |
| if( sb instanceof ForStatementBlock ) |
| { |
| ForStatementBlock fsb = (ForStatementBlock) sb; |
| ForStatement fs = (ForStatement) fsb.getStatement(0); |
| Hop from = fsb.getFromHops(); |
| Hop to = fsb.getToHops(); |
| Hop incr = fsb.getIncrementHops(); |
| String iterVar = fsb.getIterPredicate().getIterVar().getName(); |
| |
| if( fs.getBody()!=null && fs.getBody().size()==1 ) //single child block |
| { |
| StatementBlock csb = (StatementBlock) fs.getBody().get(0); |
| if( !( csb instanceof WhileStatementBlock //last level block |
| || csb instanceof IfStatementBlock |
| || csb instanceof ForStatementBlock ) ) |
| { |
| //auto vectorization pattern |
| sb = vectorizeScalarAggregate(sb, csb, from, to, incr, iterVar); //e.g., for(i){s = s + as.scalar(X[i,2])} |
| sb = vectorizeElementwiseBinary(sb, csb, from, to, incr, iterVar); |
| sb = vectorizeElementwiseUnary(sb, csb, from, to, incr, iterVar); |
| } |
| } |
| } |
| |
| //if no rewrite applied sb is the original for loop otherwise a last level statement block |
| //that includes the equivalent vectorized operations. |
| ret.add( sb ); |
| |
| return ret; |
| } |
| |
| /** |
| * Note: unnecessary row or column indexing then later removed via |
| * dynamic rewrites |
| * |
| * @param sb |
| * @param csb |
| * @param from |
| * @param to |
| * @param increment |
| * @param itervar |
| * @return |
| * @throws HopsException |
| */ |
| private StatementBlock vectorizeScalarAggregate( StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar ) |
| throws HopsException |
| { |
| StatementBlock ret = sb; |
| |
| //check missing and supported increment values |
| if( !(increment!=null && increment instanceof LiteralOp |
| && ((LiteralOp)increment).getDoubleValue()==1.0) ) { |
| return ret; |
| } |
| |
| //check for applicability |
| boolean leftScalar = false; |
| boolean rightScalar = false; |
| boolean rowIx = false; //row or col |
| |
| if( csb.get_hops()!=null && csb.get_hops().size()==1 ){ |
| Hop root = csb.get_hops().get(0); |
| |
| if( root.getDataType()==DataType.SCALAR && root.getInput().get(0) instanceof BinaryOp ) { |
| BinaryOp bop = (BinaryOp) root.getInput().get(0); |
| Hop left = bop.getInput().get(0); |
| Hop right = bop.getInput().get(1); |
| |
| //check for left scalar plus |
| if( HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) |
| && left instanceof DataOp && left.getDataType() == DataType.SCALAR |
| && root.getName().equals(left.getName()) |
| && right instanceof UnaryOp && ((UnaryOp) right).getOp() == OpOp1.CAST_AS_SCALAR |
| && right.getInput().get(0) instanceof IndexingOp ) |
| { |
| IndexingOp ix = (IndexingOp)right.getInput().get(0); |
| if( ix.getRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp |
| && ix.getInput().get(1).getName().equals(itervar) ){ |
| leftScalar = true; |
| rowIx = true; |
| } |
| else if( ix.getColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp |
| && ix.getInput().get(3).getName().equals(itervar) ){ |
| leftScalar = true; |
| rowIx = false; |
| } |
| } |
| //check for right scalar plus |
| else if( HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) |
| && right instanceof DataOp && right.getDataType() == DataType.SCALAR |
| && root.getName().equals(right.getName()) |
| && left instanceof UnaryOp && ((UnaryOp) left).getOp() == OpOp1.CAST_AS_SCALAR |
| && left.getInput().get(0) instanceof IndexingOp ) |
| { |
| IndexingOp ix = (IndexingOp)left.getInput().get(0); |
| if( ix.getRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp |
| && ix.getInput().get(1).getName().equals(itervar) ){ |
| rightScalar = true; |
| rowIx = true; |
| } |
| else if( ix.getColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp |
| && ix.getInput().get(3).getName().equals(itervar) ){ |
| rightScalar = true; |
| rowIx = false; |
| } |
| } |
| } |
| } |
| |
| //apply rewrite if possible |
| if( leftScalar || rightScalar ) |
| { |
| Hop root = csb.get_hops().get(0); |
| BinaryOp bop = (BinaryOp) root.getInput().get(0); |
| Hop cast = bop.getInput().get( leftScalar?1:0 ); |
| Hop ix = cast.getInput().get(0); |
| int aggOpPos = HopRewriteUtils.getValidOpPos(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS); |
| AggOp aggOp = MAP_SCALAR_AGGREGATE_TARGET_OPS[aggOpPos]; |
| //replace cast with sum |
| AggUnaryOp newSum = new AggUnaryOp(cast.getName(), DataType.SCALAR, ValueType.DOUBLE, aggOp, Direction.RowCol, ix); |
| HopRewriteUtils.removeChildReference(cast, ix); |
| HopRewriteUtils.removeChildReference(bop, cast); |
| HopRewriteUtils.addChildReference(bop, newSum, leftScalar?1:0 ); |
| //modify indexing expression according to loop predicate from-to |
| //NOTE: any redundant index operations are removed via dynamic algebraic simplification rewrites |
| int index1 = rowIx ? 1 : 3; |
| int index2 = rowIx ? 2 : 4; |
| HopRewriteUtils.removeChildReferenceByPos(ix, ix.getInput().get(index1), index1); |
| HopRewriteUtils.addChildReference(ix, from, index1); |
| HopRewriteUtils.removeChildReferenceByPos(ix, ix.getInput().get(index2), index2); |
| HopRewriteUtils.addChildReference(ix, to, index2); |
| |
| ret = csb; |
| //ret.liveIn().removeVariable(itervar); |
| LOG.debug("Applied vectorizeScalarSumForLoop."); |
| } |
| |
| return ret; |
| } |
| |
| /** |
| * Note: unnecessary row or column indexing then later removed via |
| * dynamic rewrites |
| * |
| * @param sb |
| * @param csb |
| * @param from |
| * @param to |
| * @param increment |
| * @param itervar |
| * @return |
| * @throws HopsException |
| */ |
| private StatementBlock vectorizeElementwiseBinary( StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar ) |
| throws HopsException |
| { |
| StatementBlock ret = sb; |
| |
| //check supported increment values |
| if( !(increment instanceof LiteralOp && ((LiteralOp)increment).getDoubleValue()==1.0) ){ |
| return ret; |
| } |
| |
| //check for applicability |
| boolean apply = false; |
| boolean rowIx = false; //row or col |
| if( csb.get_hops()!=null && csb.get_hops().size()==1 ) |
| { |
| Hop root = csb.get_hops().get(0); |
| |
| if( root.getDataType()==DataType.MATRIX && root.getInput().get(0) instanceof LeftIndexingOp ) |
| { |
| LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0); |
| Hop lixlhs = lix.getInput().get(0); |
| Hop lixrhs = lix.getInput().get(1); |
| |
| if( lixlhs instanceof DataOp && lixrhs instanceof BinaryOp |
| && lixrhs.getInput().get(0) instanceof IndexingOp |
| && lixrhs.getInput().get(1) instanceof IndexingOp |
| && lixrhs.getInput().get(0).getInput().get(0) instanceof DataOp |
| && lixrhs.getInput().get(1).getInput().get(0) instanceof DataOp) |
| { |
| IndexingOp rix0 = (IndexingOp) lixrhs.getInput().get(0); |
| IndexingOp rix1 = (IndexingOp) lixrhs.getInput().get(1); |
| |
| //check for rowwise |
| if( lix.getRowLowerEqualsUpper() && rix0.getRowLowerEqualsUpper() && rix1.getRowLowerEqualsUpper() |
| && lix.getInput().get(2).getName().equals(itervar) |
| && rix0.getInput().get(1).getName().equals(itervar) |
| && rix1.getInput().get(1).getName().equals(itervar)) |
| { |
| apply = true; |
| rowIx = true; |
| } |
| //check for colwise |
| if( lix.getColLowerEqualsUpper() && rix0.getColLowerEqualsUpper() && rix1.getColLowerEqualsUpper() |
| && lix.getInput().get(4).getName().equals(itervar) |
| && rix0.getInput().get(3).getName().equals(itervar) |
| && rix1.getInput().get(3).getName().equals(itervar)) |
| { |
| apply = true; |
| rowIx = false; |
| } |
| } |
| } |
| } |
| |
| //apply rewrite if possible |
| if( apply ) |
| { |
| Hop root = csb.get_hops().get(0); |
| LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0); |
| BinaryOp bop = (BinaryOp) lix.getInput().get(1); |
| IndexingOp rix0 = (IndexingOp) bop.getInput().get(0); |
| IndexingOp rix1 = (IndexingOp) bop.getInput().get(1); |
| int index1 = rowIx ? 2 : 4; |
| int index2 = rowIx ? 3 : 5; |
| //modify left indexing bounds |
| HopRewriteUtils.removeChildReferenceByPos(lix, lix.getInput().get(index1), index1 ); |
| HopRewriteUtils.addChildReference(lix, from, index1); |
| HopRewriteUtils.removeChildReferenceByPos(lix, lix.getInput().get(index2), index2 ); |
| HopRewriteUtils.addChildReference(lix, to, index2); |
| //modify both right indexing |
| HopRewriteUtils.removeChildReferenceByPos(rix0, rix0.getInput().get(index1-1), index1-1 ); |
| HopRewriteUtils.addChildReference(rix0, from, index1-1); |
| HopRewriteUtils.removeChildReferenceByPos(rix0, rix0.getInput().get(index2-1), index2-1 ); |
| HopRewriteUtils.addChildReference(rix0, to, index2-1); |
| HopRewriteUtils.removeChildReferenceByPos(rix1, rix1.getInput().get(index1-1), index1-1 ); |
| HopRewriteUtils.addChildReference(rix1, from, index1-1); |
| HopRewriteUtils.removeChildReferenceByPos(rix1, rix1.getInput().get(index2-1), index2-1 ); |
| HopRewriteUtils.addChildReference(rix1, to, index2-1); |
| rix0.refreshSizeInformation(); |
| rix1.refreshSizeInformation(); |
| bop.refreshSizeInformation(); |
| lix.refreshSizeInformation(); |
| |
| ret = csb; |
| //ret.liveIn().removeVariable(itervar); |
| LOG.debug("Applied vectorizeElementwiseBinaryForLoop."); |
| } |
| |
| return ret; |
| } |
| |
| /** |
| * Note: unnecessary row or column indexing then later removed via |
| * dynamic rewrites |
| * |
| * @param sb |
| * @param csb |
| * @param from |
| * @param to |
| * @param increment |
| * @param itervar |
| * @return |
| * @throws HopsException |
| */ |
| private StatementBlock vectorizeElementwiseUnary( StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar ) |
| throws HopsException |
| { |
| StatementBlock ret = sb; |
| |
| //check supported increment values |
| if( !(increment instanceof LiteralOp && ((LiteralOp)increment).getDoubleValue()==1.0) ){ |
| return ret; |
| } |
| |
| //check for applicability |
| boolean apply = false; |
| boolean rowIx = false; //row or col |
| if( csb.get_hops()!=null && csb.get_hops().size()==1 ) |
| { |
| Hop root = csb.get_hops().get(0); |
| |
| if( root.getDataType()==DataType.MATRIX && root.getInput().get(0) instanceof LeftIndexingOp ) |
| { |
| LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0); |
| Hop lixlhs = lix.getInput().get(0); |
| Hop lixrhs = lix.getInput().get(1); |
| |
| if( lixlhs instanceof DataOp && lixrhs instanceof UnaryOp |
| && lixrhs.getInput().get(0) instanceof IndexingOp |
| && lixrhs.getInput().get(0).getInput().get(0) instanceof DataOp ) |
| { |
| IndexingOp rix = (IndexingOp) lixrhs.getInput().get(0); |
| //check for rowwise |
| if( lix.getRowLowerEqualsUpper() && rix.getRowLowerEqualsUpper() |
| && lix.getInput().get(2).getName().equals(itervar) |
| && rix.getInput().get(1).getName().equals(itervar) ) |
| { |
| apply = true; |
| rowIx = true; |
| } |
| //check for colwise |
| if( lix.getColLowerEqualsUpper() && rix.getColLowerEqualsUpper() |
| && lix.getInput().get(4).getName().equals(itervar) |
| && rix.getInput().get(3).getName().equals(itervar) ) |
| { |
| apply = true; |
| rowIx = false; |
| } |
| } |
| } |
| } |
| |
| //apply rewrite if possible |
| if( apply ) |
| { |
| Hop root = csb.get_hops().get(0); |
| LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0); |
| UnaryOp uop = (UnaryOp) lix.getInput().get(1); |
| IndexingOp rix = (IndexingOp) uop.getInput().get(0); |
| int index1 = rowIx ? 2 : 4; |
| int index2 = rowIx ? 3 : 5; |
| //modify left indexing bounds |
| HopRewriteUtils.removeChildReferenceByPos(lix, lix.getInput().get(index1), index1 ); |
| HopRewriteUtils.addChildReference(lix, from, index1); |
| HopRewriteUtils.removeChildReferenceByPos(lix, lix.getInput().get(index2), index2 ); |
| HopRewriteUtils.addChildReference(lix, to, index2); |
| //modify right indexing |
| HopRewriteUtils.removeChildReferenceByPos(rix, rix.getInput().get(index1-1), index1-1 ); |
| HopRewriteUtils.addChildReference(rix, from, index1-1); |
| HopRewriteUtils.removeChildReferenceByPos(rix, rix.getInput().get(index2-1), index2-1 ); |
| HopRewriteUtils.addChildReference(rix, to, index2-1); |
| rix.refreshSizeInformation(); |
| uop.refreshSizeInformation(); |
| lix.refreshSizeInformation(); |
| |
| ret = csb; |
| //ret.liveIn().removeVariable(itervar); |
| LOG.debug("Applied vectorizeElementwiseUnaryForLoop."); |
| } |
| |
| return ret; |
| } |
| |
| |
| } |