| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.hops.rewrite; |
| |
| import java.util.ArrayList; |
| import java.util.HashMap; |
| import java.util.HashSet; |
| import java.util.LinkedHashMap; |
| import java.util.List; |
| import java.util.Set; |
| |
| import org.apache.sysds.hops.AggBinaryOp; |
| import org.apache.sysds.hops.AggUnaryOp; |
| import org.apache.sysds.hops.BinaryOp; |
| import org.apache.sysds.hops.DataGenOp; |
| import org.apache.sysds.hops.Hop; |
| import org.apache.sysds.hops.IndexingOp; |
| import org.apache.sysds.hops.LiteralOp; |
| import org.apache.sysds.hops.NaryOp; |
| import org.apache.sysds.hops.OptimizerUtils; |
| import org.apache.sysds.hops.ParameterizedBuiltinOp; |
| import org.apache.sysds.hops.ReorgOp; |
| import org.apache.sysds.hops.TernaryOp; |
| import org.apache.sysds.hops.UnaryOp; |
| import org.apache.sysds.common.Types.AggOp; |
| import org.apache.sysds.common.Types.Direction; |
| import org.apache.sysds.common.Types.OpOp1; |
| import org.apache.sysds.common.Types.OpOp2; |
| import org.apache.sysds.common.Types.OpOp3; |
| import org.apache.sysds.common.Types.OpOpDG; |
| import org.apache.sysds.common.Types.OpOpN; |
| import org.apache.sysds.common.Types.ParamBuiltinOp; |
| import org.apache.sysds.common.Types.ReOrgOp; |
| import org.apache.sysds.parser.DataExpression; |
| import org.apache.sysds.parser.Statement; |
| import org.apache.sysds.common.Types.DataType; |
| import org.apache.sysds.common.Types.ValueType; |
| |
| /** |
| * Rule: Algebraic Simplifications. Simplifies binary expressions |
| * in terms of two major purposes: (1) rewrite binary operations |
| * to unary operations when possible (in CP this reduces the memory |
| * estimate, in MR this allows map-only operations and hence prevents |
| * unnecessary shuffle and sort) and (2) remove binary operations that |
| * are in itself are unnecessary (e.g., *1 and /1). |
| * |
| */ |
| public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule |
| { |
| //valid aggregation operation types for rowOp to colOp conversions and vice versa |
| private static final AggOp[] LOOKUP_VALID_ROW_COL_AGGREGATE = new AggOp[] { |
| AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.VAR}; |
| |
| //valid binary operations for distributive and associate reorderings |
| private static final OpOp2[] LOOKUP_VALID_DISTRIBUTIVE_BINARY = new OpOp2[] {OpOp2.PLUS, OpOp2.MINUS}; |
| private static final OpOp2[] LOOKUP_VALID_ASSOCIATIVE_BINARY = new OpOp2[] {OpOp2.PLUS, OpOp2.MULT}; |
| |
| //valid binary operations for scalar operations |
| private static final OpOp2[] LOOKUP_VALID_SCALAR_BINARY = new OpOp2[] {OpOp2.AND, OpOp2.DIV, |
| OpOp2.EQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL, OpOp2.INTDIV, OpOp2.LESS, OpOp2.LESSEQUAL, |
| OpOp2.LOG, OpOp2.MAX, OpOp2.MIN, OpOp2.MINUS, OpOp2.MODULUS, OpOp2.MULT, OpOp2.NOTEQUAL, |
| OpOp2.OR, OpOp2.PLUS, OpOp2.POW}; |
| |
| @Override |
| public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) |
| { |
| if( roots == null ) |
| return roots; |
| |
| //one pass rewrite-descend (rewrite created pattern) |
| for( Hop h : roots ) |
| rule_AlgebraicSimplification( h, false ); |
| Hop.resetVisitStatus(roots, true); |
| |
| //one pass descend-rewrite (for rollup) |
| for( Hop h : roots ) |
| rule_AlgebraicSimplification( h, true ); |
| Hop.resetVisitStatus(roots, true); |
| |
| return roots; |
| } |
| |
| @Override |
| public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) |
| { |
| if( root == null ) |
| return root; |
| |
| //one pass rewrite-descend (rewrite created pattern) |
| rule_AlgebraicSimplification( root, false ); |
| |
| root.resetVisitStatus(); |
| |
| //one pass descend-rewrite (for rollup) |
| rule_AlgebraicSimplification( root, true ); |
| |
| return root; |
| } |
| |
| |
| /** |
| * Note: X/y -> X * 1/y would be useful because * cheaper than / and sparsesafe; however, |
| * (1) the results would be not exactly the same (2 rounds instead of 1) and (2) it should |
| * come before constant folding while the other simplifications should come after constant |
| * folding. Hence, not applied yet. |
| * |
| * @param hop high-level operator |
| * @param descendFirst if process children recursively first |
| */ |
| private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) |
| { |
| if(hop.isVisited()) |
| return; |
| |
| //recursively process children |
| for( int i=0; i<hop.getInput().size(); i++) |
| { |
| Hop hi = hop.getInput().get(i); |
| |
| //process childs recursively first (to allow roll-up) |
| if( descendFirst ) |
| rule_AlgebraicSimplification(hi, descendFirst); //see below |
| |
| //apply actual simplification rewrites (of childs incl checks) |
| hi = removeUnnecessaryVectorizeOperation(hi); //e.g., matrix(1,nrow(X),ncol(X))/X -> 1/X |
| hi = removeUnnecessaryBinaryOperation(hop, hi, i); //e.g., X*1 -> X (dep: should come after rm unnecessary vectorize) |
| hi = fuseDatagenAndBinaryOperation(hi); //e.g., rand(min=-1,max=1)*7 -> rand(min=-7,max=7) |
| hi = fuseDatagenAndMinusOperation(hi); //e.g., -(rand(min=-2,max=1)) -> rand(min=-1,max=2) |
| hi = foldMultipleAppendOperations(hi); //e.g., cbind(X,cbind(Y,Z)) -> cbind(X,Y,Z) |
| hi = simplifyBinaryToUnaryOperation(hop, hi, i); //e.g., X*X -> X^2 (pow2), X+X -> X*2, (X>0)-(X<0) -> sign(X) |
| hi = canonicalizeMatrixMultScalarAdd(hi); //e.g., eps+U%*%t(V) -> U%*%t(V)+eps, U%*%t(V)-eps -> U%*%t(V)+(-eps) |
| hi = simplifyCTableWithConstMatrixInputs(hi); //e.g., table(X, matrix(1,...)) -> table(X, 1) |
| hi = removeUnnecessaryCTable(hop, hi, i); //e.g., sum(table(X, 1)) -> nrow(X) and sum(table(1, Y)) -> nrow(Y) and sum(table(X, Y)) -> nrow(X) |
| hi = simplifyReverseOperation(hop, hi, i); //e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X) |
| if(OptimizerUtils.ALLOW_OPERATOR_FUSION) |
| hi = simplifyMultiBinaryToBinaryOperation(hi); //e.g., 1-X*Y -> X 1-* Y |
| hi = simplifyDistributiveBinaryOperation(hop, hi, i);//e.g., (X-Y*X) -> (1-Y)*X |
| hi = simplifyBushyBinaryOperation(hop, hi, i); //e.g., (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v) |
| hi = simplifyUnaryAggReorgOperation(hop, hi, i); //e.g., sum(t(X)) -> sum(X) |
| hi = removeUnnecessaryAggregates(hi); //e.g., sum(rowSums(X)) -> sum(X) |
| hi = simplifyBinaryMatrixScalarOperation(hop, hi, i);//e.g., as.scalar(X*s) -> as.scalar(X)*s; |
| hi = pushdownUnaryAggTransposeOperation(hop, hi, i); //e.g., colSums(t(X)) -> t(rowSums(X)) |
| hi = pushdownCSETransposeScalarOperation(hop, hi, i);//e.g., a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X) |
| hi = pushdownSumBinaryMult(hop, hi, i); //e.g., sum(lamda*X) -> lamda*sum(X) |
| hi = simplifyUnaryPPredOperation(hop, hi, i); //e.g., abs(ppred()) -> ppred(), others: round, ceil, floor |
| hi = simplifyTransposedAppend(hop, hi, i); //e.g., t(cbind(t(A),t(B))) -> rbind(A,B); |
| if(OptimizerUtils.ALLOW_OPERATOR_FUSION) |
| hi = fuseBinarySubDAGToUnaryOperation(hop, hi, i); //e.g., X*(1-X)-> sprop(X) || 1/(1+exp(-X)) -> sigmoid(X) || X*(X>0) -> selp(X) |
| hi = simplifyTraceMatrixMult(hop, hi, i); //e.g., trace(X%*%Y)->sum(X*t(Y)); |
| hi = simplifySlicedMatrixMult(hop, hi, i); //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1]; |
| hi = simplifyConstantSort(hop, hi, i); //e.g., order(matrix())->matrix/seq; |
| hi = simplifyOrderedSort(hop, hi, i); //e.g., order(matrix())->seq; |
| hi = fuseOrderOperationChain(hi); //e.g., order(order(X,2),1) -> order(X,(12)) |
| hi = removeUnnecessaryReorgOperation(hop, hi, i); //e.g., t(t(X))->X; rev(rev(X))->X potentially introduced by other rewrites |
| hi = removeUnnecessaryRemoveEmpty(hop, hi, i); //e.g., nrow(removeEmpty(A)) -> nnz(A) iff col vector |
| hi = simplifyTransposeAggBinBinaryChains(hop, hi, i);//e.g., t(t(A)%*%t(B)+C) -> B%*%A+t(C) |
| hi = simplifyReplaceZeroOperation(hop, hi, i); //e.g., X + (X==0) * s -> replace(X, 0, s) |
| hi = removeUnnecessaryMinus(hop, hi, i); //e.g., -(-X)->X; potentially introduced by simplify binary or dyn rewrites |
| hi = simplifyGroupedAggregate(hi); //e.g., aggregate(target=X,groups=y,fn="count") -> aggregate(target=y,groups=y,fn="count") |
| if(OptimizerUtils.ALLOW_OPERATOR_FUSION) { |
| hi = fuseMinusNzBinaryOperation(hop, hi, i); //e.g., X-mean*ppred(X,0,!=) -> X -nz mean |
| hi = fuseLogNzUnaryOperation(hop, hi, i); //e.g., ppred(X,0,"!=")*log(X) -> log_nz(X) |
| hi = fuseLogNzBinaryOperation(hop, hi, i); //e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5) |
| } |
| hi = simplifyOuterSeqExpand(hop, hi, i); //e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false) |
| hi = simplifyBinaryComparisonChain(hop, hi, i); //e.g., outer(v1,v2,"==")==1 -> outer(v1,v2,"=="), outer(v1,v2,"==")==0 -> outer(v1,v2,"!="), |
| hi = simplifyCumsumColOrFullAggregates(hi); //e.g., colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1)) |
| hi = simplifyCumsumReverse(hop, hi, i); //e.g., rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X) |
| |
| //hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X)) |
| |
| //process childs recursively after rewrites (to investigate pattern newly created by rewrites) |
| if( !descendFirst ) |
| rule_AlgebraicSimplification(hi, descendFirst); |
| } |
| |
| hop.setVisited(); |
| } |
| |
| private static Hop removeUnnecessaryVectorizeOperation(Hop hi) |
| { |
| //applies to all binary matrix operations, if one input is unnecessarily vectorized |
| if( hi instanceof BinaryOp && hi.getDataType()==DataType.MATRIX |
| && ((BinaryOp)hi).supportsMatrixScalarOperations() ) |
| { |
| BinaryOp bop = (BinaryOp)hi; |
| Hop left = bop.getInput().get(0); |
| Hop right = bop.getInput().get(1); |
| |
| //NOTE: these rewrites of binary cell operations need to be aware that right is |
| //potentially a vector but the result is of the size of left |
| //TODO move to dynamic rewrites (since size dependent to account for mv binary cell and outer operations) |
| |
| if( !(left.getDim1()>1 && left.getDim2()==1 && right.getDim1()==1 && right.getDim2()>1) ) // no outer |
| { |
| //check and remove right vectorized scalar |
| if( left.getDataType() == DataType.MATRIX && right instanceof DataGenOp ) |
| { |
| DataGenOp dright = (DataGenOp) right; |
| if( dright.getOp()==OpOpDG.RAND && dright.hasConstantValue() ) |
| { |
| Hop drightIn = dright.getInput().get(dright.getParamIndex(DataExpression.RAND_MIN)); |
| HopRewriteUtils.replaceChildReference(bop, dright, drightIn, 1); |
| HopRewriteUtils.cleanupUnreferenced(dright); |
| |
| LOG.debug("Applied removeUnnecessaryVectorizeOperation1"); |
| } |
| } |
| //check and remove left vectorized scalar |
| else if( right.getDataType() == DataType.MATRIX && left instanceof DataGenOp ) |
| { |
| DataGenOp dleft = (DataGenOp) left; |
| if( dleft.getOp()==OpOpDG.RAND && dleft.hasConstantValue() |
| && (left.getDim2()==1 || right.getDim2()>1) |
| && (left.getDim1()==1 || right.getDim1()>1)) |
| { |
| Hop dleftIn = dleft.getInput().get(dleft.getParamIndex(DataExpression.RAND_MIN)); |
| HopRewriteUtils.replaceChildReference(bop, dleft, dleftIn, 0); |
| HopRewriteUtils.cleanupUnreferenced(dleft); |
| |
| LOG.debug("Applied removeUnnecessaryVectorizeOperation2"); |
| } |
| } |
| |
| //Note: we applied this rewrite to at most one side in order to keep the |
| //output semantically equivalent. However, future extensions might consider |
| //to remove vectors from both side, compute the binary op on scalars and |
| //finally feed it into a datagenop of the original dimensions. |
| } |
| } |
| |
| return hi; |
| } |
| |
| |
| /** |
| * handle removal of unnecessary binary operations |
| * |
| * X/1 or X*1 or 1*X or X-0 -> X |
| * -1*X or X*-1-> -X |
| * |
| * @param parent parent high-level operator |
| * @param hi high-level operator |
| * @param pos position |
| * @return high-level operator |
| */ |
| private static Hop removeUnnecessaryBinaryOperation( Hop parent, Hop hi, int pos ) |
| { |
| if( hi instanceof BinaryOp ) |
| { |
| BinaryOp bop = (BinaryOp)hi; |
| Hop left = bop.getInput().get(0); |
| Hop right = bop.getInput().get(1); |
| //X/1 or X*1 -> X |
| if( left.getDataType()==DataType.MATRIX |
| && right instanceof LiteralOp && ((LiteralOp)right).getDoubleValue()==1.0 ) |
| { |
| if( bop.getOp()==OpOp2.DIV || bop.getOp()==OpOp2.MULT ) |
| { |
| HopRewriteUtils.replaceChildReference(parent, bop, left, pos); |
| hi = left; |
| |
| LOG.debug("Applied removeUnnecessaryBinaryOperation1 (line "+bop.getBeginLine()+")"); |
| } |
| } |
| //X-0 -> X |
| else if( left.getDataType()==DataType.MATRIX |
| && right instanceof LiteralOp && ((LiteralOp)right).getDoubleValue()==0.0 ) |
| { |
| if( bop.getOp()==OpOp2.MINUS ) |
| { |
| HopRewriteUtils.replaceChildReference(parent, bop, left, pos); |
| hi = left; |
| |
| LOG.debug("Applied removeUnnecessaryBinaryOperation2 (line "+bop.getBeginLine()+")"); |
| } |
| } |
| //1*X -> X |
| else if( right.getDataType()==DataType.MATRIX |
| && left instanceof LiteralOp && ((LiteralOp)left).getDoubleValue()==1.0 ) |
| { |
| if( bop.getOp()==OpOp2.MULT ) |
| { |
| HopRewriteUtils.replaceChildReference(parent, bop, right, pos); |
| hi = right; |
| |
| LOG.debug("Applied removeUnnecessaryBinaryOperation3 (line "+bop.getBeginLine()+")"); |
| } |
| } |
| //-1*X -> -X |
| //note: this rewrite is necessary since the new antlr parser always converts |
| //-X to -1*X due to mechanical reasons |
| else if( right.getDataType()==DataType.MATRIX |
| && left instanceof LiteralOp && ((LiteralOp)left).getDoubleValue()==-1.0 ) |
| { |
| if( bop.getOp()==OpOp2.MULT ) |
| { |
| bop.setOp(OpOp2.MINUS); |
| HopRewriteUtils.replaceChildReference(bop, left, new LiteralOp(0), 0); |
| hi = bop; |
| |
| LOG.debug("Applied removeUnnecessaryBinaryOperation4 (line "+bop.getBeginLine()+")"); |
| } |
| } |
| //X*-1 -> -X (see comment above) |
| else if( left.getDataType()==DataType.MATRIX |
| && right instanceof LiteralOp && ((LiteralOp)right).getDoubleValue()==-1.0 ) |
| { |
| if( bop.getOp()==OpOp2.MULT ) |
| { |
| bop.setOp(OpOp2.MINUS); |
| HopRewriteUtils.removeChildReferenceByPos(bop, right, 1); |
| HopRewriteUtils.addChildReference(bop, new LiteralOp(0), 0); |
| hi = bop; |
| |
| LOG.debug("Applied removeUnnecessaryBinaryOperation5 (line "+bop.getBeginLine()+")"); |
| } |
| } |
| } |
| |
| return hi; |
| } |
| |
| /** |
| * Handle removal of unnecessary binary operations over rand data |
| * |
| * rand*7 -> rand(min*7,max*7); rand+7 -> rand(min+7,max+7); rand-7 -> rand(min+(-7),max+(-7)) |
| * 7*rand -> rand(min*7,max*7); 7+rand -> rand(min+7,max+7); |
| * |
| * @param hi high-order operation |
| * @return high-level operator |
| */ |
| @SuppressWarnings("incomplete-switch") |
| private static Hop fuseDatagenAndBinaryOperation( Hop hi ) |
| { |
| if( hi instanceof BinaryOp ) |
| { |
| BinaryOp bop = (BinaryOp)hi; |
| Hop left = bop.getInput().get(0); |
| Hop right = bop.getInput().get(1); |
| |
| //NOTE: rewrite not applied if more than one datagen consumer because this would lead to |
| //the creation of multiple datagen ops and thus potentially different results if seed not specified) |
| |
| //left input rand and hence output matrix double, right scalar literal |
| if( HopRewriteUtils.isDataGenOp(left, OpOpDG.RAND) && |
| right instanceof LiteralOp && left.getParent().size()==1 ) |
| { |
| DataGenOp inputGen = (DataGenOp)left; |
| Hop pdf = inputGen.getInput(DataExpression.RAND_PDF); |
| Hop min = inputGen.getInput(DataExpression.RAND_MIN); |
| Hop max = inputGen.getInput(DataExpression.RAND_MAX); |
| double sval = ((LiteralOp)right).getDoubleValue(); |
| boolean pdfUniform = pdf instanceof LiteralOp |
| && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()); |
| |
| if( HopRewriteUtils.isBinary(bop, OpOp2.MULT, OpOp2.PLUS, OpOp2.MINUS, OpOp2.DIV) |
| && min instanceof LiteralOp && max instanceof LiteralOp && pdfUniform ) |
| { |
| //create fused data gen operator |
| DataGenOp gen = null; |
| switch( bop.getOp() ) { //fuse via scale and shift |
| case MULT: gen = HopRewriteUtils.copyDataGenOp(inputGen, sval, 0); break; |
| case PLUS: |
| case MINUS: gen = HopRewriteUtils.copyDataGenOp(inputGen, |
| 1, sval * ((bop.getOp()==OpOp2.MINUS)?-1:1)); break; |
| case DIV: gen = HopRewriteUtils.copyDataGenOp(inputGen, 1/sval, 0); break; |
| } |
| |
| //rewire all parents (avoid anomalies with replicated datagen) |
| List<Hop> parents = new ArrayList<>(bop.getParent()); |
| for( Hop p : parents ) |
| HopRewriteUtils.replaceChildReference(p, bop, gen); |
| |
| hi = gen; |
| LOG.debug("Applied fuseDatagenAndBinaryOperation1 " |
| + "("+bop.getFilename()+", line "+bop.getBeginLine()+")."); |
| } |
| } |
| //right input rand and hence output matrix double, left scalar literal |
| else if( right instanceof DataGenOp && ((DataGenOp)right).getOp()==OpOpDG.RAND && |
| left instanceof LiteralOp && right.getParent().size()==1 ) |
| { |
| DataGenOp inputGen = (DataGenOp)right; |
| Hop pdf = inputGen.getInput(DataExpression.RAND_PDF); |
| Hop min = inputGen.getInput(DataExpression.RAND_MIN); |
| Hop max = inputGen.getInput(DataExpression.RAND_MAX); |
| double sval = ((LiteralOp)left).getDoubleValue(); |
| boolean pdfUniform = pdf instanceof LiteralOp |
| && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()); |
| |
| if( (bop.getOp()==OpOp2.MULT || bop.getOp()==OpOp2.PLUS) |
| && min instanceof LiteralOp && max instanceof LiteralOp && pdfUniform ) |
| { |
| //create fused data gen operator |
| DataGenOp gen = null; |
| if( bop.getOp()==OpOp2.MULT ) |
| gen = HopRewriteUtils.copyDataGenOp(inputGen, sval, 0); |
| else { //OpOp2.PLUS |
| gen = HopRewriteUtils.copyDataGenOp(inputGen, 1, sval); |
| } |
| |
| //rewire all parents (avoid anomalies with replicated datagen) |
| List<Hop> parents = new ArrayList<>(bop.getParent()); |
| for( Hop p : parents ) |
| HopRewriteUtils.replaceChildReference(p, bop, gen); |
| |
| hi = gen; |
| LOG.debug("Applied fuseDatagenAndBinaryOperation2 " |
| + "("+bop.getFilename()+", line "+bop.getBeginLine()+")."); |
| } |
| } |
| //left input rand and hence output matrix double, right scalar variable |
| else if( HopRewriteUtils.isDataGenOp(left, OpOpDG.RAND) |
| && right.getDataType().isScalar() && left.getParent().size()==1 ) |
| { |
| DataGenOp gen = (DataGenOp)left; |
| Hop min = gen.getInput(DataExpression.RAND_MIN); |
| Hop max = gen.getInput(DataExpression.RAND_MAX); |
| Hop pdf = gen.getInput(DataExpression.RAND_PDF); |
| boolean pdfUniform = pdf instanceof LiteralOp |
| && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()); |
| |
| |
| if( HopRewriteUtils.isBinary(bop, OpOp2.PLUS) |
| && HopRewriteUtils.isLiteralOfValue(min, 0) |
| && HopRewriteUtils.isLiteralOfValue(max, 0) ) |
| { |
| gen.setInput(DataExpression.RAND_MIN, right, true); |
| gen.setInput(DataExpression.RAND_MAX, right, true); |
| //rewire all parents (avoid anomalies with replicated datagen) |
| List<Hop> parents = new ArrayList<>(bop.getParent()); |
| for( Hop p : parents ) |
| HopRewriteUtils.replaceChildReference(p, bop, gen); |
| hi = gen; |
| LOG.debug("Applied fuseDatagenAndBinaryOperation3a " |
| + "("+bop.getFilename()+", line "+bop.getBeginLine()+")."); |
| } |
| else if( HopRewriteUtils.isBinary(bop, OpOp2.MULT) |
| && ((HopRewriteUtils.isLiteralOfValue(min, 0) && pdfUniform) |
| || HopRewriteUtils.isLiteralOfValue(min, 1)) |
| && HopRewriteUtils.isLiteralOfValue(max, 1) ) |
| { |
| if( HopRewriteUtils.isLiteralOfValue(min, 1) ) |
| gen.setInput(DataExpression.RAND_MIN, right, true); |
| gen.setInput(DataExpression.RAND_MAX, right, true); |
| //rewire all parents (avoid anomalies with replicated datagen) |
| List<Hop> parents = new ArrayList<>(bop.getParent()); |
| for( Hop p : parents ) |
| HopRewriteUtils.replaceChildReference(p, bop, gen); |
| hi = gen; |
| LOG.debug("Applied fuseDatagenAndBinaryOperation3b " |
| + "("+bop.getFilename()+", line "+bop.getBeginLine()+")."); |
| } |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop fuseDatagenAndMinusOperation( Hop hi ) |
| { |
| if( hi instanceof BinaryOp ) |
| { |
| BinaryOp bop = (BinaryOp)hi; |
| Hop left = bop.getInput().get(0); |
| Hop right = bop.getInput().get(1); |
| |
| if( right instanceof DataGenOp && ((DataGenOp)right).getOp()==OpOpDG.RAND && |
| left instanceof LiteralOp && ((LiteralOp)left).getDoubleValue()==0.0 ) |
| { |
| DataGenOp inputGen = (DataGenOp)right; |
| HashMap<String,Integer> params = inputGen.getParamIndexMap(); |
| Hop pdf = right.getInput().get(params.get(DataExpression.RAND_PDF)); |
| int ixMin = params.get(DataExpression.RAND_MIN); |
| int ixMax = params.get(DataExpression.RAND_MAX); |
| Hop min = right.getInput().get(ixMin); |
| Hop max = right.getInput().get(ixMax); |
| |
| //apply rewrite under additional conditions (for simplicity) |
| if( inputGen.getParent().size()==1 |
| && min instanceof LiteralOp && max instanceof LiteralOp && pdf instanceof LiteralOp |
| && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()) ) |
| { |
| //exchange and *-1 (special case 0 stays 0 instead of -0 for consistency) |
| double newMinVal = (((LiteralOp)max).getDoubleValue()==0)?0:(-1 * ((LiteralOp)max).getDoubleValue()); |
| double newMaxVal = (((LiteralOp)min).getDoubleValue()==0)?0:(-1 * ((LiteralOp)min).getDoubleValue()); |
| Hop newMin = new LiteralOp(newMinVal); |
| Hop newMax = new LiteralOp(newMaxVal); |
| |
| HopRewriteUtils.removeChildReferenceByPos(inputGen, min, ixMin); |
| HopRewriteUtils.addChildReference(inputGen, newMin, ixMin); |
| HopRewriteUtils.removeChildReferenceByPos(inputGen, max, ixMax); |
| HopRewriteUtils.addChildReference(inputGen, newMax, ixMax); |
| |
| //rewire all parents (avoid anomalies with replicated datagen) |
| List<Hop> parents = new ArrayList<>(bop.getParent()); |
| for( Hop p : parents ) |
| HopRewriteUtils.replaceChildReference(p, bop, inputGen); |
| |
| hi = inputGen; |
| LOG.debug("Applied fuseDatagenAndMinusOperation (line "+bop.getBeginLine()+")."); |
| } |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop foldMultipleAppendOperations(Hop hi) |
| { |
| if( hi.getDataType().isMatrix() //no string appends or frames |
| && (HopRewriteUtils.isBinary(hi, OpOp2.CBIND, OpOp2.RBIND) |
| || HopRewriteUtils.isNary(hi, OpOpN.CBIND, OpOpN.RBIND)) ) |
| { |
| OpOp2 bop = (hi instanceof BinaryOp) ? ((BinaryOp)hi).getOp() : |
| OpOp2.valueOf(((NaryOp)hi).getOp().name()); |
| OpOpN nop = (hi instanceof NaryOp) ? ((NaryOp)hi).getOp() : |
| OpOpN.valueOf(((BinaryOp)hi).getOp().name()); |
| |
| boolean converged = false; |
| while( !converged ) { |
| //get first matching cbind or rbind |
| Hop first = hi.getInput().stream() |
| .filter(h -> HopRewriteUtils.isBinary(h, bop) || HopRewriteUtils.isNary(h, nop)) |
| .findFirst().orElse(null); |
| |
| //replace current op with new nary cbind/rbind |
| if( first != null && first.getParent().size()==1 ) { |
| //construct new list of inputs (in original order) |
| ArrayList<Hop> linputs = new ArrayList<>(); |
| for(Hop in : hi.getInput()) |
| if( in == first ) |
| linputs.addAll(first.getInput()); |
| else |
| linputs.add(in); |
| Hop hnew = HopRewriteUtils.createNary(nop, linputs.toArray(new Hop[0])); |
| //clear dangling references |
| HopRewriteUtils.removeAllChildReferences(hi); |
| HopRewriteUtils.removeAllChildReferences(first); |
| //rewire all parents (avoid anomalies with refs to hi) |
| List<Hop> parents = new ArrayList<>(hi.getParent()); |
| for( Hop p : parents ) |
| HopRewriteUtils.replaceChildReference(p, hi, hnew); |
| hi = hnew; |
| LOG.debug("Applied foldMultipleAppendOperations (line "+hi.getBeginLine()+")."); |
| } |
| else { |
| converged = true; |
| } |
| } |
| } |
| |
| return hi; |
| } |
| |
| /** |
| * Handle simplification of binary operations (relies on previous common subexpression elimination). |
| * At the same time this servers as a canonicalization for more complex rewrites. |
| * |
| * X+X -> X*2, X*X -> X^2, (X>0)-(X<0) -> sign(X) |
| * |
| * @param parent parent high-level operator |
| * @param hi high-level operator |
| * @param pos position |
| * @return high-level operator |
| */ |
| private static Hop simplifyBinaryToUnaryOperation( Hop parent, Hop hi, int pos ) |
| { |
| if( hi instanceof BinaryOp ) |
| { |
| BinaryOp bop = (BinaryOp)hi; |
| Hop left = hi.getInput().get(0); |
| Hop right = hi.getInput().get(1); |
| |
| //patterns: X+X -> X*2, X*X -> X^2, |
| if( left == right && left.getDataType()==DataType.MATRIX ) |
| { |
| //note: we simplify this to unary operations first (less mem and better MR plan), |
| //however, we later compile specific LOPS for X*2 and X^2 |
| if( bop.getOp()==OpOp2.PLUS ) //X+X -> X*2 |
| { |
| bop.setOp(OpOp2.MULT); |
| HopRewriteUtils.replaceChildReference(hi, right, new LiteralOp(2), 1); |
| |
| LOG.debug("Applied simplifyBinaryToUnaryOperation1 (line "+hi.getBeginLine()+")."); |
| } |
| else if ( bop.getOp()==OpOp2.MULT ) //X*X -> X^2 |
| { |
| bop.setOp(OpOp2.POW); |
| HopRewriteUtils.replaceChildReference(hi, right, new LiteralOp(2), 1); |
| |
| LOG.debug("Applied simplifyBinaryToUnaryOperation2 (line "+hi.getBeginLine()+")."); |
| } |
| } |
| //patterns: (X>0)-(X<0) -> sign(X) |
| else if( bop.getOp() == OpOp2.MINUS |
| && HopRewriteUtils.isBinary(left, OpOp2.GREATER) |
| && HopRewriteUtils.isBinary(right, OpOp2.LESS) |
| && left.getInput().get(0) == right.getInput().get(0) |
| && left.getInput().get(1) instanceof LiteralOp |
| && HopRewriteUtils.getDoubleValue((LiteralOp)left.getInput().get(1))==0 |
| && right.getInput().get(1) instanceof LiteralOp |
| && HopRewriteUtils.getDoubleValue((LiteralOp)right.getInput().get(1))==0 ) |
| { |
| UnaryOp uop = HopRewriteUtils.createUnary(left.getInput().get(0), OpOp1.SIGN); |
| HopRewriteUtils.replaceChildReference(parent, hi, uop, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi, left, right); |
| hi = uop; |
| |
| LOG.debug("Applied simplifyBinaryToUnaryOperation3 (line "+hi.getBeginLine()+")."); |
| } |
| } |
| |
| return hi; |
| } |
| |
| /** |
| * Rewrite to canonicalize all patterns like U%*%V+eps, eps+U%*%V, and |
| * U%*%V-eps into the common representation U%*%V+s which simplifies |
| * subsequent rewrites (e.g., wdivmm or wcemm with epsilon). |
| * |
| * @param hi high-level operator |
| * @return high-level operator |
| */ |
| private static Hop canonicalizeMatrixMultScalarAdd( Hop hi ) |
| { |
| //pattern: binary operation (+ or -) of matrix mult and scalar |
| if( hi instanceof BinaryOp ) |
| { |
| BinaryOp bop = (BinaryOp)hi; |
| Hop left = hi.getInput().get(0); |
| Hop right = hi.getInput().get(1); |
| |
| //pattern: (eps + U%*%V) -> (U%*%V+eps) |
| if( left.getDataType().isScalar() && right instanceof AggBinaryOp |
| && bop.getOp()==OpOp2.PLUS ) |
| { |
| HopRewriteUtils.removeAllChildReferences(bop); |
| HopRewriteUtils.addChildReference(bop, right, 0); |
| HopRewriteUtils.addChildReference(bop, left, 1); |
| LOG.debug("Applied canonicalizeMatrixMultScalarAdd1 (line "+hi.getBeginLine()+")."); |
| } |
| //pattern: (U%*%V - eps) -> (U%*%V + (-eps)) |
| else if( right.getDataType().isScalar() && left instanceof AggBinaryOp |
| && bop.getOp() == OpOp2.MINUS ) |
| { |
| bop.setOp(OpOp2.PLUS); |
| HopRewriteUtils.replaceChildReference(bop, right, |
| HopRewriteUtils.createBinaryMinus(right), 1); |
| LOG.debug("Applied canonicalizeMatrixMultScalarAdd2 (line "+hi.getBeginLine()+")."); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyCTableWithConstMatrixInputs( Hop hi ) |
| { |
| //pattern: table(X, matrix(1,...), matrix(7, ...)) -> table(X, 1, 7) |
| if( HopRewriteUtils.isTernary(hi, OpOp3.CTABLE) ) { |
| //note: the first input always expected to be a matrix |
| for( int i=1; i<hi.getInput().size(); i++ ) { |
| Hop inCurr = hi.getInput().get(i); |
| if( HopRewriteUtils.isDataGenOpWithConstantValue(inCurr) ) { |
| Hop inNew = ((DataGenOp)inCurr).getInput(DataExpression.RAND_MIN); |
| HopRewriteUtils.replaceChildReference(hi, inCurr, inNew, i); |
| LOG.debug("Applied simplifyCTableWithConstMatrixInputs" |
| + i + " (line "+hi.getBeginLine()+")."); |
| } |
| } |
| } |
| return hi; |
| } |
| |
| private static Hop removeUnnecessaryCTable( Hop parent, Hop hi, int pos ) { |
| if ( HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.RowCol) |
| && HopRewriteUtils.isTernary(hi.getInput().get(0), OpOp3.CTABLE) |
| && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0).getInput().get(2), 1.0)) |
| { |
| Hop matrixInput = hi.getInput().get(0).getInput().get(0); |
| OpOp1 opcode = matrixInput.getDim2() == 1 ? OpOp1.NROW : OpOp1.LENGTH; |
| Hop newOpLength = new UnaryOp("tmp", DataType.SCALAR, ValueType.INT64, opcode, matrixInput); |
| HopRewriteUtils.replaceChildReference(parent, hi, newOpLength, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi, hi.getInput().get(0)); |
| hi = newOpLength; |
| } |
| return hi; |
| } |
| |
| /** |
| * NOTE: this would be by definition a dynamic rewrite; however, we apply it as a static |
| * rewrite in order to apply it before splitting dags which would hide the table information |
| * if dimensions are not specified. |
| * |
| * @param parent parent high-level operator |
| * @param hi high-level operator |
| * @param pos position |
| * @return high-level operator |
| */ |
| private static Hop simplifyReverseOperation( Hop parent, Hop hi, int pos ) |
| { |
| if( hi instanceof AggBinaryOp |
| && hi.getInput().get(0) instanceof TernaryOp ) |
| { |
| TernaryOp top = (TernaryOp) hi.getInput().get(0); |
| |
| if( top.getOp()==OpOp3.CTABLE |
| && HopRewriteUtils.isBasic1NSequence(top.getInput().get(0)) |
| && HopRewriteUtils.isBasicN1Sequence(top.getInput().get(1)) |
| && top.getInput().get(0).getDim1()==top.getInput().get(1).getDim1()) |
| { |
| ReorgOp rop = HopRewriteUtils.createReorg(hi.getInput().get(1), ReOrgOp.REV); |
| HopRewriteUtils.replaceChildReference(parent, hi, rop, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi, top); |
| hi = rop; |
| |
| LOG.debug("Applied simplifyReverseOperation."); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyMultiBinaryToBinaryOperation( Hop hi ) |
| { |
| //pattern: 1-(X*Y) --> X 1-* Y (avoid intermediate) |
| if( HopRewriteUtils.isBinary(hi, OpOp2.MINUS) |
| && hi.getDataType() == DataType.MATRIX |
| && hi.getInput().get(0) instanceof LiteralOp |
| && HopRewriteUtils.getDoubleValueSafe((LiteralOp)hi.getInput().get(0))==1 |
| && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) |
| && hi.getInput().get(1).getParent().size() == 1 ) //single consumer |
| { |
| BinaryOp bop = (BinaryOp)hi; |
| Hop left = hi.getInput().get(1).getInput().get(0); |
| Hop right = hi.getInput().get(1).getInput().get(1); |
| |
| //set new binaryop type and rewire inputs |
| bop.setOp(OpOp2.MINUS1_MULT); |
| HopRewriteUtils.removeAllChildReferences(hi); |
| HopRewriteUtils.addChildReference(bop, left); |
| HopRewriteUtils.addChildReference(bop, right); |
| |
| LOG.debug("Applied simplifyMultiBinaryToBinaryOperation."); |
| } |
| |
| return hi; |
| } |
| |
| /** |
| * (X-Y*X) -> (1-Y)*X, (Y*X-X) -> (Y-1)*X |
| * (X+Y*X) -> (1+Y)*X, (Y*X+X) -> (Y+1)*X |
| * |
| * |
| * @param parent parent high-level operator |
| * @param hi high-level operator |
| * @param pos position |
| * @return high-level operator |
| */ |
| private static Hop simplifyDistributiveBinaryOperation( Hop parent, Hop hi, int pos ) |
| { |
| |
| if( hi instanceof BinaryOp ) |
| { |
| BinaryOp bop = (BinaryOp)hi; |
| Hop left = bop.getInput().get(0); |
| Hop right = bop.getInput().get(1); |
| |
| //(X+Y*X) -> (1+Y)*X, (Y*X+X) -> (Y+1)*X |
| //(X-Y*X) -> (1-Y)*X, (Y*X-X) -> (Y-1)*X |
| boolean applied = false; |
| if( left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX |
| && HopRewriteUtils.isValidOp(bop.getOp(), LOOKUP_VALID_DISTRIBUTIVE_BINARY) ) |
| { |
| Hop X = null; Hop Y = null; |
| if( HopRewriteUtils.isBinary(left, OpOp2.MULT) ) //(Y*X-X) -> (Y-1)*X |
| { |
| Hop leftC1 = left.getInput().get(0); |
| Hop leftC2 = left.getInput().get(1); |
| |
| if( leftC1.getDataType()==DataType.MATRIX && leftC2.getDataType()==DataType.MATRIX && |
| (right == leftC1 || right == leftC2) && leftC1 !=leftC2 ){ //any mult order |
| X = right; |
| Y = ( right == leftC1 ) ? leftC2 : leftC1; |
| } |
| if( X != null ){ //rewrite 'binary +/-' |
| LiteralOp literal = new LiteralOp(1); |
| BinaryOp plus = HopRewriteUtils.createBinary(Y, literal, bop.getOp()); |
| BinaryOp mult = HopRewriteUtils.createBinary(plus, X, OpOp2.MULT); |
| HopRewriteUtils.replaceChildReference(parent, hi, mult, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi, left); |
| hi = mult; |
| applied = true; |
| |
| LOG.debug("Applied simplifyDistributiveBinaryOperation1"); |
| } |
| } |
| |
| if( !applied && HopRewriteUtils.isBinary(right, OpOp2.MULT) ) //(X-Y*X) -> (1-Y)*X |
| { |
| Hop rightC1 = right.getInput().get(0); |
| Hop rightC2 = right.getInput().get(1); |
| if( rightC1.getDataType()==DataType.MATRIX && rightC2.getDataType()==DataType.MATRIX && |
| (left == rightC1 || left == rightC2) && rightC1 !=rightC2 ){ //any mult order |
| X = left; |
| Y = ( left == rightC1 ) ? rightC2 : rightC1; |
| } |
| if( X != null ){ //rewrite '+/- binary' |
| LiteralOp literal = new LiteralOp(1); |
| BinaryOp plus = HopRewriteUtils.createBinary(literal, Y, bop.getOp()); |
| BinaryOp mult = HopRewriteUtils.createBinary(plus, X, OpOp2.MULT); |
| HopRewriteUtils.replaceChildReference(parent, hi, mult, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi, right); |
| hi = mult; |
| |
| LOG.debug("Applied simplifyDistributiveBinaryOperation2"); |
| } |
| } |
| } |
| } |
| |
| return hi; |
| } |
| |
| /** |
| * (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v) |
| * (X+(Y+(Z%*%v))) -> (X+Y)+(Z%*%v) |
| * |
| * Note: Restriction ba() at leaf and root instead of data at leaf to not reorganize too |
| * eagerly, which would loose additional rewrite potential. This rewrite has two goals |
| * (1) enable XtwXv, and increase piggybacking potential by creating bushy trees. |
| * |
| * @param parent parent high-level operator |
| * @param hi high-level operator |
| * @param pos position |
| * @return high-level operator |
| */ |
| private static Hop simplifyBushyBinaryOperation( Hop parent, Hop hi, int pos ) |
| { |
| if( hi instanceof BinaryOp && parent instanceof AggBinaryOp ) |
| { |
| BinaryOp bop = (BinaryOp)hi; |
| Hop left = bop.getInput().get(0); |
| Hop right = bop.getInput().get(1); |
| OpOp2 op = bop.getOp(); |
| |
| if( left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX && |
| HopRewriteUtils.isValidOp(op, LOOKUP_VALID_ASSOCIATIVE_BINARY) ) |
| { |
| boolean applied = false; |
| |
| if( right instanceof BinaryOp ) |
| { |
| BinaryOp bop2 = (BinaryOp)right; |
| Hop left2 = bop2.getInput().get(0); |
| Hop right2 = bop2.getInput().get(1); |
| OpOp2 op2 = bop2.getOp(); |
| |
| if( op==op2 && right2.getDataType()==DataType.MATRIX |
| && (right2 instanceof AggBinaryOp) ) |
| { |
| //(X*(Y*op()) -> (X*Y)*op() |
| BinaryOp bop3 = HopRewriteUtils.createBinary(left, left2, op); |
| BinaryOp bop4 = HopRewriteUtils.createBinary(bop3, right2, op); |
| HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos); |
| HopRewriteUtils.cleanupUnreferenced(bop, bop2); |
| hi = bop4; |
| |
| applied = true; |
| |
| LOG.debug("Applied simplifyBushyBinaryOperation1"); |
| } |
| } |
| |
| if( !applied && left instanceof BinaryOp ) |
| { |
| BinaryOp bop2 = (BinaryOp)left; |
| Hop left2 = bop2.getInput().get(0); |
| Hop right2 = bop2.getInput().get(1); |
| OpOp2 op2 = bop2.getOp(); |
| |
| if( op==op2 && left2.getDataType()==DataType.MATRIX |
| && (left2 instanceof AggBinaryOp) |
| && (right2.getDim2() > 1 || right.getDim2() == 1) //X not vector, or Y vector |
| && (right2.getDim1() > 1 || right.getDim1() == 1) ) //X not vector, or Y vector |
| { |
| //((op()*X)*Y) -> op()*(X*Y) |
| BinaryOp bop3 = HopRewriteUtils.createBinary(right2, right, op); |
| BinaryOp bop4 = HopRewriteUtils.createBinary(left2, bop3, op); |
| HopRewriteUtils.replaceChildReference(parent, bop, bop4, pos); |
| HopRewriteUtils.cleanupUnreferenced(bop, bop2); |
| hi = bop4; |
| |
| LOG.debug("Applied simplifyBushyBinaryOperation2"); |
| } |
| } |
| } |
| |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyUnaryAggReorgOperation( Hop parent, Hop hi, int pos ) |
| { |
| if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol //full uagg |
| && hi.getInput().get(0) instanceof ReorgOp ) //reorg operation |
| { |
| ReorgOp rop = (ReorgOp)hi.getInput().get(0); |
| if( (rop.getOp()==ReOrgOp.TRANS || rop.getOp()==ReOrgOp.RESHAPE |
| || rop.getOp() == ReOrgOp.REV ) //valid reorg |
| && rop.getParent().size()==1 ) //uagg only reorg consumer |
| { |
| Hop input = rop.getInput().get(0); |
| HopRewriteUtils.removeAllChildReferences(hi); |
| HopRewriteUtils.removeAllChildReferences(rop); |
| HopRewriteUtils.addChildReference(hi, input); |
| |
| LOG.debug("Applied simplifyUnaryAggReorgOperation"); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop removeUnnecessaryAggregates(Hop hi) |
| { |
| //sum(rowSums(X)) -> sum(X), sum(colSums(X)) -> sum(X) |
| //min(rowMins(X)) -> min(X), min(colMins(X)) -> min(X) |
| //max(rowMaxs(X)) -> max(X), max(colMaxs(X)) -> max(X) |
| //sum(rowSums(X^2)) -> sum(X), sum(colSums(X^2)) -> sum(X) |
| if( hi instanceof AggUnaryOp && hi.getInput().get(0) instanceof AggUnaryOp |
| && ((AggUnaryOp)hi).getDirection()==Direction.RowCol |
| && hi.getInput().get(0).getParent().size()==1 ) |
| { |
| AggUnaryOp au1 = (AggUnaryOp) hi; |
| AggUnaryOp au2 = (AggUnaryOp) hi.getInput().get(0); |
| if( (au1.getOp()==AggOp.SUM && (au2.getOp()==AggOp.SUM || au2.getOp()==AggOp.SUM_SQ)) |
| || (au1.getOp()==AggOp.MIN && au2.getOp()==AggOp.MIN) |
| || (au1.getOp()==AggOp.MAX && au2.getOp()==AggOp.MAX) ) |
| { |
| Hop input = au2.getInput().get(0); |
| HopRewriteUtils.removeAllChildReferences(au2); |
| HopRewriteUtils.replaceChildReference(au1, au2, input); |
| if( au2.getOp() == AggOp.SUM_SQ ) |
| au1.setOp(AggOp.SUM_SQ); |
| |
| LOG.debug("Applied removeUnnecessaryAggregates (line "+hi.getBeginLine()+")."); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyBinaryMatrixScalarOperation( Hop parent, Hop hi, int pos ) |
| { |
| // Note: This rewrite is not applicable for all binary operations because some of them |
| // are undefined over scalars. We explicitly exclude potential conflicting matrix-scalar binary |
| // operations; other operations like cbind/rbind will never occur as matrix-scalar operations. |
| |
| if( HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR) |
| && hi.getInput().get(0) instanceof BinaryOp |
| && HopRewriteUtils.isBinary(hi.getInput().get(0), LOOKUP_VALID_SCALAR_BINARY)) |
| { |
| BinaryOp bin = (BinaryOp) hi.getInput().get(0); |
| BinaryOp bout = null; |
| |
| //as.scalar(X*Y) -> as.scalar(X) * as.scalar(Y) |
| if( bin.getInput().get(0).getDataType()==DataType.MATRIX |
| && bin.getInput().get(1).getDataType()==DataType.MATRIX ) { |
| UnaryOp cast1 = HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR); |
| UnaryOp cast2 = HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR); |
| bout = HopRewriteUtils.createBinary(cast1, cast2, bin.getOp()); |
| } |
| //as.scalar(X*s) -> as.scalar(X) * s |
| else if( bin.getInput().get(0).getDataType()==DataType.MATRIX ) { |
| UnaryOp cast = HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR); |
| bout = HopRewriteUtils.createBinary(cast, bin.getInput().get(1), bin.getOp()); |
| } |
| //as.scalar(s*X) -> s * as.scalar(X) |
| else if ( bin.getInput().get(1).getDataType()==DataType.MATRIX ) { |
| UnaryOp cast = HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR); |
| bout = HopRewriteUtils.createBinary(bin.getInput().get(0), cast, bin.getOp()); |
| } |
| |
| if( bout != null ) { |
| HopRewriteUtils.replaceChildReference(parent, hi, bout, pos); |
| |
| LOG.debug("Applied simplifyBinaryMatrixScalarOperation."); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop pushdownUnaryAggTransposeOperation( Hop parent, Hop hi, int pos ) |
| { |
| if( hi instanceof AggUnaryOp && hi.getParent().size()==1 |
| && (((AggUnaryOp) hi).getDirection()==Direction.Row || ((AggUnaryOp) hi).getDirection()==Direction.Col) |
| && HopRewriteUtils.isTransposeOperation(hi.getInput().get(0), 1) |
| && HopRewriteUtils.isValidOp(((AggUnaryOp) hi).getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE) ) |
| { |
| AggUnaryOp uagg = (AggUnaryOp) hi; |
| |
| //get input rewire existing operators (remove inner transpose) |
| Hop input = uagg.getInput().get(0).getInput().get(0); |
| HopRewriteUtils.removeAllChildReferences(hi.getInput().get(0)); |
| HopRewriteUtils.removeAllChildReferences(hi); |
| HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); |
| |
| //pattern 1: row-aggregate to col aggregate, e.g., rowSums(t(X))->t(colSums(X)) |
| if( uagg.getDirection()==Direction.Row ) { |
| uagg.setDirection(Direction.Col); |
| LOG.debug("Applied pushdownUnaryAggTransposeOperation1 (line "+hi.getBeginLine()+")."); |
| } |
| //pattern 2: col-aggregate to row aggregate, e.g., colSums(t(X))->t(rowSums(X)) |
| else if( uagg.getDirection()==Direction.Col ) { |
| uagg.setDirection(Direction.Row); |
| LOG.debug("Applied pushdownUnaryAggTransposeOperation2 (line "+hi.getBeginLine()+")."); |
| } |
| |
| //create outer transpose operation and rewire operators |
| HopRewriteUtils.addChildReference(uagg, input); uagg.refreshSizeInformation(); |
| Hop trans = HopRewriteUtils.createTranspose(uagg); //incl refresh size |
| HopRewriteUtils.addChildReference(parent, trans, pos); //by def, same size |
| |
| hi = trans; |
| } |
| |
| return hi; |
| } |
| |
| private static Hop pushdownCSETransposeScalarOperation( Hop parent, Hop hi, int pos ) |
| { |
| // a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X) |
| // probed at root node of b in above example |
| // (with support for left or right scalar operations) |
| if( HopRewriteUtils.isTransposeOperation(hi, 1) |
| && HopRewriteUtils.isBinaryMatrixScalarOperation(hi.getInput().get(0)) |
| && hi.getInput().get(0).getParent().size()==1) |
| { |
| int Xpos = hi.getInput().get(0).getInput().get(0).getDataType().isMatrix() ? 0 : 1; |
| Hop X = hi.getInput().get(0).getInput().get(Xpos); |
| BinaryOp binary = (BinaryOp) hi.getInput().get(0); |
| |
| if( HopRewriteUtils.containsTransposeOperation(X.getParent()) |
| && !HopRewriteUtils.isValidOp(binary.getOp(), new OpOp2[]{OpOp2.MOMENT, OpOp2.QUANTILE})) |
| { |
| //clear existing wiring |
| HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); |
| HopRewriteUtils.removeChildReference(hi, binary); |
| HopRewriteUtils.removeChildReference(binary, X); |
| |
| //re-wire operators |
| HopRewriteUtils.addChildReference(parent, binary, pos); |
| HopRewriteUtils.addChildReference(binary, hi, Xpos); |
| HopRewriteUtils.addChildReference(hi, X); |
| //note: common subexpression later eliminated by dedicated rewrite |
| |
| hi = binary; |
| LOG.debug("Applied pushdownCSETransposeScalarOperation (line "+hi.getBeginLine()+")."); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop pushdownSumBinaryMult(Hop parent, Hop hi, int pos ) { |
| //pattern: sum(lamda*X) -> lamda*sum(X) |
| if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol |
| && ((AggUnaryOp)hi).getOp()==AggOp.SUM // only one parent which is the sum |
| && HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.MULT, 1) |
| && ((hi.getInput().get(0).getInput().get(0).getDataType()==DataType.SCALAR && hi.getInput().get(0).getInput().get(1).getDataType()==DataType.MATRIX) |
| ||(hi.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(0).getInput().get(1).getDataType()==DataType.SCALAR))) |
| { |
| Hop operand1 = hi.getInput().get(0).getInput().get(0); |
| Hop operand2 = hi.getInput().get(0).getInput().get(1); |
| |
| //check which operand is the Scalar and which is the matrix |
| Hop lamda = (operand1.getDataType()==DataType.SCALAR) ? operand1 : operand2; |
| Hop matrix = (operand1.getDataType()==DataType.MATRIX) ? operand1 : operand2; |
| |
| AggUnaryOp aggOp=HopRewriteUtils.createAggUnaryOp(matrix, AggOp.SUM, Direction.RowCol); |
| Hop bop = HopRewriteUtils.createBinary(lamda, aggOp, OpOp2.MULT); |
| |
| HopRewriteUtils.replaceChildReference(parent, hi, bop, pos); |
| |
| LOG.debug("Applied pushdownSumBinaryMult."); |
| return bop; |
| } |
| return hi; |
| } |
| |
| private static Hop simplifyUnaryPPredOperation( Hop parent, Hop hi, int pos ) |
| { |
| if( hi instanceof UnaryOp && hi.getDataType()==DataType.MATRIX //unaryop |
| && hi.getInput().get(0) instanceof BinaryOp //binaryop - ppred |
| && ((BinaryOp)hi.getInput().get(0)).isPPredOperation() ) |
| { |
| UnaryOp uop = (UnaryOp) hi; //valid unary op |
| if( uop.getOp()==OpOp1.ABS || uop.getOp()==OpOp1.SIGN |
| || uop.getOp()==OpOp1.CEIL || uop.getOp()==OpOp1.FLOOR || uop.getOp()==OpOp1.ROUND ) |
| { |
| //clear link unary-binary |
| Hop input = uop.getInput().get(0); |
| HopRewriteUtils.replaceChildReference(parent, hi, input, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| hi = input; |
| |
| LOG.debug("Applied simplifyUnaryPPredOperation."); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyTransposedAppend( Hop parent, Hop hi, int pos ) |
| { |
| //e.g., t(cbind(t(A),t(B))) --> rbind(A,B), t(rbind(t(A),t(B))) --> cbind(A,B) |
| if( HopRewriteUtils.isTransposeOperation(hi) //t() rooted |
| && hi.getInput().get(0) instanceof BinaryOp |
| && (((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.CBIND //append (cbind/rbind) |
| || ((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.RBIND) |
| && hi.getInput().get(0).getParent().size() == 1 ) //single consumer of append |
| { |
| BinaryOp bop = (BinaryOp)hi.getInput().get(0); |
| //both inputs transpose ops, where transpose is single consumer |
| if( HopRewriteUtils.isTransposeOperation(bop.getInput().get(0), 1) |
| && HopRewriteUtils.isTransposeOperation(bop.getInput().get(1), 1) ) |
| { |
| Hop left = bop.getInput().get(0).getInput().get(0); |
| Hop right = bop.getInput().get(1).getInput().get(0); |
| |
| //create new subdag (no in-place dag update to prevent anomalies with |
| //multiple consumers during rewrite process) |
| OpOp2 binop = (bop.getOp()==OpOp2.CBIND) ? OpOp2.RBIND : OpOp2.CBIND; |
| BinaryOp bopnew = HopRewriteUtils.createBinary(left, right, binop); |
| HopRewriteUtils.replaceChildReference(parent, hi, bopnew, pos); |
| |
| hi = bopnew; |
| LOG.debug("Applied simplifyTransposedAppend (line "+hi.getBeginLine()+")."); |
| } |
| } |
| |
| return hi; |
| } |
| |
| /** |
| * handle simplification of more complex sub DAG to unary operation. |
| * |
| * X*(1-X) -> sprop(X) |
| * (1-X)*X -> sprop(X) |
| * 1/(1+exp(-X)) -> sigmoid(X) |
| * |
| * @param parent parent high-level operator |
| * @param hi high-level operator |
| * @param pos position |
| */ |
| private static Hop fuseBinarySubDAGToUnaryOperation( Hop parent, Hop hi, int pos ) |
| { |
| if( hi instanceof BinaryOp ) |
| { |
| BinaryOp bop = (BinaryOp)hi; |
| Hop left = hi.getInput().get(0); |
| Hop right = hi.getInput().get(1); |
| boolean applied = false; |
| |
| //sample proportion (sprop) operator |
| if( bop.getOp() == OpOp2.MULT && left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX ) |
| { |
| //by definition, either left or right or none applies. |
| //note: if there are multiple consumers on the intermediate, |
| //we follow the heuristic that redundant computation is more beneficial, |
| //i.e., we still fuse but leave the intermediate for the other consumers |
| |
| if( left instanceof BinaryOp ) //(1-X)*X |
| { |
| BinaryOp bleft = (BinaryOp)left; |
| Hop left1 = bleft.getInput().get(0); |
| Hop left2 = bleft.getInput().get(1); |
| |
| if( left1 instanceof LiteralOp && |
| HopRewriteUtils.getDoubleValue((LiteralOp)left1)==1 && |
| left2 == right && bleft.getOp() == OpOp2.MINUS ) |
| { |
| UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SPROP); |
| HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); |
| HopRewriteUtils.cleanupUnreferenced(bop, left); |
| hi = unary; |
| applied = true; |
| |
| LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop1"); |
| } |
| } |
| if( !applied && right instanceof BinaryOp ) //X*(1-X) |
| { |
| BinaryOp bright = (BinaryOp)right; |
| Hop right1 = bright.getInput().get(0); |
| Hop right2 = bright.getInput().get(1); |
| |
| if( right1 instanceof LiteralOp && |
| HopRewriteUtils.getDoubleValue((LiteralOp)right1)==1 && |
| right2 == left && bright.getOp() == OpOp2.MINUS ) |
| { |
| UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SPROP); |
| HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); |
| HopRewriteUtils.cleanupUnreferenced(bop, left); |
| hi = unary; |
| applied = true; |
| |
| LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop2"); |
| } |
| } |
| } |
| |
| //sigmoid operator |
| if( !applied && bop.getOp() == OpOp2.DIV && left.getDataType()==DataType.SCALAR && right.getDataType()==DataType.MATRIX |
| && left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left)==1 && right instanceof BinaryOp) |
| { |
| //note: if there are multiple consumers on the intermediate, |
| //we follow the heuristic that redundant computation is more beneficial, |
| //i.e., we still fuse but leave the intermediate for the other consumers |
| |
| BinaryOp bop2 = (BinaryOp)right; |
| Hop left2 = bop2.getInput().get(0); |
| Hop right2 = bop2.getInput().get(1); |
| |
| if( bop2.getOp() == OpOp2.PLUS && left2.getDataType()==DataType.SCALAR && right2.getDataType()==DataType.MATRIX |
| && left2 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left2)==1 && right2 instanceof UnaryOp) |
| { |
| UnaryOp uop = (UnaryOp) right2; |
| Hop uopin = uop.getInput().get(0); |
| |
| if( uop.getOp()==OpOp1.EXP ) |
| { |
| UnaryOp unary = null; |
| |
| //Pattern 1: (1/(1 + exp(-X)) |
| if( HopRewriteUtils.isBinary(uopin, OpOp2.MINUS) ) { |
| BinaryOp bop3 = (BinaryOp) uopin; |
| Hop left3 = bop3.getInput().get(0); |
| Hop right3 = bop3.getInput().get(1); |
| |
| if( left3 instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left3)==0 ) |
| unary = HopRewriteUtils.createUnary(right3, OpOp1.SIGMOID); |
| } |
| //Pattern 2: (1/(1 + exp(X)), e.g., where -(-X) has been removed by |
| //the 'remove unnecessary minus' rewrite --> reintroduce the minus |
| else { |
| BinaryOp minus = HopRewriteUtils.createBinaryMinus(uopin); |
| unary = HopRewriteUtils.createUnary(minus, OpOp1.SIGMOID); |
| } |
| |
| if( unary != null ) { |
| HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); |
| HopRewriteUtils.cleanupUnreferenced(bop, bop2, uop); |
| hi = unary; |
| applied = true; |
| |
| LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sigmoid1"); |
| } |
| } |
| } |
| } |
| |
| //select positive (selp) operator (note: same initial pattern as sprop) |
| if( !applied && bop.getOp() == OpOp2.MULT && left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX ) |
| { |
| //by definition, either left or right or none applies. |
| //note: if there are multiple consumers on the intermediate tmp=(X>0), it's still beneficial |
| //to replace the X*tmp with selp(X) due to lower memory requirements and simply sparsity propagation |
| if( left instanceof BinaryOp ) //(X>0)*X |
| { |
| BinaryOp bleft = (BinaryOp)left; |
| Hop left1 = bleft.getInput().get(0); |
| Hop left2 = bleft.getInput().get(1); |
| |
| if( left2 instanceof LiteralOp && |
| HopRewriteUtils.getDoubleValue((LiteralOp)left2)==0 && |
| left1 == right && (bleft.getOp() == OpOp2.GREATER ) ) |
| { |
| BinaryOp binary = HopRewriteUtils.createBinary(right, new LiteralOp(0), OpOp2.MAX); |
| HopRewriteUtils.replaceChildReference(parent, bop, binary, pos); |
| HopRewriteUtils.cleanupUnreferenced(bop, left); |
| hi = binary; |
| applied = true; |
| |
| LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0a"); |
| } |
| } |
| if( !applied && right instanceof BinaryOp ) //X*(X>0) |
| { |
| BinaryOp bright = (BinaryOp)right; |
| Hop right1 = bright.getInput().get(0); |
| Hop right2 = bright.getInput().get(1); |
| |
| if( right2 instanceof LiteralOp && |
| HopRewriteUtils.getDoubleValue((LiteralOp)right2)==0 && |
| right1 == left && bright.getOp() == OpOp2.GREATER ) |
| { |
| BinaryOp binary = HopRewriteUtils.createBinary(left, new LiteralOp(0), OpOp2.MAX); |
| HopRewriteUtils.replaceChildReference(parent, bop, binary, pos); |
| HopRewriteUtils.cleanupUnreferenced(bop, left); |
| hi = binary; |
| applied= true; |
| |
| LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-max0b"); |
| } |
| } |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyTraceMatrixMult(Hop parent, Hop hi, int pos) |
| { |
| if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getOp()==AggOp.TRACE ) //trace() |
| { |
| Hop hi2 = hi.getInput().get(0); |
| if( HopRewriteUtils.isMatrixMultiply(hi2) ) //X%*%Y |
| { |
| Hop left = hi2.getInput().get(0); |
| Hop right = hi2.getInput().get(1); |
| |
| //create new operators (incl refresh size inside for transpose) |
| ReorgOp trans = HopRewriteUtils.createTranspose(right); |
| BinaryOp mult = HopRewriteUtils.createBinary(left, trans, OpOp2.MULT); |
| AggUnaryOp sum = HopRewriteUtils.createSum(mult); |
| |
| //rehang new subdag under parent node |
| HopRewriteUtils.replaceChildReference(parent, hi, sum, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi, hi2); |
| hi = sum; |
| |
| LOG.debug("Applied simplifyTraceMatrixMult"); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifySlicedMatrixMult(Hop parent, Hop hi, int pos) |
| { |
| //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1] |
| if( hi instanceof IndexingOp |
| && ((IndexingOp)hi).isRowLowerEqualsUpper() |
| && ((IndexingOp)hi).isColLowerEqualsUpper() |
| && hi.getInput().get(0).getParent().size()==1 //rix is single mm consumer |
| && HopRewriteUtils.isMatrixMultiply(hi.getInput().get(0)) ) |
| { |
| Hop mm = hi.getInput().get(0); |
| Hop X = mm.getInput().get(0); |
| Hop Y = mm.getInput().get(1); |
| Hop rowExpr = hi.getInput().get(1); //rl==ru |
| Hop colExpr = hi.getInput().get(3); //cl==cu |
| |
| HopRewriteUtils.removeAllChildReferences(mm); |
| |
| //create new indexing operations |
| IndexingOp ix1 = new IndexingOp("tmp1", DataType.MATRIX, ValueType.FP64, X, |
| rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(X, false), true, false); |
| ix1.setBlocksize(X.getBlocksize()); |
| ix1.refreshSizeInformation(); |
| IndexingOp ix2 = new IndexingOp("tmp2", DataType.MATRIX, ValueType.FP64, Y, |
| new LiteralOp(1), HopRewriteUtils.createValueHop(Y, true), colExpr, colExpr, false, true); |
| ix2.setBlocksize(Y.getBlocksize()); |
| ix2.refreshSizeInformation(); |
| |
| //rewire matrix mult over ix1 and ix2 |
| HopRewriteUtils.addChildReference(mm, ix1, 0); |
| HopRewriteUtils.addChildReference(mm, ix2, 1); |
| mm.refreshSizeInformation(); |
| |
| hi = mm; |
| |
| LOG.debug("Applied simplifySlicedMatrixMult"); |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyConstantSort(Hop parent, Hop hi, int pos) |
| { |
| //order(matrix(7), indexreturn=FALSE) -> matrix(7) |
| //order(matrix(7), indexreturn=TRUE) -> seq(1,nrow(X),1) |
| if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.SORT ) //order |
| { |
| Hop hi2 = hi.getInput().get(0); |
| |
| if( hi2 instanceof DataGenOp && ((DataGenOp)hi2).getOp()==OpOpDG.RAND |
| && ((DataGenOp)hi2).hasConstantValue() |
| && hi.getInput().get(3) instanceof LiteralOp ) //known indexreturn |
| { |
| if( HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(3)) ) |
| { |
| //order(matrix(7), indexreturn=TRUE) -> seq(1,nrow(X),1) |
| Hop seq = HopRewriteUtils.createSeqDataGenOp(hi2); |
| seq.refreshSizeInformation(); |
| HopRewriteUtils.replaceChildReference(parent, hi, seq, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| hi = seq; |
| |
| LOG.debug("Applied simplifyConstantSort1."); |
| } |
| else |
| { |
| //order(matrix(7), indexreturn=FALSE) -> matrix(7) |
| HopRewriteUtils.replaceChildReference(parent, hi, hi2, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| hi = hi2; |
| |
| LOG.debug("Applied simplifyConstantSort2."); |
| } |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyOrderedSort(Hop parent, Hop hi, int pos) |
| { |
| //order(seq(2,N+1,1), indexreturn=FALSE) -> matrix(7) |
| //order(seq(2,N+1,1), indexreturn=TRUE) -> seq(1,N,1)/seq(N,1,-1) |
| if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.SORT ) //order |
| { |
| Hop hi2 = hi.getInput().get(0); |
| |
| if( hi2 instanceof DataGenOp && ((DataGenOp)hi2).getOp()==OpOpDG.SEQ ) |
| { |
| Hop incr = hi2.getInput().get(((DataGenOp)hi2).getParamIndex(Statement.SEQ_INCR)); |
| //check for known ascending ordering and known indexreturn |
| if( incr instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)incr)==1 |
| && hi.getInput().get(2) instanceof LiteralOp //decreasing |
| && hi.getInput().get(3) instanceof LiteralOp ) //indexreturn |
| { |
| if( HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(3)) ) //IXRET, ASC/DESC |
| { |
| //order(seq(2,N+1,1), indexreturn=TRUE) -> seq(1,N,1)/seq(N,1,-1) |
| boolean desc = HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2)); |
| Hop seq = HopRewriteUtils.createSeqDataGenOp(hi2, !desc); |
| seq.refreshSizeInformation(); |
| HopRewriteUtils.replaceChildReference(parent, hi, seq, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| hi = seq; |
| |
| LOG.debug("Applied simplifyOrderedSort1."); |
| } |
| else if( !HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2)) ) //DATA, ASC |
| { |
| //order(seq(2,N+1,1), indexreturn=FALSE) -> seq(2,N+1,1) |
| HopRewriteUtils.replaceChildReference(parent, hi, hi2, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| hi = hi2; |
| |
| LOG.debug("Applied simplifyOrderedSort2."); |
| } |
| } |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop fuseOrderOperationChain(Hop hi) |
| { |
| //order(order(X,2),1) -> order(X, (12)), |
| if( HopRewriteUtils.isReorg(hi, ReOrgOp.SORT) |
| && hi.getInput().get(1) instanceof LiteralOp //scalar by |
| && hi.getInput().get(2) instanceof LiteralOp //scalar desc |
| && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false) ) //not ixret |
| { |
| LiteralOp by = (LiteralOp) hi.getInput().get(1); |
| boolean desc = HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2)); |
| |
| //find chain of order operations with same desc/ixret configuration and single consumers |
| Set<String> probe = new HashSet<>(); |
| ArrayList<LiteralOp> byList = new ArrayList<>(); |
| byList.add(by); probe.add(by.getStringValue()); |
| Hop input = hi.getInput().get(0); |
| while( HopRewriteUtils.isReorg(input, ReOrgOp.SORT) |
| && input.getInput().get(1) instanceof LiteralOp //scalar by |
| && !probe.contains(input.getInput().get(1).getName()) |
| && HopRewriteUtils.isLiteralOfValue(input.getInput().get(2), desc) |
| && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false) |
| && input.getParent().size() == 1 ) |
| { |
| byList.add((LiteralOp)input.getInput().get(1)); |
| probe.add(input.getInput().get(1).getName()); |
| input = input.getInput().get(0); |
| } |
| |
| //merge order chain if at least two instances |
| if( byList.size() >= 2 ) { |
| //create new order operations |
| ArrayList<Hop> inputs = new ArrayList<>(); |
| inputs.add(input); |
| inputs.add(HopRewriteUtils.createDataGenOpByVal(byList, 1, byList.size())); |
| inputs.add(new LiteralOp(desc)); |
| inputs.add(new LiteralOp(false)); |
| Hop hnew = HopRewriteUtils.createReorg(inputs, ReOrgOp.SORT); |
| |
| //cleanup references recursively |
| Hop current = hi; |
| while(current != input ) { |
| Hop tmp = current.getInput().get(0); |
| HopRewriteUtils.removeAllChildReferences(current); |
| current = tmp; |
| } |
| |
| //rewire all parents (avoid anomalies with replicated datagen) |
| List<Hop> parents = new ArrayList<>(hi.getParent()); |
| for( Hop p : parents ) |
| HopRewriteUtils.replaceChildReference(p, hi, hnew); |
| |
| hi = hnew; |
| LOG.debug("Applied fuseOrderOperationChain (line "+hi.getBeginLine()+")."); |
| } |
| } |
| |
| return hi; |
| } |
| |
| /** |
| * Patterns: t(t(A)%*%t(B)+C) -> B%*%A+t(C) |
| * |
| * @param parent parent high-level operator |
| * @param hi high-level operator |
| * @param pos position |
| * @return high-level operator |
| */ |
| private static Hop simplifyTransposeAggBinBinaryChains(Hop parent, Hop hi, int pos) |
| { |
| if( HopRewriteUtils.isTransposeOperation(hi) |
| && hi.getInput().get(0) instanceof BinaryOp //basic binary |
| && ((BinaryOp)hi.getInput().get(0)).supportsMatrixScalarOperations()) |
| { |
| Hop left = hi.getInput().get(0).getInput().get(0); |
| Hop C = hi.getInput().get(0).getInput().get(1); |
| |
| //check matrix mult and both inputs transposes w/ single consumer |
| if( left instanceof AggBinaryOp && C.getDataType().isMatrix() |
| && HopRewriteUtils.isTransposeOperation(left.getInput().get(0)) |
| && left.getInput().get(0).getParent().size()==1 |
| && HopRewriteUtils.isTransposeOperation(left.getInput().get(1)) |
| && left.getInput().get(1).getParent().size()==1 ) |
| { |
| Hop A = left.getInput().get(0).getInput().get(0); |
| Hop B = left.getInput().get(1).getInput().get(0); |
| |
| AggBinaryOp abop = HopRewriteUtils.createMatrixMultiply(B, A); |
| ReorgOp rop = HopRewriteUtils.createTranspose(C); |
| BinaryOp bop = HopRewriteUtils.createBinary(abop, rop, OpOp2.PLUS); |
| |
| HopRewriteUtils.replaceChildReference(parent, hi, bop, pos); |
| |
| hi = bop; |
| LOG.debug("Applied simplifyTransposeAggBinBinaryChains (line "+hi.getBeginLine()+")."); |
| } |
| } |
| |
| return hi; |
| } |
| |
| // Patterns: X + (X==0) * s -> replace(X, 0, s) |
| private static Hop simplifyReplaceZeroOperation(Hop parent, Hop hi, int pos) |
| { |
| if( HopRewriteUtils.isBinary(hi, OpOp2.PLUS) && hi.getInput().get(0).isMatrix() |
| && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) |
| && hi.getInput().get(1).getInput().get(1).isScalar() |
| && HopRewriteUtils.isBinaryMatrixScalar(hi.getInput().get(1).getInput().get(0), OpOp2.EQUAL, 0) |
| && hi.getInput().get(1).getInput().get(0).getInput().contains(hi.getInput().get(0)) ) |
| { |
| LinkedHashMap<String, Hop> args = new LinkedHashMap<>(); |
| args.put("target", hi.getInput().get(0)); |
| args.put("pattern", new LiteralOp(0)); |
| args.put("replacement", hi.getInput().get(1).getInput().get(1)); |
| Hop replace = HopRewriteUtils.createParameterizedBuiltinOp( |
| hi.getInput().get(0), args, ParamBuiltinOp.REPLACE); |
| HopRewriteUtils.replaceChildReference(parent, hi, replace, pos); |
| hi = replace; |
| LOG.debug("Applied simplifyReplaceZeroOperation (line "+hi.getBeginLine()+")."); |
| } |
| return hi; |
| } |
| |
| /** |
| * Pattners: t(t(X)) -> X, rev(rev(X)) -> X |
| * |
| * @param parent parent high-level operator |
| * @param hi high-level operator |
| * @param pos position |
| * @return high-level operator |
| */ |
| private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos) |
| { |
| ReOrgOp[] lookup = new ReOrgOp[]{ReOrgOp.TRANS, ReOrgOp.REV}; |
| |
| if( hi instanceof ReorgOp && HopRewriteUtils.isValidOp(((ReorgOp)hi).getOp(), lookup) ) //first reorg |
| { |
| ReOrgOp firstOp = ((ReorgOp)hi).getOp(); |
| Hop hi2 = hi.getInput().get(0); |
| if( hi2 instanceof ReorgOp && ((ReorgOp)hi2).getOp()==firstOp ) //second reorg w/ same type |
| { |
| Hop hi3 = hi2.getInput().get(0); |
| //remove unnecessary chain of t(t()) |
| HopRewriteUtils.replaceChildReference(parent, hi, hi3, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi, hi2); |
| hi = hi3; |
| |
| LOG.debug("Applied removeUnecessaryReorgOperation."); |
| } |
| } |
| |
| return hi; |
| } |
| |
| /* |
| * Eliminate RemoveEmpty for SUM, SUM_SQ, and NNZ (number of non-zeros) |
| */ |
| private static Hop removeUnnecessaryRemoveEmpty(Hop parent, Hop hi, int pos) |
| { |
| //check if SUM or SUM_SQ is computed with input rmEmpty without select vector |
| //rewrite pattern: |
| //sum(removeEmpty(target=X)) -> sum(X) |
| //rowSums(removeEmpty(target=X,margin="cols")) -> rowSums(X) |
| //colSums(removeEmpty(target=X,margin="rows")) -> colSums(X) |
| if( (HopRewriteUtils.isSum(hi) || HopRewriteUtils.isSumSq(hi)) |
| && HopRewriteUtils.isRemoveEmpty(hi.getInput().get(0)) |
| && hi.getInput().get(0).getParent().size() == 1 ) |
| { |
| AggUnaryOp agg = (AggUnaryOp)hi; |
| ParameterizedBuiltinOp rmEmpty = (ParameterizedBuiltinOp) hi.getInput().get(0); |
| boolean needRmEmpty = (agg.getDirection() == Direction.Row && HopRewriteUtils.isRemoveEmpty(rmEmpty, true)) |
| || (agg.getDirection() == Direction.Col && HopRewriteUtils.isRemoveEmpty(rmEmpty, false)); |
| |
| if (rmEmpty.getParameterHop("select") == null && !needRmEmpty) { |
| Hop input = rmEmpty.getTargetHop(); |
| if( input != null ) { |
| HopRewriteUtils.replaceChildReference(hi, rmEmpty, input); |
| return hi; //eliminate rmEmpty |
| } |
| } |
| } |
| |
| //check if nrow is called on the output of removeEmpty |
| if( HopRewriteUtils.isUnary(hi, OpOp1.NROW) |
| && HopRewriteUtils.isRemoveEmpty(hi.getInput().get(0), true) |
| && hi.getInput().get(0).getParent().size() == 1 ) |
| { |
| ParameterizedBuiltinOp rm = (ParameterizedBuiltinOp) hi.getInput().get(0); |
| //obtain optional select vector or input if col vector |
| //(nnz will be the same as the select vector if |
| // the select vector is provided and it will be the same |
| // as the input if the select vector is not provided) |
| //NOTE: part of static rewrites despite size dependence for phase |
| //ordering before rewrite for DAG splits after table/removeEmpty |
| Hop input = (rm.getParameterHop("select") != null) ? |
| rm.getParameterHop("select") : |
| (rm.getDim2() == 1) ? rm.getTargetHop() : null; |
| |
| //create new expression w/o rmEmpty if applicable |
| if( input != null ) { |
| HopRewriteUtils.removeAllChildReferences(rm); |
| Hop hnew = HopRewriteUtils.createComputeNnz(input); |
| |
| //modify dag if nnz is called on the output of removeEmpty |
| if( hnew != null ){ |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| hi = hnew; |
| LOG.debug("Applied removeUnnecessaryRemoveEmpty (line " + hi.getBeginLine() + ")"); |
| } |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop removeUnnecessaryMinus(Hop parent, Hop hi, int pos) |
| { |
| if( hi.getDataType() == DataType.MATRIX && hi instanceof BinaryOp |
| && ((BinaryOp)hi).getOp()==OpOp2.MINUS //first minus |
| && hi.getInput().get(0) instanceof LiteralOp && ((LiteralOp)hi.getInput().get(0)).getDoubleValue()==0 ) |
| { |
| Hop hi2 = hi.getInput().get(1); |
| if( hi2.getDataType() == DataType.MATRIX && hi2 instanceof BinaryOp |
| && ((BinaryOp)hi2).getOp()==OpOp2.MINUS //second minus |
| && hi2.getInput().get(0) instanceof LiteralOp && ((LiteralOp)hi2.getInput().get(0)).getDoubleValue()==0 ) |
| |
| { |
| Hop hi3 = hi2.getInput().get(1); |
| //remove unnecessary chain of -(-()) |
| HopRewriteUtils.replaceChildReference(parent, hi, hi3, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi, hi2); |
| hi = hi3; |
| |
| LOG.debug("Applied removeUnecessaryMinus"); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyGroupedAggregate(Hop hi) |
| { |
| if( hi instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp)hi).getOp()==ParamBuiltinOp.GROUPEDAGG ) //aggregate |
| { |
| ParameterizedBuiltinOp phi = (ParameterizedBuiltinOp)hi; |
| |
| if( phi.isCountFunction() //aggregate(fn="count") |
| && phi.getTargetHop().getDim2()==1 ) //only for vector |
| { |
| HashMap<String, Integer> params = phi.getParamIndexMap(); |
| int ix1 = params.get(Statement.GAGG_TARGET); |
| int ix2 = params.get(Statement.GAGG_GROUPS); |
| |
| //check for unnecessary memory consumption for "count" |
| if( ix1 != ix2 && phi.getInput().get(ix1)!=phi.getInput().get(ix2) ) |
| { |
| Hop th = phi.getInput().get(ix1); |
| Hop gh = phi.getInput().get(ix2); |
| |
| HopRewriteUtils.replaceChildReference(hi, th, gh, ix1); |
| |
| LOG.debug("Applied simplifyGroupedAggregateCount"); |
| } |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop fuseMinusNzBinaryOperation(Hop parent, Hop hi, int pos) |
| { |
| //pattern X - (s * ppred(X,0,!=)) -> X -nz s |
| //note: this is done as a hop rewrite in order to significantly reduce the |
| //memory estimate for X - tmp if X is sparse |
| if( HopRewriteUtils.isBinary(hi, OpOp2.MINUS) |
| && hi.getInput().get(0).getDataType()==DataType.MATRIX |
| && hi.getInput().get(1).getDataType()==DataType.MATRIX |
| && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) ) |
| { |
| Hop X = hi.getInput().get(0); |
| Hop s = hi.getInput().get(1).getInput().get(0); |
| Hop pred = hi.getInput().get(1).getInput().get(1); |
| |
| if( s.getDataType()==DataType.SCALAR && pred.getDataType()==DataType.MATRIX |
| && HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL) |
| && pred.getInput().get(0) == X //depend on common subexpression elimination |
| && pred.getInput().get(1) instanceof LiteralOp |
| && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 ) |
| { |
| Hop hnew = HopRewriteUtils.createBinary(X, s, OpOp2.MINUS_NZ); |
| |
| //relink new hop into original position |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| hi = hnew; |
| |
| LOG.debug("Applied fuseMinusNzBinaryOperation (line "+hi.getBeginLine()+")"); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop fuseLogNzUnaryOperation(Hop parent, Hop hi, int pos) |
| { |
| //pattern ppred(X,0,"!=")*log(X) -> log_nz(X) |
| //note: this is done as a hop rewrite in order to significantly reduce the |
| //memory estimate and to prevent dense intermediates if X is ultra sparse |
| if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) |
| && hi.getInput().get(0).getDataType()==DataType.MATRIX |
| && hi.getInput().get(1).getDataType()==DataType.MATRIX |
| && HopRewriteUtils.isUnary(hi.getInput().get(1), OpOp1.LOG) ) |
| { |
| Hop pred = hi.getInput().get(0); |
| Hop X = hi.getInput().get(1).getInput().get(0); |
| |
| if( HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL) |
| && pred.getInput().get(0) == X //depend on common subexpression elimination |
| && pred.getInput().get(1) instanceof LiteralOp |
| && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 ) |
| { |
| Hop hnew = HopRewriteUtils.createUnary(X, OpOp1.LOG_NZ); |
| |
| //relink new hop into original position |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| hi = hnew; |
| |
| LOG.debug("Applied fuseLogNzUnaryOperation (line "+hi.getBeginLine()+")."); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop fuseLogNzBinaryOperation(Hop parent, Hop hi, int pos) |
| { |
| //pattern ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5) |
| //note: this is done as a hop rewrite in order to significantly reduce the |
| //memory estimate and to prevent dense intermediates if X is ultra sparse |
| if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) |
| && hi.getInput().get(0).getDataType()==DataType.MATRIX |
| && hi.getInput().get(1).getDataType()==DataType.MATRIX |
| && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.LOG) ) |
| { |
| Hop pred = hi.getInput().get(0); |
| Hop X = hi.getInput().get(1).getInput().get(0); |
| Hop log = hi.getInput().get(1).getInput().get(1); |
| |
| if( HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL) |
| && pred.getInput().get(0) == X //depend on common subexpression elimination |
| && pred.getInput().get(1) instanceof LiteralOp |
| && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 ) |
| { |
| Hop hnew = HopRewriteUtils.createBinary(X, log, OpOp2.LOG_NZ); |
| |
| //relink new hop into original position |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| hi = hnew; |
| |
| LOG.debug("Applied fuseLogNzBinaryOperation (line "+hi.getBeginLine()+")"); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyOuterSeqExpand(Hop parent, Hop hi, int pos) |
| { |
| //pattern: outer(v, t(seq(1,m)), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false) |
| //note: this rewrite supports both left/right sequence |
| |
| if( HopRewriteUtils.isBinary(hi, OpOp2.EQUAL) && ((BinaryOp)hi).isOuter() ) |
| { |
| if( ( HopRewriteUtils.isTransposeOperation(hi.getInput().get(1)) //pattern a: outer(v, t(seq(1,m)), "==") |
| && HopRewriteUtils.isBasic1NSequence(hi.getInput().get(1).getInput().get(0))) |
| || HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0))) //pattern b: outer(seq(1,m), t(v) "==") |
| { |
| //determine variable parameters for pattern a/b |
| boolean isPatternB = HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)); |
| boolean isTransposeRight = HopRewriteUtils.isTransposeOperation(hi.getInput().get(1)); |
| Hop trgt = isPatternB ? (isTransposeRight ? |
| hi.getInput().get(1).getInput().get(0) : //get v from t(v) |
| HopRewriteUtils.createTranspose(hi.getInput().get(1)) ) : //create v via t(v') |
| hi.getInput().get(0); //get v directly |
| Hop seq = isPatternB ? |
| hi.getInput().get(0) : hi.getInput().get(1).getInput().get(0); |
| String direction = HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)) ? "rows" : "cols"; |
| |
| //setup input parameter hops |
| LinkedHashMap<String,Hop> inputargs = new LinkedHashMap<>(); |
| inputargs.put("target", trgt); |
| inputargs.put("max", HopRewriteUtils.getBasic1NSequenceMax(seq)); |
| inputargs.put("dir", new LiteralOp(direction)); |
| inputargs.put("ignore", new LiteralOp(true)); |
| inputargs.put("cast", new LiteralOp(false)); |
| |
| //create new hop |
| ParameterizedBuiltinOp pbop = HopRewriteUtils |
| .createParameterizedBuiltinOp(trgt, inputargs, ParamBuiltinOp.REXPAND); |
| |
| //relink new hop into original position |
| HopRewriteUtils.replaceChildReference(parent, hi, pbop, pos); |
| hi = pbop; |
| |
| LOG.debug("Applied simplifyOuterSeqExpand (line "+hi.getBeginLine()+")"); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyBinaryComparisonChain(Hop parent, Hop hi, int pos) { |
| if( HopRewriteUtils.isBinaryPPred(hi) |
| && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 0d, 1d) |
| && HopRewriteUtils.isBinaryPPred(hi.getInput().get(0)) ) |
| { |
| BinaryOp bop = (BinaryOp) hi; |
| BinaryOp bop2 = (BinaryOp) hi.getInput().get(0); |
| boolean one = HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 1); |
| |
| //pattern: outer(v1,v2,"!=") == 1 -> outer(v1,v2,"!=") |
| if( (one && bop.getOp() == OpOp2.EQUAL) |
| || (!one && bop.getOp() == OpOp2.NOTEQUAL) ) |
| { |
| HopRewriteUtils.replaceChildReference(parent, bop, bop2, pos); |
| HopRewriteUtils.cleanupUnreferenced(bop); |
| hi = bop2; |
| LOG.debug("Applied simplifyBinaryComparisonChain1 (line "+hi.getBeginLine()+")"); |
| } |
| //pattern: outer(v1,v2,"!=") == 0 -> outer(v1,v2,"==") |
| else if( !one && bop.getOp() == OpOp2.EQUAL ) { |
| OpOp2 optr = bop2.getComplementPPredOperation(); |
| BinaryOp tmp = HopRewriteUtils.createBinary(bop2.getInput().get(0), |
| bop2.getInput().get(1), optr, bop2.isOuter()); |
| HopRewriteUtils.replaceChildReference(parent, bop, tmp, pos); |
| HopRewriteUtils.cleanupUnreferenced(bop, bop2); |
| hi = tmp; |
| LOG.debug("Applied simplifyBinaryComparisonChain0 (line "+hi.getBeginLine()+")"); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyCumsumColOrFullAggregates(Hop hi) { |
| //pattern: colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1)) |
| if( (HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.Col) |
| || HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.RowCol)) |
| && HopRewriteUtils.isUnary(hi.getInput().get(0), OpOp1.CUMSUM) |
| && hi.getInput().get(0).getParent().size()==1) |
| { |
| Hop cumsumX = hi.getInput().get(0); |
| Hop X = cumsumX.getInput().get(0); |
| Hop mult = HopRewriteUtils.createBinary(X, |
| HopRewriteUtils.createSeqDataGenOp(X, false), OpOp2.MULT); |
| HopRewriteUtils.replaceChildReference(hi, cumsumX, mult); |
| HopRewriteUtils.removeAllChildReferences(cumsumX); |
| LOG.debug("Applied simplifyCumsumColOrFullAggregates (line "+hi.getBeginLine()+")"); |
| } |
| return hi; |
| } |
| |
| private static Hop simplifyCumsumReverse(Hop parent, Hop hi, int pos) { |
| //pattern: rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X) |
| if( HopRewriteUtils.isReorg(hi, ReOrgOp.REV) |
| && HopRewriteUtils.isUnary(hi.getInput().get(0), OpOp1.CUMSUM) |
| && hi.getInput().get(0).getParent().size()==1 |
| && HopRewriteUtils.isReorg(hi.getInput().get(0).getInput().get(0), ReOrgOp.REV) |
| && hi.getInput().get(0).getInput().get(0).getParent().size()==1) |
| { |
| Hop cumsumX = hi.getInput().get(0); |
| Hop revX = cumsumX.getInput().get(0); |
| Hop X = revX.getInput().get(0); |
| Hop plus = HopRewriteUtils.createBinary(X, HopRewriteUtils |
| .createAggUnaryOp(X, AggOp.SUM, Direction.Col), OpOp2.PLUS); |
| Hop minus = HopRewriteUtils.createBinary(plus, |
| HopRewriteUtils.createUnary(X, OpOp1.CUMSUM), OpOp2.MINUS); |
| HopRewriteUtils.replaceChildReference(parent, hi, minus, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi, cumsumX, revX); |
| |
| hi = minus; |
| LOG.debug("Applied simplifyCumsumReverse (line "+hi.getBeginLine()+")"); |
| } |
| return hi; |
| } |
| |
| /** |
| * NOTE: currently disabled since this rewrite is INVALID in the |
| * presence of NaNs (because (NaN!=NaN) is true). |
| * |
| * @param parent parent high-level operator |
| * @param hi high-level operator |
| * @param pos position |
| * @return high-level operator |
| */ |
| @SuppressWarnings("unused") |
| private static Hop removeUnecessaryPPred(Hop parent, Hop hi, int pos) |
| { |
| if( hi instanceof BinaryOp ) |
| { |
| BinaryOp bop = (BinaryOp)hi; |
| Hop left = bop.getInput().get(0); |
| Hop right = bop.getInput().get(1); |
| |
| Hop datagen = null; |
| |
| //ppred(X,X,"==") -> matrix(1, rows=nrow(X),cols=nrow(Y)) |
| if( left==right && bop.getOp()==OpOp2.EQUAL || bop.getOp()==OpOp2.GREATEREQUAL || bop.getOp()==OpOp2.LESSEQUAL ) |
| datagen = HopRewriteUtils.createDataGenOp(left, 1); |
| |
| //ppred(X,X,"!=") -> matrix(0, rows=nrow(X),cols=nrow(Y)) |
| if( left==right && bop.getOp()==OpOp2.NOTEQUAL || bop.getOp()==OpOp2.GREATER || bop.getOp()==OpOp2.LESS ) |
| datagen = HopRewriteUtils.createDataGenOp(left, 0); |
| |
| if( datagen != null ) { |
| HopRewriteUtils.replaceChildReference(parent, hi, datagen, pos); |
| hi = datagen; |
| } |
| } |
| |
| return hi; |
| } |
| } |