| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.hops.rewrite; |
| |
| import static org.apache.sysds.hops.OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES; |
| |
| import java.util.ArrayList; |
| import java.util.HashMap; |
| import java.util.LinkedHashMap; |
| import java.util.List; |
| |
| import org.apache.sysds.hops.AggBinaryOp; |
| import org.apache.sysds.hops.AggUnaryOp; |
| import org.apache.sysds.hops.BinaryOp; |
| import org.apache.sysds.hops.DataGenOp; |
| import org.apache.sysds.hops.Hop; |
| import org.apache.sysds.hops.IndexingOp; |
| import org.apache.sysds.hops.LeftIndexingOp; |
| import org.apache.sysds.hops.LiteralOp; |
| import org.apache.sysds.hops.NaryOp; |
| import org.apache.sysds.hops.OptimizerUtils; |
| import org.apache.sysds.hops.ParameterizedBuiltinOp; |
| import org.apache.sysds.hops.QuaternaryOp; |
| import org.apache.sysds.hops.ReorgOp; |
| import org.apache.sysds.hops.TernaryOp; |
| import org.apache.sysds.hops.UnaryOp; |
| import org.apache.sysds.common.Types.AggOp; |
| import org.apache.sysds.common.Types.Direction; |
| import org.apache.sysds.common.Types.OpOp1; |
| import org.apache.sysds.common.Types.OpOp2; |
| import org.apache.sysds.common.Types.OpOp3; |
| import org.apache.sysds.common.Types.OpOp4; |
| import org.apache.sysds.common.Types.OpOpDG; |
| import org.apache.sysds.common.Types.OpOpN; |
| import org.apache.sysds.common.Types.ParamBuiltinOp; |
| import org.apache.sysds.common.Types.ReOrgOp; |
| import org.apache.sysds.lops.MapMultChain.ChainType; |
| import org.apache.sysds.parser.DataExpression; |
| import org.apache.sysds.common.Types.DataType; |
| import org.apache.sysds.common.Types.ValueType; |
| |
| /** |
| * Rule: Algebraic Simplifications. Simplifies binary expressions |
| * in terms of two major purposes: (1) rewrite binary operations |
| * to unary operations when possible (in CP this reduces the memory |
| * estimate, in MR this allows map-only operations and hence prevents |
| * unnecessary shuffle and sort) and (2) remove binary operations that |
| * are in itself are unnecessary (e.g., *1 and /1). |
| * |
| */ |
| public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule |
| { |
| //valid aggregation operation types for rowOp to Op conversions (not all operations apply) |
| private static AggOp[] LOOKUP_VALID_ROW_COL_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.VAR}; |
| |
| //valid aggregation operation types for empty (sparse-safe) operations (not all operations apply) |
| //AggOp.MEAN currently not due to missing count/corrections |
| private static AggOp[] LOOKUP_VALID_EMPTY_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.PROD, AggOp.TRACE}; |
| private static AggOp[] LOOKUP_VALID_UNNECESSARY_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.MIN, AggOp.MAX, AggOp.PROD, AggOp.TRACE}; |
| |
| //valid unary operation types for empty (sparse-safe) operations (not all operations apply) |
| private static OpOp1[] LOOKUP_VALID_EMPTY_UNARY = new OpOp1[]{OpOp1.ABS, OpOp1.SIN, OpOp1.TAN, OpOp1.SQRT, OpOp1.ROUND, OpOp1.CUMSUM}; |
| |
| //valid pseudo-sparse-safe binary operators for wdivmm |
| private static OpOp2[] LOOKUP_VALID_WDIVMM_BINARY = new OpOp2[]{OpOp2.MULT, OpOp2.DIV}; |
| |
| //valid unary and binary operators for wumm |
| private static OpOp1[] LOOKUP_VALID_WUMM_UNARY = new OpOp1[]{OpOp1.ABS, OpOp1.ROUND, OpOp1.CEIL, OpOp1.FLOOR, OpOp1.EXP, OpOp1.LOG, OpOp1.SQRT, OpOp1.SIGMOID, OpOp1.SPROP}; |
| private static OpOp2[] LOOKUP_VALID_WUMM_BINARY = new OpOp2[]{OpOp2.MULT, OpOp2.POW}; |
| |
| @Override |
| public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) { |
| if( roots == null ) |
| return roots; |
| |
| //one pass rewrite-descend (rewrite created pattern) |
| for( Hop h : roots ) |
| rule_AlgebraicSimplification( h, false ); |
| Hop.resetVisitStatus(roots, true); |
| |
| //one pass descend-rewrite (for rollup) |
| for( Hop h : roots ) |
| rule_AlgebraicSimplification( h, true ); |
| Hop.resetVisitStatus(roots, true); |
| |
| return roots; |
| } |
| |
| @Override |
| public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { |
| if( root == null ) |
| return root; |
| |
| //one pass rewrite-descend (rewrite created pattern) |
| rule_AlgebraicSimplification( root, false ); |
| |
| root.resetVisitStatus(); |
| |
| //one pass descend-rewrite (for rollup) |
| rule_AlgebraicSimplification( root, true ); |
| |
| return root; |
| } |
| |
| |
| /** |
| * Note: X/y -> X * 1/y would be useful because * cheaper than / and sparsesafe; however, |
| * (1) the results would be not exactly the same (2 rounds instead of 1) and (2) it should |
| * come before constant folding while the other simplifications should come after constant |
| * folding. Hence, not applied yet. |
| * |
| * @param hop high-level operator |
| * @param descendFirst true if recursively process children first |
| */ |
| private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst) |
| { |
| if(hop.isVisited()) |
| return; |
| |
| //recursively process children |
| for( int i=0; i<hop.getInput().size(); i++) |
| { |
| Hop hi = hop.getInput().get(i); |
| |
| //process childs recursively first (to allow roll-up) |
| if( descendFirst ) |
| rule_AlgebraicSimplification(hi, descendFirst); //see below |
| |
| //apply actual simplification rewrites (of childs incl checks) |
| hi = removeEmptyRightIndexing(hop, hi, i); //e.g., X[,1] -> matrix(0,ru-rl+1,cu-cl+1), if nnz(X)==0 |
| hi = removeUnnecessaryRightIndexing(hop, hi, i); //e.g., X[,1] -> X, if output == input size |
| hi = removeEmptyLeftIndexing(hop, hi, i); //e.g., X[,1]=Y -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0 and nnz(Y)==0 |
| hi = removeUnnecessaryLeftIndexing(hop, hi, i); //e.g., X[,1]=Y -> Y, if output == input dims |
| if(OptimizerUtils.ALLOW_OPERATOR_FUSION) |
| hi = fuseLeftIndexingChainToAppend(hop, hi, i); //e.g., X[,1]=A; X[,2]=B -> X=cbind(A,B), iff ncol(X)==2 and col1/2 lix |
| hi = removeUnnecessaryCumulativeOp(hop, hi, i); //e.g., cumsum(X) -> X, if nrow(X)==1; |
| hi = removeUnnecessaryReorgOperation(hop, hi, i); //e.g., matrix(X) -> X, if dims(in)==dims(out); r(X)->X, if 1x1 dims |
| hi = removeUnnecessaryOuterProduct(hop, hi, i); //e.g., X*(Y%*%matrix(1,...) -> X*Y, if Y col vector |
| hi = removeUnnecessaryIfElseOperation(hop, hi, i);//e.g., ifelse(E, A, B) -> A, if E==TRUE or nnz(E)==length(E) |
| hi = removeUnnecessaryAppendTSMM(hop, hi, i); //e.g., X = t(rbind(A,B,C)) %*% rbind(A,B,C) -> t(A)%*%A + t(B)%*%B + t(C)%*%C |
| if(OptimizerUtils.ALLOW_OPERATOR_FUSION) |
| hi = fuseDatagenAndReorgOperation(hop, hi, i); //e.g., t(rand(rows=10,cols=1)) -> rand(rows=1,cols=10), if one dim=1 |
| hi = simplifyColwiseAggregate(hop, hi, i); //e.g., colsums(X) -> sum(X) or X, if col/row vector |
| hi = simplifyRowwiseAggregate(hop, hi, i); //e.g., rowsums(X) -> sum(X) or X, if row/col vector |
| hi = simplifyColSumsMVMult(hop, hi, i); //e.g., colSums(X*Y) -> t(Y) %*% X, if Y col vector |
| hi = simplifyRowSumsMVMult(hop, hi, i); //e.g., rowSums(X*Y) -> X %*% t(Y), if Y row vector |
| hi = simplifyUnnecessaryAggregate(hop, hi, i); //e.g., sum(X) -> as.scalar(X), if 1x1 dims |
| hi = simplifyEmptyAggregate(hop, hi, i); //e.g., sum(X) -> 0, if nnz(X)==0 |
| hi = simplifyEmptyUnaryOperation(hop, hi, i); //e.g., round(X) -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0 |
| hi = simplifyEmptyReorgOperation(hop, hi, i); //e.g., t(X) -> matrix(0, ncol(X), nrow(X)) |
| hi = simplifyEmptySortOperation(hop, hi, i); //e.g., order(X) -> seq(1, nrow(X)), if nnz(X)==0 |
| hi = simplifyEmptyMatrixMult(hop, hi, i); //e.g., X%*%Y -> matrix(0,...), if nnz(Y)==0 | X if Y==matrix(1,1,1) |
| hi = simplifyIdentityRepMatrixMult(hop, hi, i); //e.g., X%*%y -> X if y matrix(1,1,1); |
| hi = simplifyScalarMatrixMult(hop, hi, i); //e.g., X%*%y -> X*as.scalar(y), if y is a 1-1 matrix |
| hi = simplifyMatrixMultDiag(hop, hi, i); //e.g., diag(X)%*%Y -> X*Y, if ncol(Y)==1 / -> Y*X if ncol(Y)>1 |
| hi = simplifyDiagMatrixMult(hop, hi, i); //e.g., diag(X%*%Y)->rowSums(X*t(Y)); if col vector |
| hi = simplifySumDiagToTrace(hi); //e.g., sum(diag(X)) -> trace(X); if col vector |
| hi = simplifyLowerTriExtraction(hop, hi, i); //e.g., X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri |
| hi = pushdownBinaryOperationOnDiag(hop, hi, i); //e.g., diag(X)*7 -> diag(X*7); if col vector |
| hi = pushdownSumOnAdditiveBinary(hop, hi, i); //e.g., sum(A+B) -> sum(A)+sum(B); if dims(A)==dims(B) |
| if(OptimizerUtils.ALLOW_OPERATOR_FUSION) { |
| hi = simplifyWeightedSquaredLoss(hop, hi, i); //e.g., sum(W * (X - U %*% t(V)) ^ 2) -> wsl(X, U, t(V), W, true), |
| hi = simplifyWeightedSigmoidMMChains(hop, hi, i); //e.g., W * sigmoid(Y%*%t(X)) -> wsigmoid(W, Y, t(X), type) |
| hi = simplifyWeightedDivMM(hop, hi, i); //e.g., t(U) %*% (X/(U%*%t(V))) -> wdivmm(X, U, t(V), left) |
| hi = simplifyWeightedCrossEntropy(hop, hi, i); //e.g., sum(X*log(U%*%t(V))) -> wcemm(X, U, t(V)) |
| hi = simplifyWeightedUnaryMM(hop, hi, i); //e.g., X*exp(U%*%t(V)) -> wumm(X, U, t(V), exp) |
| hi = simplifyDotProductSum(hop, hi, i); //e.g., sum(v^2) -> t(v)%*%v if ncol(v)==1 |
| hi = fuseSumSquared(hop, hi, i); //e.g., sum(X^2) -> sumSq(X), if ncol(X)>1 |
| hi = fuseAxpyBinaryOperationChain(hop, hi, i); //e.g., (X+s*Y) -> (X+*s Y), (X-s*Y) -> (X-*s Y) |
| } |
| hi = reorderMinusMatrixMult(hop, hi, i); //e.g., (-t(X))%*%y->-(t(X)%*%y), TODO size |
| hi = simplifySumMatrixMult(hop, hi, i); //e.g., sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), if not dot product / wsloss |
| hi = simplifyEmptyBinaryOperation(hop, hi, i); //e.g., X*Y -> matrix(0,nrow(X), ncol(X)) / X+Y->X / X-Y -> X |
| hi = simplifyScalarMVBinaryOperation(hi); //e.g., X*y -> X*as.scalar(y), if y is a 1-1 matrix |
| hi = simplifyNnzComputation(hop, hi, i); //e.g., sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known |
| hi = simplifyNrowNcolComputation(hop, hi, i); //e.g., nrow(X) -> literal(nrow(X)), if nrow known to remove data dependency |
| hi = simplifyTableSeqExpand(hop, hi, i); //e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true) |
| if( OptimizerUtils.ALLOW_OPERATOR_FUSION ) |
| foldMultipleMinMaxOperations(hi); //e.g., min(X,min(min(3,7),Y)) -> min(X,3,7,Y) |
| |
| //process childs recursively after rewrites (to investigate pattern newly created by rewrites) |
| if( !descendFirst ) |
| rule_AlgebraicSimplification(hi, descendFirst); |
| } |
| |
| hop.setVisited(); |
| } |
| |
| private static Hop removeEmptyRightIndexing(Hop parent, Hop hi, int pos) |
| { |
| if( hi instanceof IndexingOp && hi.getDataType()==DataType.MATRIX ) //indexing op |
| { |
| Hop input = hi.getInput().get(0); |
| if( input.getNnz()==0 && //nnz input known and empty |
| HopRewriteUtils.isDimsKnown(hi)) //output dims known |
| { |
| //remove unnecessary right indexing |
| Hop hnew = HopRewriteUtils.createDataGenOpByVal( new LiteralOp(hi.getDim1()), |
| new LiteralOp(hi.getDim2()), null, DataType.MATRIX, ValueType.FP64, 0); |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi, input); |
| hi = hnew; |
| |
| LOG.debug("Applied removeEmptyRightIndexing"); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop removeUnnecessaryRightIndexing(Hop parent, Hop hi, int pos) |
| { |
| if( HopRewriteUtils.isUnnecessaryRightIndexing(hi) ) { |
| //remove unnecessary right indexing |
| Hop input = hi.getInput().get(0); |
| HopRewriteUtils.replaceChildReference(parent, hi, input, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| hi = input; |
| |
| LOG.debug("Applied removeUnnecessaryRightIndexing"); |
| } |
| |
| return hi; |
| } |
| |
| private static Hop removeEmptyLeftIndexing(Hop parent, Hop hi, int pos) |
| { |
| if( hi instanceof LeftIndexingOp && hi.getDataType() == DataType.MATRIX ) //left indexing op |
| { |
| Hop input1 = hi.getInput().get(0); //lhs matrix |
| Hop input2 = hi.getInput().get(1); //rhs matrix |
| |
| if( input1.getNnz()==0 //nnz original known and empty |
| && input2.getNnz()==0 ) //nnz input known and empty |
| { |
| //remove unnecessary right indexing |
| Hop hnew = HopRewriteUtils.createDataGenOp( input1, 0); |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi, input2); |
| hi = hnew; |
| |
| LOG.debug("Applied removeEmptyLeftIndexing"); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop removeUnnecessaryLeftIndexing(Hop parent, Hop hi, int pos) |
| { |
| if( hi instanceof LeftIndexingOp ) //left indexing op |
| { |
| Hop input = hi.getInput().get(1); //rhs matrix/frame |
| |
| if( HopRewriteUtils.isEqualSize(hi, input) ) //equal dims |
| { |
| //equal dims of left indexing input and output -> no need for indexing |
| |
| //remove unnecessary right indexing |
| HopRewriteUtils.replaceChildReference(parent, hi, input, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| hi = input; |
| |
| LOG.debug("Applied removeUnnecessaryLeftIndexing"); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop fuseLeftIndexingChainToAppend(Hop parent, Hop hi, int pos) |
| { |
| boolean applied = false; |
| |
| //pattern1: X[,1]=A; X[,2]=B -> X=cbind(A,B); matrix / frame |
| if( hi instanceof LeftIndexingOp //first lix |
| && HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp)hi) |
| && hi.getInput().get(0) instanceof LeftIndexingOp //second lix |
| && HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp)hi.getInput().get(0)) |
| && hi.getInput().get(0).getParent().size()==1 //first lix is single consumer |
| && hi.getInput().get(0).getInput().get(0).getDim2() == 2 ) //two column matrix |
| { |
| Hop input2 = hi.getInput().get(1); //rhs matrix |
| Hop pred2 = hi.getInput().get(4); //cl=cu |
| Hop input1 = hi.getInput().get(0).getInput().get(1); //lhs matrix |
| Hop pred1 = hi.getInput().get(0).getInput().get(4); //cl=cu |
| |
| if( pred1 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred1)==1 |
| && pred2 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred2)==2 |
| && input1.getDataType()!=DataType.SCALAR && input2.getDataType()!=DataType.SCALAR ) |
| { |
| //create new cbind operation and rewrite inputs |
| BinaryOp bop = HopRewriteUtils.createBinary(input1, input2, OpOp2.CBIND); |
| HopRewriteUtils.replaceChildReference(parent, hi, bop, pos); |
| |
| hi = bop; |
| applied = true; |
| } |
| } |
| |
| //pattern1: X[1,]=A; X[2,]=B -> X=rbind(A,B) |
| if( !applied && hi instanceof LeftIndexingOp //first lix |
| && HopRewriteUtils.isFullRowIndexing((LeftIndexingOp)hi) |
| && hi.getInput().get(0) instanceof LeftIndexingOp //second lix |
| && HopRewriteUtils.isFullRowIndexing((LeftIndexingOp)hi.getInput().get(0)) |
| && hi.getInput().get(0).getParent().size()==1 //first lix is single consumer |
| && hi.getInput().get(0).getInput().get(0).getDim1() == 2 ) //two column matrix |
| { |
| Hop input2 = hi.getInput().get(1); //rhs matrix |
| Hop pred2 = hi.getInput().get(2); //rl=ru |
| Hop input1 = hi.getInput().get(0).getInput().get(1); //lhs matrix |
| Hop pred1 = hi.getInput().get(0).getInput().get(2); //rl=ru |
| |
| if( pred1 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred1)==1 |
| && pred2 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred2)==2 |
| && input1.getDataType()!=DataType.SCALAR && input2.getDataType()!=DataType.SCALAR ) |
| { |
| //create new cbind operation and rewrite inputs |
| BinaryOp bop = HopRewriteUtils.createBinary(input1, input2, OpOp2.RBIND); |
| HopRewriteUtils.replaceChildReference(parent, hi, bop, pos); |
| |
| hi = bop; |
| applied = true; |
| |
| LOG.debug("Applied fuseLeftIndexingChainToAppend2 (line "+hi.getBeginLine()+")"); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop removeUnnecessaryCumulativeOp(Hop parent, Hop hi, int pos) |
| { |
| if( hi instanceof UnaryOp && ((UnaryOp)hi).isCumulativeUnaryOperation() ) |
| { |
| Hop input = hi.getInput().get(0); //input matrix |
| |
| if( HopRewriteUtils.isDimsKnown(input) //dims input known |
| && input.getDim1()==1 ) //1 row |
| { |
| OpOp1 op = ((UnaryOp)hi).getOp(); |
| |
| //remove unnecessary unary cumsum operator |
| HopRewriteUtils.replaceChildReference(parent, hi, input, pos); |
| hi = input; |
| |
| LOG.debug("Applied removeUnnecessaryCumulativeOp: "+op); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos) |
| { |
| if( hi instanceof ReorgOp ) |
| { |
| ReorgOp rop = (ReorgOp) hi; |
| Hop input = hi.getInput().get(0); |
| boolean apply = false; |
| |
| //equal dims of reshape input and output -> no need for reshape because |
| //byrow always refers to both input/output and hence gives the same result |
| apply |= (rop.getOp()==ReOrgOp.RESHAPE && HopRewriteUtils.isEqualSize(hi, input)); |
| |
| //1x1 dimensions of transpose/reshape -> no need for reorg |
| apply |= ((rop.getOp()==ReOrgOp.TRANS || rop.getOp()==ReOrgOp.RESHAPE) |
| && rop.getDim1()==1 && rop.getDim2()==1); |
| |
| if( apply ) { |
| HopRewriteUtils.replaceChildReference(parent, hi, input, pos); |
| hi = input; |
| LOG.debug("Applied removeUnnecessaryReorg."); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop removeUnnecessaryOuterProduct(Hop parent, Hop hi, int pos) |
| { |
| if( hi instanceof BinaryOp ) //binary cell operation |
| { |
| OpOp2 bop = ((BinaryOp)hi).getOp(); |
| Hop left = hi.getInput().get(0); |
| Hop right = hi.getInput().get(1); |
| |
| //check for matrix-vector column replication: (A + b %*% ones) -> (A + b) |
| if( HopRewriteUtils.isMatrixMultiply(right) //matrix mult with datagen |
| && HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(1), 1) |
| && right.getInput().get(0).getDim2() == 1 ) //column vector for mv binary |
| { |
| //remove unnecessary outer product |
| HopRewriteUtils.replaceChildReference(hi, right, right.getInput().get(0), 1 ); |
| HopRewriteUtils.cleanupUnreferenced(right); |
| |
| LOG.debug("Applied removeUnnecessaryOuterProduct1 (line "+right.getBeginLine()+")"); |
| } |
| //check for matrix-vector row replication: (A + ones %*% b) -> (A + b) |
| else if( HopRewriteUtils.isMatrixMultiply(right) //matrix mult with datagen |
| && HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(0), 1) |
| && right.getInput().get(1).getDim1() == 1 ) //row vector for mv binary |
| { |
| //remove unnecessary outer product |
| HopRewriteUtils.replaceChildReference(hi, right, right.getInput().get(1), 1 ); |
| HopRewriteUtils.cleanupUnreferenced(right); |
| |
| LOG.debug("Applied removeUnnecessaryOuterProduct2 (line "+right.getBeginLine()+")"); |
| } |
| //check for vector-vector column replication: (a %*% ones) == b) -> outer(a, b, "==") |
| else if(HopRewriteUtils.isValidOuterBinaryOp(bop) |
| && HopRewriteUtils.isMatrixMultiply(left) |
| && HopRewriteUtils.isDataGenOpWithConstantValue(left.getInput().get(1), 1) |
| && (left.getInput().get(0).getDim2() == 1 //outer product |
| || left.getInput().get(1).getDim1() == 1) |
| && left.getDim1() != 1 && right.getDim1() == 1 ) //outer vector binary |
| { |
| Hop hnew = HopRewriteUtils.createBinary(left.getInput().get(0), right, bop, true); |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| |
| hi = hnew; |
| LOG.debug("Applied removeUnnecessaryOuterProduct3 (line "+right.getBeginLine()+")"); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop removeUnnecessaryIfElseOperation(Hop parent, Hop hi, int pos) |
| { |
| if( !HopRewriteUtils.isTernary(hi, OpOp3.IFELSE) ) |
| return hi; |
| |
| Hop expr = hi.getInput().get(0); |
| Hop first = hi.getInput().get(1); |
| Hop second = hi.getInput().get(2); |
| boolean applied = false; |
| |
| //pattern 1: ifelse(TRUE/FALSE, A, B) -> A/B (constant scalar predicate) |
| if( expr instanceof LiteralOp ) { |
| Hop hnew = ((LiteralOp)expr).getBooleanValue() ? first : second; |
| if( HopRewriteUtils.isEqualSize(hnew, hi) ) { |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos ); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| LOG.debug("Applied removeUnnecessaryIfElse1 (line "+hi.getBeginLine()+")"); |
| hi = hnew; applied = true; |
| } |
| } |
| //pattern 2: ifelse(E, A, B) -> A/B if nnz(E)==length(E) or nnz(E)==0 (constant matrix predicate) |
| if( !applied && expr.getNnz()==expr.getLength() || expr.getNnz()==0 ) { |
| Hop hnew = expr.getNnz()==0 ? second : first; |
| if( HopRewriteUtils.isEqualSize(hnew, hi) ) { |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos ); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| LOG.debug("Applied removeUnnecessaryIfElse2 (line "+hi.getBeginLine()+")"); |
| hi = hnew; applied = true; |
| } |
| } |
| //pattern 3: ifelse(E, A, A) -> A (same input) |
| if( !applied && first == second //dep CSE |
| && HopRewriteUtils.isEqualSize(first, hi) ){ |
| HopRewriteUtils.replaceChildReference(parent, hi, first, pos ); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| LOG.debug("Applied removeUnnecessaryIfElse3 (line "+hi.getBeginLine()+")"); |
| hi = first; |
| } |
| |
| return hi; |
| } |
| |
| private static Hop removeUnnecessaryAppendTSMM(Hop parent, Hop hi, int pos) |
| { |
| Hop hnew = null; |
| //pattern 1: X = t(rbind(A,B,C)) %*% rbind(A,B,C) -> t(A)%*%A + t(B)%*%B + t(C)%*%C |
| int branch = -1; |
| if( HopRewriteUtils.isTsmm(hi) |
| && HopRewriteUtils.isTransposeOperation(hi.getInput().get(0)) |
| && HopRewriteUtils.isNary(hi.getInput().get(1), OpOpN.RBIND) ) |
| { |
| List<Hop> inputs = hi.getInput().get(1).getInput(); |
| if( HopRewriteUtils.checkAvgRowsGteCols(inputs) ) { |
| Hop[] tsmms = inputs.stream() |
| .map(h -> HopRewriteUtils.createTsmm(h, true)).toArray(Hop[]::new); |
| hnew = HopRewriteUtils.createNary(OpOpN.PLUS, tsmms); |
| //cleanup parent references from rbind |
| //HopRewriteUtils.removeAllChildReferences(hi.getInput().get(1)); |
| branch = 1; |
| } |
| } |
| //pattern 2: X = t(rbind(A,B,C)) %*% rbind(D,E,F) -> t(A)%*%D + t(B)%*%E + t(C)%*%F |
| else if( HopRewriteUtils.isMatrixMultiply(hi) |
| && HopRewriteUtils.isTransposeOperation(hi.getInput().get(0)) |
| && HopRewriteUtils.isNary(hi.getInput().get(0).getInput().get(0), OpOpN.RBIND) |
| && HopRewriteUtils.isNary(hi.getInput().get(1), OpOpN.RBIND) ) |
| { |
| List<Hop> inputs1 = hi.getInput().get(0).getInput().get(0).getInput(); |
| List<Hop> inputs2 = hi.getInput().get(1).getInput(); |
| if( HopRewriteUtils.checkAvgRowsGteCols(inputs1) |
| && HopRewriteUtils.checkAvgRowsGteCols(inputs2) |
| && HopRewriteUtils.checkConsistentRows(inputs1, inputs2) ) |
| { |
| Hop[] mms = new Hop[inputs1.size()]; |
| for( int i=0; i<inputs1.size(); i++ ) |
| mms[i] = HopRewriteUtils.createMatrixMultiply( |
| HopRewriteUtils.createTranspose(inputs1.get(i)), inputs2.get(i)); |
| hnew = HopRewriteUtils.createNary(OpOpN.PLUS, mms); |
| //cleanup parent references from rbind left/right |
| //HopRewriteUtils.removeAllChildReferences(hi.getInput().get(0).getInput().get(0)); |
| //HopRewriteUtils.removeAllChildReferences(hi.getInput().get(1)); |
| branch = 2; |
| } |
| } |
| //pattern 3: X = t(cbind(A, B)) %*% cbind(A, B), w/ one cbind consumer (twice in tsmm) |
| else if( HopRewriteUtils.isTsmm(hi) && hi.getInput().get(1).getParent().size()==2 |
| && HopRewriteUtils.isTransposeOperation(hi.getInput().get(0)) |
| && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.CBIND) ) |
| { |
| Hop input1 = hi.getInput().get(1).getInput().get(0); |
| Hop input2 = hi.getInput().get(1).getInput().get(1); |
| if( input1.getDim1() > input1.getDim2() && input2.getDim2() == 1 ) { |
| hnew = HopRewriteUtils.createPartialTsmmCbind( |
| input1, input2, HopRewriteUtils.createTsmm(input1, true)); |
| branch = 3; |
| } |
| } |
| |
| //modify dag if one of the above rules applied |
| if( hnew != null ){ |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| HopRewriteUtils.removeAllChildReferences(hi); |
| hi = hnew; |
| LOG.debug("Applied removeUnnecessaryAppendTSMM" |
| + branch + " (line " + hi.getBeginLine() + ")"); |
| } |
| |
| return hi; |
| } |
| |
| @SuppressWarnings("unchecked") |
| private static Hop fuseDatagenAndReorgOperation(Hop parent, Hop hi, int pos) |
| { |
| if( HopRewriteUtils.isTransposeOperation(hi) |
| && hi.getInput().get(0) instanceof DataGenOp //datagen |
| && hi.getInput().get(0).getParent().size()==1 ) //transpose only consumer |
| { |
| DataGenOp dop = (DataGenOp)hi.getInput().get(0); |
| if( (dop.getOp() == OpOpDG.RAND || dop.getOp() == OpOpDG.SINIT) |
| && (dop.getDim1()==1 || dop.getDim2()==1 )) |
| { |
| //relink all parents and dataop (remove transpose) |
| HopRewriteUtils.removeAllChildReferences(hi); |
| ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone(); |
| for( int i=0; i<parents.size(); i++ ) { |
| Hop lparent = parents.get(i); |
| int ppos = HopRewriteUtils.getChildReferencePos(lparent, hi); |
| HopRewriteUtils.removeChildReferenceByPos(lparent, hi, ppos); |
| HopRewriteUtils.addChildReference(lparent, dop, pos); |
| } |
| |
| //flip rows/cols attributes in datagen |
| HashMap<String, Integer> rparams = dop.getParamIndexMap(); |
| int pos1 = rparams.get(DataExpression.RAND_ROWS); |
| int pos2 = rparams.get(DataExpression.RAND_COLS); |
| rparams.put(DataExpression.RAND_ROWS, pos2); |
| rparams.put(DataExpression.RAND_COLS, pos1); |
| dop.refreshSizeInformation(); |
| |
| hi = dop; |
| |
| LOG.debug("Applied fuseDatagenReorgOperation."); |
| } |
| } |
| |
| return hi; |
| } |
| |
| @SuppressWarnings("unchecked") |
| private static Hop simplifyColwiseAggregate( Hop parent, Hop hi, int pos ) { |
| if( hi instanceof AggUnaryOp ) |
| { |
| AggUnaryOp uhi = (AggUnaryOp)hi; |
| Hop input = uhi.getInput().get(0); |
| |
| if( HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE) ) { |
| if( uhi.getDirection() == Direction.Col ) |
| { |
| if( input.getDim1() == 1 ) |
| { |
| if (uhi.getOp() == AggOp.VAR) { |
| // For the column variance aggregation, if the input is a row vector, |
| // the column variances will each be zero. |
| // Therefore, perform a rewrite from COLVAR(X) to a row vector of zeros. |
| Hop emptyRow = HopRewriteUtils.createDataGenOp(uhi, input, 0); |
| HopRewriteUtils.replaceChildReference(parent, hi, emptyRow, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi, input); |
| hi = emptyRow; |
| |
| LOG.debug("Applied simplifyColwiseAggregate for colVars"); |
| } else { |
| // All other valid column aggregations over a row vector will result |
| // in the row vector itself. |
| // Therefore, remove unnecessary col aggregation for 1 row. |
| HopRewriteUtils.replaceChildReference(parent, hi, input, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| hi = input; |
| |
| LOG.debug("Applied simplifyColwiseAggregate1"); |
| } |
| } |
| else if( input.getDim2() == 1 ) |
| { |
| //get old parents (before creating cast over aggregate) |
| ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone(); |
| |
| //simplify col-aggregate to full aggregate |
| uhi.setDirection(Direction.RowCol); |
| uhi.setDataType(DataType.SCALAR); |
| |
| //create cast to keep same output datatype |
| UnaryOp cast = HopRewriteUtils.createUnary(uhi, OpOp1.CAST_AS_MATRIX); |
| |
| //rehang cast under all parents |
| for( Hop p : parents ) { |
| int ix = HopRewriteUtils.getChildReferencePos(p, hi); |
| HopRewriteUtils.replaceChildReference(p, hi, cast, ix); |
| } |
| |
| hi = cast; |
| |
| LOG.debug("Applied simplifyColwiseAggregate2"); |
| } |
| } |
| } |
| } |
| |
| return hi; |
| } |
| |
| @SuppressWarnings("unchecked") |
| private static Hop simplifyRowwiseAggregate( Hop parent, Hop hi, int pos ) { |
| if( hi instanceof AggUnaryOp ) |
| { |
| AggUnaryOp uhi = (AggUnaryOp)hi; |
| Hop input = uhi.getInput().get(0); |
| |
| if( HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE) ) { |
| if( uhi.getDirection() == Direction.Row ) |
| { |
| if( input.getDim2() == 1 ) |
| { |
| if (uhi.getOp() == AggOp.VAR) { |
| // For the row variance aggregation, if the input is a column vector, |
| // the row variances will each be zero. |
| // Therefore, perform a rewrite from ROWVAR(X) to a column vector of |
| // zeros. |
| Hop emptyCol = HopRewriteUtils.createDataGenOp(input, uhi, 0); |
| HopRewriteUtils.replaceChildReference(parent, hi, emptyCol, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi, input); |
| |
| // replace current HOP with new empty column HOP |
| hi = emptyCol; |
| |
| LOG.debug("Applied simplifyRowwiseAggregate for rowVars"); |
| } else { |
| // All other valid row aggregations over a column vector will result |
| // in the column vector itself. |
| // Therefore, remove unnecessary row aggregation for 1 col |
| HopRewriteUtils.replaceChildReference(parent, hi, input, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| hi = input; |
| |
| LOG.debug("Applied simplifyRowwiseAggregate1"); |
| } |
| } |
| else if( input.getDim1() == 1 ) |
| { |
| //get old parents (before creating cast over aggregate) |
| ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone(); |
| |
| //simplify row-aggregate to full aggregate |
| uhi.setDirection(Direction.RowCol); |
| uhi.setDataType(DataType.SCALAR); |
| |
| //create cast to keep same output datatype |
| UnaryOp cast = HopRewriteUtils.createUnary(uhi, OpOp1.CAST_AS_MATRIX); |
| |
| //rehang cast under all parents |
| for( Hop p : parents ) { |
| int ix = HopRewriteUtils.getChildReferencePos(p, hi); |
| HopRewriteUtils.replaceChildReference(p, hi, cast, ix); |
| } |
| |
| hi = cast; |
| |
| LOG.debug("Applied simplifyRowwiseAggregate2"); |
| } |
| } |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyColSumsMVMult( Hop parent, Hop hi, int pos ) |
| { |
| //colSums(X*Y) -> t(Y) %*% X, if Y col vector; additional transpose later |
| //removed by other rewrite if unnecessary, i.e., if Y==t(Z) |
| if( hi instanceof AggUnaryOp ) |
| { |
| AggUnaryOp uhi = (AggUnaryOp)hi; |
| Hop input = uhi.getInput().get(0); |
| |
| if( uhi.getOp() == AggOp.SUM && uhi.getDirection() == Direction.Col //colsums |
| && HopRewriteUtils.isBinary(input, OpOp2.MULT) ) //b(*) |
| { |
| Hop left = input.getInput().get(0); |
| Hop right = input.getInput().get(1); |
| |
| if( left.getDim1()>1 && left.getDim2()>1 |
| && right.getDim1()>1 && right.getDim2()==1 ) // MV (col vector) |
| { |
| //create new operators |
| ReorgOp trans = HopRewriteUtils.createTranspose(right); |
| AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, left); |
| |
| //relink new child |
| HopRewriteUtils.replaceChildReference(parent, hi, mmult, pos); |
| HopRewriteUtils.cleanupUnreferenced(uhi, input); |
| hi = mmult; |
| |
| LOG.debug("Applied simplifyColSumsMVMult"); |
| } |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyRowSumsMVMult( Hop parent, Hop hi, int pos ) |
| { |
| //rowSums(X * Y) -> X %*% t(Y), if Y row vector; additional transpose later |
| //removed by other rewrite if unnecessary, i.e., if Y==t(Z) |
| if( hi instanceof AggUnaryOp ) |
| { |
| AggUnaryOp uhi = (AggUnaryOp)hi; |
| Hop input = uhi.getInput().get(0); |
| |
| if( uhi.getOp() == AggOp.SUM && uhi.getDirection() == Direction.Row //rowsums |
| && HopRewriteUtils.isBinary(input, OpOp2.MULT) ) //b(*) |
| { |
| Hop left = input.getInput().get(0); |
| Hop right = input.getInput().get(1); |
| |
| if( left.getDim1()>1 && left.getDim2()>1 |
| && right.getDim1()==1 && right.getDim2()>1 ) // MV (row vector) |
| { |
| //create new operators |
| ReorgOp trans = HopRewriteUtils.createTranspose(right); |
| AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(left, trans); |
| |
| //relink new child |
| HopRewriteUtils.replaceChildReference(parent, hi, mmult, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi, input); |
| hi = mmult; |
| |
| LOG.debug("Applied simplifyRowSumsMVMult"); |
| } |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyUnnecessaryAggregate(Hop parent, Hop hi, int pos) |
| { |
| // TODO implement for tensor |
| //e.g., sum(X) -> as.scalar(X) if 1x1 (applies to sum, min, max, prod, trace) |
| if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol ) |
| { |
| AggUnaryOp uhi = (AggUnaryOp)hi; |
| Hop input = uhi.getInput().get(0); |
| |
| if( HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_UNNECESSARY_AGGREGATE) ){ |
| |
| if( input.getDim1()==1 && input.getDim2()==1 && input.getDataType()==DataType.MATRIX) |
| { |
| UnaryOp cast = HopRewriteUtils.createUnary(input, OpOp1.CAST_AS_SCALAR); |
| |
| //remove unnecessary aggregation |
| HopRewriteUtils.replaceChildReference(parent, hi, cast, pos); |
| hi = cast; |
| |
| LOG.debug("Applied simplifyUnncessaryAggregate"); |
| } |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyEmptyAggregate(Hop parent, Hop hi, int pos) |
| { |
| if( hi instanceof AggUnaryOp ) |
| { |
| AggUnaryOp uhi = (AggUnaryOp)hi; |
| Hop input = uhi.getInput().get(0); |
| |
| //check for valid empty aggregates, except for matrices with zero rows/cols |
| if( HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_EMPTY_AGGREGATE) |
| && HopRewriteUtils.isEmpty(input) |
| && input.getDim1()>=1 && input.getDim2() >= 1 ) |
| { |
| Hop hnew = null; |
| if( uhi.getDirection() == Direction.RowCol ) |
| hnew = new LiteralOp(0.0); |
| else if( uhi.getDirection() == Direction.Col ) |
| hnew = HopRewriteUtils.createDataGenOp(uhi, input, 0); //nrow(uhi)=1 |
| else //if( uhi.getDirection() == Direction.Row ) |
| hnew = HopRewriteUtils.createDataGenOp(input, uhi, 0); //ncol(uhi)=1 |
| |
| //add new child to parent input |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| hi = hnew; |
| LOG.debug("Applied simplifyEmptyAggregate"); |
| } |
| } |
| return hi; |
| } |
| |
| private static Hop simplifyEmptyUnaryOperation(Hop parent, Hop hi, int pos) |
| { |
| if( hi instanceof UnaryOp ) |
| { |
| UnaryOp uhi = (UnaryOp)hi; |
| Hop input = uhi.getInput().get(0); |
| |
| if( HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_EMPTY_UNARY) ){ |
| |
| if( HopRewriteUtils.isEmpty(input) ) |
| { |
| //create literal add it to parent |
| Hop hnew = HopRewriteUtils.createDataGenOp(input, 0); |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| hi = hnew; |
| |
| LOG.debug("Applied simplifyEmptyUnaryOperation"); |
| } |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyEmptyReorgOperation(Hop parent, Hop hi, int pos) |
| { |
| if( hi instanceof ReorgOp ) |
| { |
| ReorgOp rhi = (ReorgOp)hi; |
| Hop input = rhi.getInput().get(0); |
| |
| if( HopRewriteUtils.isEmpty(input) ) //empty input |
| { |
| //reorg-operation-specific rewrite |
| Hop hnew = null; |
| if( rhi.getOp() == ReOrgOp.TRANS ) |
| hnew = HopRewriteUtils.createDataGenOp(input, true, input, true, 0); |
| else if( rhi.getOp() == ReOrgOp.REV ) |
| hnew = HopRewriteUtils.createDataGenOp(input, 0); |
| else if( rhi.getOp() == ReOrgOp.DIAG ) { |
| if( HopRewriteUtils.isDimsKnown(input) ) { |
| if( input.getDim2()==1 ) //diagv2m |
| hnew = HopRewriteUtils.createDataGenOp(input, false, input, true, 0); |
| else //diagm2v TODO support tensor operation |
| hnew = HopRewriteUtils.createDataGenOpByVal( |
| HopRewriteUtils.createValueHop(input,true), new LiteralOp(1), |
| null, DataType.MATRIX, ValueType.FP64, 0); |
| } |
| } |
| else if( rhi.getOp() == ReOrgOp.RESHAPE ) |
| hnew = HopRewriteUtils.createDataGenOpByVal(rhi.getInput().get(1), rhi.getInput().get(2), |
| rhi.getInput().get(3), rhi.getDataType(), rhi.getValueType(), 0); |
| |
| //modify dag if one of the above rules applied |
| if( hnew != null ){ |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| hi = hnew; |
| |
| LOG.debug("Applied simplifyEmptyReorgOperation"); |
| } |
| } |
| |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyEmptySortOperation(Hop parent, Hop hi, int pos) |
| { |
| //order(X, indexreturn=FALSE) -> matrix(0,nrow(X),1) |
| //order(X, indexreturn=TRUE) -> seq(1,nrow(X),1) |
| if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.SORT ) |
| { |
| ReorgOp rhi = (ReorgOp)hi; |
| Hop input = rhi.getInput().get(0); |
| |
| if( HopRewriteUtils.isEmpty(input) ) //empty input |
| { |
| //reorg-operation-specific rewrite |
| Hop hnew = null; |
| boolean ixret = false; |
| |
| if( rhi.getInput().get(3) instanceof LiteralOp ) //index return known |
| { |
| ixret = HopRewriteUtils.getBooleanValue((LiteralOp)rhi.getInput().get(3)); |
| if( ixret ) |
| hnew = HopRewriteUtils.createSeqDataGenOp(input); |
| else |
| hnew = HopRewriteUtils.createDataGenOp(input, 0); |
| } |
| |
| //modify dag if one of the above rules applied |
| if( hnew != null ){ |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| hi = hnew; |
| |
| LOG.debug("Applied simplifyEmptySortOperation (indexreturn="+ixret+")."); |
| } |
| } |
| |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyEmptyMatrixMult(Hop parent, Hop hi, int pos) { |
| if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y -> matrix(0, ) |
| { |
| Hop left = hi.getInput().get(0); |
| Hop right = hi.getInput().get(1); |
| |
| if( HopRewriteUtils.isEmpty(left) //one input empty |
| || HopRewriteUtils.isEmpty(right) ) |
| { |
| //create datagen and add it to parent |
| Hop hnew = HopRewriteUtils.createDataGenOp(left, right, 0); |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| hi = hnew; |
| |
| LOG.debug("Applied simplifyEmptyMatrixMult"); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyIdentityRepMatrixMult(Hop parent, Hop hi, int pos) |
| { |
| if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y -> X, if y is matrix(1,1,1) |
| { |
| Hop left = hi.getInput().get(0); |
| Hop right = hi.getInput().get(1); |
| |
| // X %*% y -> X |
| if( HopRewriteUtils.isDimsKnown(right) && right.getDim1()==1 && right.getDim2()==1 && //scalar right |
| right instanceof DataGenOp && ((DataGenOp)right).getOp()==OpOpDG.RAND |
| && ((DataGenOp)right).hasConstantValue(1.0)) //matrix(1,) |
| { |
| HopRewriteUtils.replaceChildReference(parent, hi, left, pos); |
| hi = left; |
| |
| LOG.debug("Applied simplifyIdentiyMatrixMult"); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyScalarMatrixMult(Hop parent, Hop hi, int pos) |
| { |
| if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y |
| { |
| Hop left = hi.getInput().get(0); |
| Hop right = hi.getInput().get(1); |
| |
| // y %*% X -> as.scalar(y) * X |
| if( HopRewriteUtils.isDimsKnown(left) && left.getDim1()==1 && left.getDim2()==1 ) //scalar left |
| { |
| UnaryOp cast = HopRewriteUtils.createUnary(left, OpOp1.CAST_AS_SCALAR); |
| BinaryOp mult = HopRewriteUtils.createBinary(cast, right, OpOp2.MULT); |
| |
| //add mult to parent |
| HopRewriteUtils.replaceChildReference(parent, hi, mult, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| |
| hi = mult; |
| LOG.debug("Applied simplifyScalarMatrixMult1"); |
| } |
| // X %*% y -> X * as.scalar(y) |
| else if( HopRewriteUtils.isDimsKnown(right) && right.getDim1()==1 && right.getDim2()==1 ) //scalar right |
| { |
| UnaryOp cast = HopRewriteUtils.createUnary(right, OpOp1.CAST_AS_SCALAR); |
| BinaryOp mult = HopRewriteUtils.createBinary(cast, left, OpOp2.MULT); |
| |
| //add mult to parent |
| HopRewriteUtils.replaceChildReference(parent, hi, mult, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| |
| hi = mult; |
| LOG.debug("Applied simplifyScalarMatrixMult2"); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyMatrixMultDiag(Hop parent, Hop hi, int pos) |
| { |
| Hop hnew = null; |
| |
| if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y |
| { |
| |
| Hop left = hi.getInput().get(0); |
| Hop right = hi.getInput().get(1); |
| |
| // diag(X) %*% Y -> X * Y / diag(X) %*% Y -> Y * X |
| // previously rep required for the second case: diag(X) %*% Y -> (X%*%ones) * Y |
| if( left instanceof ReorgOp && ((ReorgOp)left).getOp()==ReOrgOp.DIAG //left diag |
| && HopRewriteUtils.isDimsKnown(left) && left.getDim2()>1 ) //diagV2M |
| { |
| if( right.getDim2()==1 ) //right column vector |
| { |
| //create binary operation over input and right |
| Hop input = left.getInput().get(0); //diag input |
| hnew = HopRewriteUtils.createBinary(input, right, OpOp2.MULT); |
| |
| LOG.debug("Applied simplifyMatrixMultDiag1"); |
| } |
| else if( right.getDim2()>1 ) //multi column vector |
| { |
| //create binary operation over input and right; in contrast to above rewrite, |
| //we need to switch the order because MV binary cell operations require vector on the right |
| Hop input = left.getInput().get(0); //diag input |
| hnew = HopRewriteUtils.createBinary(right, input, OpOp2.MULT); |
| |
| //NOTE: previously to MV binary cell operations we replicated the left |
| //(if moderate number of columns: 2), but this is no longer required |
| |
| LOG.debug("Applied simplifyMatrixMultDiag2"); |
| } |
| } |
| |
| //notes: similar rewrites would be possible for the right side as well, just transposed into the right alignment |
| } |
| |
| //if one of the above rewrites applied |
| if( hnew !=null ){ |
| //add mult to parent |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| |
| hi = hnew; |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyDiagMatrixMult(Hop parent, Hop hi, int pos) |
| { |
| if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.DIAG && hi.getDim2()==1 ) //diagM2V |
| { |
| Hop hi2 = hi.getInput().get(0); |
| if( HopRewriteUtils.isMatrixMultiply(hi2) ) //X%*%Y |
| { |
| Hop left = hi2.getInput().get(0); |
| Hop right = hi2.getInput().get(1); |
| |
| //create new operators (incl refresh size inside for transpose) |
| ReorgOp trans = HopRewriteUtils.createTranspose(right); |
| BinaryOp mult = HopRewriteUtils.createBinary(left, trans, OpOp2.MULT); |
| AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(mult, AggOp.SUM, Direction.Row); |
| |
| //rehang new subdag under parent node |
| HopRewriteUtils.replaceChildReference(parent, hi, rowSum, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi, hi2); |
| |
| hi = rowSum; |
| LOG.debug("Applied simplifyDiagMatrixMult"); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifySumDiagToTrace(Hop hi) |
| { |
| if( hi instanceof AggUnaryOp ) |
| { |
| AggUnaryOp au = (AggUnaryOp) hi; |
| if( au.getOp()==AggOp.SUM && au.getDirection()==Direction.RowCol ) //sum |
| { |
| Hop hi2 = au.getInput().get(0); |
| if( hi2 instanceof ReorgOp && ((ReorgOp)hi2).getOp()==ReOrgOp.DIAG && hi2.getDim2()==1 ) //diagM2V |
| { |
| Hop hi3 = hi2.getInput().get(0); |
| |
| //remove diag operator |
| HopRewriteUtils.replaceChildReference(au, hi2, hi3, 0); |
| HopRewriteUtils.cleanupUnreferenced(hi2); |
| |
| //change sum to trace |
| au.setOp( AggOp.TRACE ); |
| |
| LOG.debug("Applied simplifySumDiagToTrace"); |
| } |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyLowerTriExtraction(Hop parent, Hop hi, int pos) { |
| //pattern: X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri (only right) |
| if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) |
| && hi.getDim1() == hi.getDim2() && hi.getDim1() > 1 ) { |
| Hop left = hi.getInput().get(0); |
| Hop right = hi.getInput().get(1); |
| |
| if( HopRewriteUtils.isUnary(right, OpOp1.CUMSUM) && right.getParent().size()==1 |
| && HopRewriteUtils.isReorg(right.getInput().get(0), ReOrgOp.DIAG) |
| && HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(0).getInput().get(0), 1d)) |
| { |
| LinkedHashMap<String,Hop> args = new LinkedHashMap<>(); |
| args.put("target", left); |
| args.put("diag", new LiteralOp(true)); |
| args.put("values", new LiteralOp(true)); |
| Hop hnew = HopRewriteUtils.createParameterizedBuiltinOp( |
| left, args, ParamBuiltinOp.LOWER_TRI); |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew); |
| HopRewriteUtils.removeAllChildReferences(right); |
| |
| hi = hnew; |
| LOG.debug("Applied simplifyLowerTriExtraction"); |
| } |
| } |
| return hi; |
| } |
| |
| @SuppressWarnings("unchecked") |
| private static Hop pushdownBinaryOperationOnDiag(Hop parent, Hop hi, int pos) |
| { |
| //diag(X)*7 --> diag(X*7) in order to (1) reduce required memory for b(*) and |
| //(2) in order to make the binary operation more efficient (dense vector vs sparse matrix) |
| if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) ) |
| { |
| Hop left = hi.getInput().get(0); |
| Hop right = hi.getInput().get(1); |
| |
| boolean applyLeft = false; |
| boolean applyRight = false; |
| |
| //left input is diag |
| if( left instanceof ReorgOp && ((ReorgOp)left).getOp()==ReOrgOp.DIAG |
| && left.getParent().size()==1 //binary op only parent |
| && left.getInput().get(0).getDim2()==1 //col vector |
| && right.getDataType() == DataType.SCALAR ) |
| { |
| applyLeft = true; |
| } |
| else if( right instanceof ReorgOp && ((ReorgOp)right).getOp()==ReOrgOp.DIAG |
| && right.getParent().size()==1 //binary op only parent |
| && right.getInput().get(0).getDim2()==1 //col vector |
| && left.getDataType() == DataType.SCALAR ) |
| { |
| applyRight = true; |
| } |
| |
| //perform actual rewrite |
| if( applyLeft || applyRight ) |
| { |
| //remove all parent links to binary op (since we want to reorder |
| //we cannot just look at the current parent) |
| ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone(); |
| ArrayList<Integer> parentspos = new ArrayList<>(); |
| for(Hop lparent : parents) { |
| int lpos = HopRewriteUtils.getChildReferencePos(lparent, hi); |
| HopRewriteUtils.removeChildReferenceByPos(lparent, hi, lpos); |
| parentspos.add(lpos); |
| } |
| |
| //rewire binop-diag-input into diag-binop-input |
| if( applyLeft ) { |
| Hop input = left.getInput().get(0); |
| HopRewriteUtils.removeChildReferenceByPos(hi, left, 0); |
| HopRewriteUtils.removeChildReferenceByPos(left, input, 0); |
| HopRewriteUtils.addChildReference(left, hi, 0); |
| HopRewriteUtils.addChildReference(hi, input, 0); |
| hi.refreshSizeInformation(); |
| hi = left; |
| } |
| else if ( applyRight ) { |
| Hop input = right.getInput().get(0); |
| HopRewriteUtils.removeChildReferenceByPos(hi, right, 1); |
| HopRewriteUtils.removeChildReferenceByPos(right, input, 0); |
| HopRewriteUtils.addChildReference(right, hi, 0); |
| HopRewriteUtils.addChildReference(hi, input, 1); |
| hi.refreshSizeInformation(); |
| hi = right; |
| } |
| |
| //relink all parents to the diag operation |
| for( int i=0; i<parents.size(); i++ ) { |
| Hop lparent = parents.get(i); |
| int lpos = parentspos.get(i); |
| HopRewriteUtils.addChildReference(lparent, hi, lpos); |
| } |
| |
| LOG.debug("Applied pushdownBinaryOperationOnDiag."); |
| } |
| } |
| |
| return hi; |
| } |
| |
| /** |
| * patterns: sum(A+B)->sum(A)+sum(B); sum(A-B)->sum(A)-sum(B) |
| * |
| * @param parent the parent high-level operator |
| * @param hi high-level operator |
| * @param pos position |
| * @return high-level operator |
| */ |
| private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos) |
| { |
| //all patterns headed by full sum over binary operation |
| if( hi instanceof AggUnaryOp //full sum root over binaryop |
| && ((AggUnaryOp)hi).getDirection()==Direction.RowCol |
| && ((AggUnaryOp)hi).getOp() == AggOp.SUM |
| && hi.getInput().get(0) instanceof BinaryOp |
| && hi.getInput().get(0).getParent().size()==1 ) //single parent |
| { |
| BinaryOp bop = (BinaryOp) hi.getInput().get(0); |
| Hop left = bop.getInput().get(0); |
| Hop right = bop.getInput().get(1); |
| |
| if( HopRewriteUtils.isEqualSize(left, right) //dims(A) == dims(B) |
| && left.getDataType() == DataType.MATRIX |
| && right.getDataType() == DataType.MATRIX ) |
| { |
| OpOp2 applyOp = ( bop.getOp() == OpOp2.PLUS //pattern a: sum(A+B)->sum(A)+sum(B) |
| || bop.getOp() == OpOp2.MINUS ) //pattern b: sum(A-B)->sum(A)-sum(B) |
| ? bop.getOp() : null; |
| |
| if( applyOp != null ) { |
| //create new subdag sum(A) bop sum(B) |
| AggUnaryOp sum1 = HopRewriteUtils.createSum(left); |
| AggUnaryOp sum2 = HopRewriteUtils.createSum(right); |
| BinaryOp newBin = HopRewriteUtils.createBinary(sum1, sum2, applyOp); |
| |
| //rewire new subdag |
| HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi, bop); |
| |
| hi = newBin; |
| |
| LOG.debug("Applied pushdownSumOnAdditiveBinary."); |
| } |
| } |
| } |
| |
| return hi; |
| } |
| |
| /** |
| * Searches for weighted squared loss expressions and replaces them with a quaternary operator. |
| * Currently, this search includes the following three patterns: |
| * 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting) |
| * 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting) |
| * 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting) |
| * 4) sumSq (X - U %*% t(V)) (no weighting sumSq) |
| * |
| * NOTE: We include transpose into the pattern because during runtime we need to compute |
| * U%*% t(V) pointwise; having V and not t(V) at hand allows for a cache-friendly implementation |
| * without additional memory requirements for internal transpose. |
| * |
| * This rewrite is conceptually a static rewrite; however, the current MR runtime only supports |
| * U/V factors of rank up to the blocksize (1000). We enforce this contraint here during the general |
| * rewrite because this is an uncommon case. Also, the intention is to remove this constaint as soon |
| * as we generalized the runtime or hop/lop compilation. |
| * |
| * @param parent parent high-level operator |
| * @param hi high-level operator |
| * @param pos position |
| * @return high-level operator |
| */ |
| private static Hop simplifyWeightedSquaredLoss(Hop parent, Hop hi, int pos) |
| { |
| //NOTE: there might be also a general simplification without custom operator |
| //via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2 |
| Hop hnew = null; |
| boolean appliedPattern = false; |
| |
| if( HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.RowCol) //all patterns rooted by sum() |
| && hi.getInput().get(0) instanceof BinaryOp //all patterns subrooted by binary op |
| && hi.getInput().get(0).getDim2() > 1 ) //not applied for vector-vector mult |
| { |
| BinaryOp bop = (BinaryOp) hi.getInput().get(0); |
| |
| //Pattern 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting) |
| //alternative pattern: sum (W * (U %*% t(V) - X) ^ 2) |
| if( bop.getOp()==OpOp2.MULT && HopRewriteUtils.isBinary(bop.getInput().get(1), OpOp2.POW) |
| && bop.getInput().get(0).getDataType()==DataType.MATRIX |
| && HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) //prevent mv |
| && HopRewriteUtils.isLiteralOfValue(bop.getInput().get(1).getInput().get(1), 2) ) |
| { |
| Hop W = bop.getInput().get(0); |
| Hop tmp = bop.getInput().get(1).getInput().get(0); //(X - U %*% t(V)) |
| |
| if( HopRewriteUtils.isBinary(tmp, OpOp2.MINUS) |
| && HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) //prevent mv |
| && tmp.getInput().get(0).getDataType() == DataType.MATRIX ) |
| { |
| //a) sum (W * (X - U %*% t(V)) ^ 2) |
| int uvIndex = -1; |
| if( tmp.getInput().get(1) instanceof AggBinaryOp //ba gurantees matrices |
| && HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) { //BLOCKSIZE CONSTRAINT |
| uvIndex = 1; |
| } |
| //b) sum (W * (U %*% t(V) - X) ^ 2) |
| else if(tmp.getInput().get(0) instanceof AggBinaryOp //ba gurantees matrices |
| && HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0),true)) { //BLOCKSIZE CONSTRAINT |
| uvIndex = 0; |
| } |
| |
| if( uvIndex >= 0 ) { //rewrite match |
| Hop X = tmp.getInput().get((uvIndex==0)?1:0); |
| Hop U = tmp.getInput().get(uvIndex).getInput().get(0); |
| Hop V = tmp.getInput().get(uvIndex).getInput().get(1); |
| V = !HopRewriteUtils.isTransposeOperation(V) ? |
| HopRewriteUtils.createTranspose(V) : V.getInput().get(0); |
| |
| //handle special case of post_nz |
| if( HopRewriteUtils.isNonZeroIndicator(W, X) ){ |
| W = new LiteralOp(1); |
| } |
| |
| //construct quaternary hop |
| hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, |
| ValueType.FP64, OpOp4.WSLOSS, X, U, V, W, true); |
| HopRewriteUtils.setOutputParametersForScalar(hnew); |
| |
| appliedPattern = true; |
| LOG.debug("Applied simplifyWeightedSquaredLoss1"+uvIndex+" (line "+hi.getBeginLine()+")"); |
| } |
| } |
| } |
| |
| //Pattern 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting) |
| //alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2) |
| if( !appliedPattern |
| && bop.getOp()==OpOp2.POW && HopRewriteUtils.isLiteralOfValue(bop.getInput().get(1), 2) |
| && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) |
| && HopRewriteUtils.isEqualMatrixSize((BinaryOp)bop.getInput().get(0))) |
| { |
| Hop lleft = bop.getInput().get(0).getInput().get(0); |
| Hop lright = bop.getInput().get(0).getInput().get(1); |
| |
| //a) sum ((X - W * (U %*% t(V))) ^ 2) |
| int wuvIndex = -1; |
| if( lright instanceof BinaryOp && lright.getInput().get(1) instanceof AggBinaryOp ){ |
| wuvIndex = 1; |
| } |
| //b) sum ((W * (U %*% t(V)) - X) ^ 2) |
| else if( lleft instanceof BinaryOp && lleft.getInput().get(1) instanceof AggBinaryOp ){ |
| wuvIndex = 0; |
| } |
| |
| if( wuvIndex >= 0 ) //rewrite match |
| { |
| Hop X = bop.getInput().get(0).getInput().get((wuvIndex==0)?1:0); |
| Hop tmp = bop.getInput().get(0).getInput().get(wuvIndex); //(W * (U %*% t(V))) |
| |
| if( ((BinaryOp)tmp).getOp()==OpOp2.MULT |
| && tmp.getInput().get(0).getDataType() == DataType.MATRIX |
| && HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) //prevent mv |
| && HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT |
| { |
| Hop W = tmp.getInput().get(0); |
| Hop U = tmp.getInput().get(1).getInput().get(0); |
| Hop V = tmp.getInput().get(1).getInput().get(1); |
| V = !HopRewriteUtils.isTransposeOperation(V) ? |
| HopRewriteUtils.createTranspose(V) : V.getInput().get(0); |
| hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, |
| ValueType.FP64, OpOp4.WSLOSS, X, U, V, W, false); |
| HopRewriteUtils.setOutputParametersForScalar(hnew); |
| appliedPattern = true; |
| LOG.debug("Applied simplifyWeightedSquaredLoss2"+wuvIndex+" (line "+hi.getBeginLine()+")"); |
| } |
| } |
| } |
| |
| //Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting) |
| //alternative pattern: sum (((U %*% t(V)) - X) ^ 2) |
| if( !appliedPattern |
| && bop.getOp()==OpOp2.POW && HopRewriteUtils.isLiteralOfValue(bop.getInput().get(1), 2) |
| && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) |
| && HopRewriteUtils.isEqualMatrixSize((BinaryOp)bop.getInput().get(0))) //prevent mv |
| { |
| Hop lleft = bop.getInput().get(0).getInput().get(0); |
| Hop lright = bop.getInput().get(0).getInput().get(1); |
| |
| //a) sum ((X - (U %*% t(V))) ^ 2) |
| int uvIndex = -1; |
| if( lright instanceof AggBinaryOp //ba guarantees matrices |
| && HopRewriteUtils.isSingleBlock(lright.getInput().get(0),true) ) { //BLOCKSIZE CONSTRAINT |
| uvIndex = 1; |
| } |
| //b) sum (((U %*% t(V)) - X) ^ 2) |
| else if( lleft instanceof AggBinaryOp //ba guarantees matrices |
| && HopRewriteUtils.isSingleBlock(lleft.getInput().get(0),true) ) { //BLOCKSIZE CONSTRAINT |
| uvIndex = 0; |
| } |
| |
| if( uvIndex >= 0 ) { //rewrite match |
| Hop X = bop.getInput().get(0).getInput().get((uvIndex==0)?1:0); |
| Hop tmp = bop.getInput().get(0).getInput().get(uvIndex); //(U %*% t(V)) |
| Hop W = new LiteralOp(1); //no weighting |
| Hop U = tmp.getInput().get(0); |
| Hop V = tmp.getInput().get(1); |
| V = !HopRewriteUtils.isTransposeOperation(V) ? |
| HopRewriteUtils.createTranspose(V) : V.getInput().get(0); |
| hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, |
| ValueType.FP64, OpOp4.WSLOSS, X, U, V, W, false); |
| HopRewriteUtils.setOutputParametersForScalar(hnew); |
| appliedPattern = true; |
| |
| LOG.debug("Applied simplifyWeightedSquaredLoss3"+uvIndex+" (line "+hi.getBeginLine()+")"); |
| } |
| } |
| } |
| |
| //Pattern 4) sumSq (X - U %*% t(V)) (no weighting) |
| //alternative pattern: sumSq (U %*% t(V) - X) |
| if( !appliedPattern |
| && HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM_SQ, Direction.RowCol) |
| && HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.MINUS) |
| && HopRewriteUtils.isEqualMatrixSize((BinaryOp)hi.getInput().get(0))) //prevent mv |
| { |
| Hop lleft = hi.getInput().get(0).getInput().get(0); |
| Hop lright = hi.getInput().get(0).getInput().get(1); |
| |
| //a) sumSq (X - U %*% t(V)) |
| int uvIndex = -1; |
| if( lright instanceof AggBinaryOp //ba guarantees matrices |
| && HopRewriteUtils.isSingleBlock(lright.getInput().get(0),true) ) { //BLOCKSIZE CONSTRAINT |
| uvIndex = 1; |
| } |
| //b) sumSq (U %*% t(V) - X) |
| else if( lleft instanceof AggBinaryOp //ba guarantees matrices |
| && HopRewriteUtils.isSingleBlock(lleft.getInput().get(0),true) ) { //BLOCKSIZE CONSTRAINT |
| uvIndex = 0; |
| } |
| |
| if( uvIndex >= 0 ) { //rewrite match |
| Hop X = hi.getInput().get(0).getInput().get((uvIndex==0)?1:0); |
| Hop tmp = hi.getInput().get(0).getInput().get(uvIndex); //(U %*% t(V)) |
| Hop W = new LiteralOp(1); //no weighting |
| Hop U = tmp.getInput().get(0); |
| Hop V = tmp.getInput().get(1); |
| V = !HopRewriteUtils.isTransposeOperation(V) ? |
| HopRewriteUtils.createTranspose(V) : V.getInput().get(0); |
| hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, |
| ValueType.FP64, OpOp4.WSLOSS, X, U, V, W, false); |
| HopRewriteUtils.setOutputParametersForScalar(hnew); |
| appliedPattern = true; |
| |
| LOG.debug("Applied simplifyWeightedSquaredLoss4"+uvIndex+" (line "+hi.getBeginLine()+")"); |
| } |
| } |
| |
| //relink new hop into original position |
| if( hnew != null ) { |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| hi = hnew; |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos) |
| { |
| Hop hnew = null; |
| |
| if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) //all patterns subrooted by W * |
| && hi.getDim2() > 1 //not applied for vector-vector mult |
| && HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv |
| && hi.getInput().get(0).getDataType()==DataType.MATRIX |
| && hi.getInput().get(1) instanceof UnaryOp ) //sigmoid/log |
| { |
| UnaryOp uop = (UnaryOp) hi.getInput().get(1); |
| boolean appliedPattern = false; |
| |
| //Pattern 1) W * sigmoid(Y%*%t(X)) (basic) |
| if( uop.getOp() == OpOp1.SIGMOID |
| && uop.getInput().get(0) instanceof AggBinaryOp |
| && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0),true) ) |
| { |
| Hop W = hi.getInput().get(0); |
| Hop Y = uop.getInput().get(0).getInput().get(0); |
| Hop tX = uop.getInput().get(0).getInput().get(1); |
| |
| if( !HopRewriteUtils.isTransposeOperation(tX) ) { |
| tX = HopRewriteUtils.createTranspose(tX); |
| } |
| else |
| tX = tX.getInput().get(0); |
| |
| hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, |
| OpOp4.WSIGMOID, W, Y, tX, false, false); |
| hnew.setBlocksize(W.getBlocksize()); |
| hnew.refreshSizeInformation(); |
| |
| appliedPattern = true; |
| LOG.debug("Applied simplifyWeightedSigmoid1 (line "+hi.getBeginLine()+")"); |
| } |
| |
| //Pattern 2) W * sigmoid(-(Y%*%t(X))) (minus) |
| if( !appliedPattern |
| && uop.getOp() == OpOp1.SIGMOID |
| && HopRewriteUtils.isBinary(uop.getInput().get(0), OpOp2.MINUS) |
| && uop.getInput().get(0).getInput().get(0) instanceof LiteralOp |
| && HopRewriteUtils.getDoubleValueSafe( |
| (LiteralOp)uop.getInput().get(0).getInput().get(0))==0 |
| && uop.getInput().get(0).getInput().get(1) instanceof AggBinaryOp |
| && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(1).getInput().get(0),true)) |
| { |
| Hop W = hi.getInput().get(0); |
| Hop Y = uop.getInput().get(0).getInput().get(1).getInput().get(0); |
| Hop tX = uop.getInput().get(0).getInput().get(1).getInput().get(1); |
| |
| if( !HopRewriteUtils.isTransposeOperation(tX) ) { |
| tX = HopRewriteUtils.createTranspose(tX); |
| } |
| else |
| tX = tX.getInput().get(0); |
| |
| hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, |
| OpOp4.WSIGMOID, W, Y, tX, false, true); |
| hnew.setBlocksize(W.getBlocksize()); |
| hnew.refreshSizeInformation(); |
| |
| appliedPattern = true; |
| LOG.debug("Applied simplifyWeightedSigmoid2 (line "+hi.getBeginLine()+")"); |
| } |
| |
| //Pattern 3) W * log(sigmoid(Y%*%t(X))) (log) |
| if( !appliedPattern |
| && uop.getOp() == OpOp1.LOG |
| && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) |
| && uop.getInput().get(0).getInput().get(0) instanceof AggBinaryOp |
| && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0).getInput().get(0),true) ) |
| { |
| Hop W = hi.getInput().get(0); |
| Hop Y = uop.getInput().get(0).getInput().get(0).getInput().get(0); |
| Hop tX = uop.getInput().get(0).getInput().get(0).getInput().get(1); |
| |
| if( !HopRewriteUtils.isTransposeOperation(tX) ) { |
| tX = HopRewriteUtils.createTranspose(tX); |
| } |
| else |
| tX = tX.getInput().get(0); |
| |
| hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, |
| OpOp4.WSIGMOID, W, Y, tX, true, false); |
| hnew.setBlocksize(W.getBlocksize()); |
| hnew.refreshSizeInformation(); |
| |
| appliedPattern = true; |
| LOG.debug("Applied simplifyWeightedSigmoid3 (line "+hi.getBeginLine()+")"); |
| } |
| |
| //Pattern 4) W * log(sigmoid(-(Y%*%t(X)))) (log_minus) |
| if( !appliedPattern |
| && uop.getOp() == OpOp1.LOG |
| && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) |
| && HopRewriteUtils.isBinary(uop.getInput().get(0).getInput().get(0), OpOp2.MINUS) ) |
| { |
| BinaryOp bop = (BinaryOp) uop.getInput().get(0).getInput().get(0); |
| |
| if( bop.getInput().get(0) instanceof LiteralOp |
| && HopRewriteUtils.getDoubleValueSafe((LiteralOp)bop.getInput().get(0))==0 |
| && bop.getInput().get(1) instanceof AggBinaryOp |
| && HopRewriteUtils.isSingleBlock(bop.getInput().get(1).getInput().get(0),true)) |
| { |
| Hop W = hi.getInput().get(0); |
| Hop Y = bop.getInput().get(1).getInput().get(0); |
| Hop tX = bop.getInput().get(1).getInput().get(1); |
| |
| if( !HopRewriteUtils.isTransposeOperation(tX) ) { |
| tX = HopRewriteUtils.createTranspose(tX); |
| } |
| else |
| tX = tX.getInput().get(0); |
| |
| hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, |
| OpOp4.WSIGMOID, W, Y, tX, true, true); |
| hnew.setBlocksize(W.getBlocksize()); |
| hnew.refreshSizeInformation(); |
| |
| appliedPattern = true; |
| LOG.debug("Applied simplifyWeightedSigmoid4 (line "+hi.getBeginLine()+")"); |
| } |
| } |
| } |
| |
| //relink new hop into original position |
| if( hnew != null ) { |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| hi = hnew; |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyWeightedDivMM(Hop parent, Hop hi, int pos) { |
| Hop hnew = null; |
| boolean appliedPattern = false; |
| |
| //left/right patterns rooted by 'ab - b(div)' or 'ab - b(mult)' |
| //note: we do not rewrite t(X)%*%(w*(X%*%v)) where w and v are vectors (see mmchain ops) |
| if( HopRewriteUtils.isMatrixMultiply(hi) |
| && (hi.getInput().get(0) instanceof BinaryOp |
| && HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput().get(0)).getOp(), LOOKUP_VALID_WDIVMM_BINARY) |
| || hi.getInput().get(1) instanceof BinaryOp |
| && hi.getDim2() > 1 //not applied for vector-vector mult |
| && HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput().get(1)).getOp(), LOOKUP_VALID_WDIVMM_BINARY)) ) |
| { |
| Hop left = hi.getInput().get(0); |
| Hop right = hi.getInput().get(1); |
| |
| //Pattern 1) t(U) %*% (W/(U%*%t(V))) |
| //alternative pattern: t(U) %*% (W*(U%*%t(V))) |
| if( right instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)right).getOp(),LOOKUP_VALID_WDIVMM_BINARY) |
| && HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) //prevent mv |
| && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1)) |
| && HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT |
| { |
| Hop W = right.getInput().get(0); |
| Hop U = right.getInput().get(1).getInput().get(0); |
| Hop V = right.getInput().get(1).getInput().get(1); |
| |
| if( HopRewriteUtils.isTransposeOfItself(left, U) ) |
| { |
| if( !HopRewriteUtils.isTransposeOperation(V) ) |
| V = HopRewriteUtils.createTranspose(V); |
| else |
| V = V.getInput().get(0); |
| |
| boolean mult = ((BinaryOp)right).getOp() == OpOp2.MULT; |
| hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, |
| OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 1, mult, false); |
| hnew.setBlocksize(W.getBlocksize()); |
| hnew.refreshSizeInformation(); |
| |
| //add output transpose for efficient target indexing (redundant t() removed by other rewrites) |
| hnew = HopRewriteUtils.createTranspose(hnew); |
| |
| appliedPattern = true; |
| LOG.debug("Applied simplifyWeightedDivMM1 (line "+hi.getBeginLine()+")"); |
| } |
| } |
| |
| //Pattern 1e) t(U) %*% (W/(U%*%t(V) + x)) |
| if( !appliedPattern |
| && HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[1]) //DIV |
| && HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) //prevent mv |
| && HopRewriteUtils.isBinary(right.getInput().get(1), OpOp2.PLUS) |
| && right.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR |
| && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) |
| && HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT |
| { |
| Hop W = right.getInput().get(0); |
| Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0); |
| Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1); |
| Hop X = right.getInput().get(1).getInput().get(1); |
| |
| if( HopRewriteUtils.isTransposeOfItself(left, U) ) |
| { |
| if( !HopRewriteUtils.isTransposeOperation(V) ) |
| V = HopRewriteUtils.createTranspose(V); |
| else |
| V = V.getInput().get(0); |
| |
| hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, |
| OpOp4.WDIVMM, W, U, V, X, 3, false, false); // 3=>DIV_LEFT_EPS |
| hnew.setBlocksize(W.getBlocksize()); |
| hnew.refreshSizeInformation(); |
| |
| //add output transpose for efficient target indexing (redundant t() removed by other rewrites) |
| hnew = HopRewriteUtils.createTranspose(hnew); |
| |
| appliedPattern = true; |
| LOG.debug("Applied simplifyWeightedDivMM1e (line "+hi.getBeginLine()+")"); |
| } |
| } |
| |
| //Pattern 2) (W/(U%*%t(V))) %*% V |
| //alternative pattern: (W*(U%*%t(V))) %*% V |
| if( !appliedPattern |
| && left instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)left).getOp(), LOOKUP_VALID_WDIVMM_BINARY) |
| && HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) //prevent mv |
| && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1)) |
| && HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT |
| { |
| Hop W = left.getInput().get(0); |
| Hop U = left.getInput().get(1).getInput().get(0); |
| Hop V = left.getInput().get(1).getInput().get(1); |
| |
| if( HopRewriteUtils.isTransposeOfItself(right, V) ) |
| { |
| if( !HopRewriteUtils.isTransposeOperation(V) ) |
| V = right; |
| else |
| V = V.getInput().get(0); |
| |
| boolean mult = ((BinaryOp)left).getOp() == OpOp2.MULT; |
| hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, |
| OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 2, mult, false); |
| hnew.setBlocksize(W.getBlocksize()); |
| hnew.refreshSizeInformation(); |
| |
| appliedPattern = true; |
| LOG.debug("Applied simplifyWeightedDivMM2 (line "+hi.getBeginLine()+")"); |
| } |
| } |
| |
| //Pattern 2e) (W/(U%*%t(V) + x)) %*% V |
| if( !appliedPattern |
| && HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[1]) //DIV |
| && HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) //prevent mv |
| && HopRewriteUtils.isBinary(left.getInput().get(1), OpOp2.PLUS) |
| && left.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR |
| && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) |
| && HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT |
| { |
| Hop W = left.getInput().get(0); |
| Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0); |
| Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1); |
| Hop X = left.getInput().get(1).getInput().get(1); |
| |
| if( HopRewriteUtils.isTransposeOfItself(right, V) ) |
| { |
| if( !HopRewriteUtils.isTransposeOperation(V) ) |
| V = right; |
| else |
| V = V.getInput().get(0); |
| |
| hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, |
| OpOp4.WDIVMM, W, U, V, X, 4, false, false); // 4=>DIV_RIGHT_EPS |
| hnew.setBlocksize(W.getBlocksize()); |
| hnew.refreshSizeInformation(); |
| |
| appliedPattern = true; |
| LOG.debug("Applied simplifyWeightedDivMM2e (line "+hi.getBeginLine()+")"); |
| } |
| } |
| |
| //Pattern 3) t(U) %*% ((X!=0)*(U%*%t(V)-X)) |
| if( !appliedPattern |
| && HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT |
| && HopRewriteUtils.isBinary(right.getInput().get(1), OpOp2.MINUS) |
| && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) |
| && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX |
| && HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT |
| { |
| Hop W = right.getInput().get(0); |
| Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0); |
| Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1); |
| Hop X = right.getInput().get(1).getInput().get(1); |
| |
| if( HopRewriteUtils.isNonZeroIndicator(W, X) //W-X constraint |
| && HopRewriteUtils.isTransposeOfItself(left, U) ) //t(U)-U constraint |
| { |
| if( !HopRewriteUtils.isTransposeOperation(V) ) |
| V = HopRewriteUtils.createTranspose(V); |
| else |
| V = V.getInput().get(0); |
| |
| hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, |
| OpOp4.WDIVMM, X, U, V, new LiteralOp(-1), 1, true, true); |
| hnew.setBlocksize(W.getBlocksize()); |
| hnew.refreshSizeInformation(); |
| |
| //add output transpose for efficient target indexing (redundant t() removed by other rewrites) |
| hnew = HopRewriteUtils.createTranspose(hnew); |
| |
| appliedPattern = true; |
| LOG.debug("Applied simplifyWeightedDivMM3 (line "+hi.getBeginLine()+")"); |
| } |
| } |
| |
| //Pattern 4) ((X!=0)*(U%*%t(V)-X)) %*% V |
| if( !appliedPattern |
| && HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT |
| && HopRewriteUtils.isBinary(left.getInput().get(1), OpOp2.MINUS) |
| && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) |
| && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX |
| && HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT |
| { |
| Hop W = left.getInput().get(0); |
| Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0); |
| Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1); |
| Hop X = left.getInput().get(1).getInput().get(1); |
| |
| if( HopRewriteUtils.isNonZeroIndicator(W, X) //W-X constraint |
| && HopRewriteUtils.isTransposeOfItself(right, V) ) //V-t(V) constraint |
| { |
| if( !HopRewriteUtils.isTransposeOperation(V) ) |
| V = right; |
| else |
| V = V.getInput().get(0); |
| |
| hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, |
| OpOp4.WDIVMM, X, U, V, new LiteralOp(-1), 2, true, true); |
| hnew.setBlocksize(W.getBlocksize()); |
| hnew.refreshSizeInformation(); |
| |
| appliedPattern = true; |
| LOG.debug("Applied simplifyWeightedDivMM4 (line "+hi.getBeginLine()+")"); |
| } |
| } |
| |
| //Pattern 5) t(U) %*% (W*(U%*%t(V)-X)) |
| if( !appliedPattern |
| && HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT |
| && HopRewriteUtils.isBinary(right.getInput().get(1), OpOp2.MINUS) |
| && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) |
| && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX |
| && HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT |
| { |
| Hop W = right.getInput().get(0); |
| Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0); |
| Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1); |
| Hop X = right.getInput().get(1).getInput().get(1); |
| |
| if( HopRewriteUtils.isTransposeOfItself(left, U) ) //t(U)-U constraint |
| { |
| if( !HopRewriteUtils.isTransposeOperation(V) ) |
| V = HopRewriteUtils.createTranspose(V); |
| else |
| V = V.getInput().get(0); |
| |
| //note: x and w exchanged compared to patterns 1-4, 7 |
| hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, |
| OpOp4.WDIVMM, W, U, V, X, 1, true, true); |
| hnew.setBlocksize(W.getBlocksize()); |
| hnew.refreshSizeInformation(); |
| |
| //add output transpose for efficient target indexing (redundant t() removed by other rewrites) |
| hnew = HopRewriteUtils.createTranspose(hnew); |
| |
| appliedPattern = true; |
| LOG.debug("Applied simplifyWeightedDivMM5 (line "+hi.getBeginLine()+")"); |
| } |
| } |
| |
| //Pattern 6) (W*(U%*%t(V)-X)) %*% V |
| if( !appliedPattern |
| && HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT |
| && HopRewriteUtils.isBinary(left.getInput().get(1), OpOp2.MINUS) |
| && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) |
| && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX |
| && HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT |
| { |
| Hop W = left.getInput().get(0); |
| Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0); |
| Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1); |
| Hop X = left.getInput().get(1).getInput().get(1); |
| |
| if( HopRewriteUtils.isTransposeOfItself(right, V) ) //V-t(V) constraint |
| { |
| if( !HopRewriteUtils.isTransposeOperation(V) ) |
| V = right; |
| else |
| V = V.getInput().get(0); |
| |
| //note: x and w exchanged compared to patterns 1-4, 7 |
| hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, |
| OpOp4.WDIVMM, W, U, V, X, 2, true, true); |
| hnew.setBlocksize(W.getBlocksize()); |
| hnew.refreshSizeInformation(); |
| |
| appliedPattern = true; |
| LOG.debug("Applied simplifyWeightedDivMM6 (line "+hi.getBeginLine()+")"); |
| } |
| } |
| } |
| |
| //Pattern 7) (W*(U%*%t(V))) |
| if( !appliedPattern |
| && HopRewriteUtils.isBinary(hi, LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT |
| && HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv |
| && hi.getDim2() > 1 //not applied for vector-vector mult |
| && hi.getInput().get(0).getDataType() == DataType.MATRIX |
| && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getBlocksize() |
| && HopRewriteUtils.isOuterProductLikeMM(hi.getInput().get(1)) |
| && (((AggBinaryOp) hi.getInput().get(1)).checkMapMultChain() == ChainType.NONE || hi.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain |
| && HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT |
| { |
| Hop W = hi.getInput().get(0); |
| Hop U = hi.getInput().get(1).getInput().get(0); |
| Hop V = hi.getInput().get(1).getInput().get(1); |
| |
| //for this basic pattern, we're more conservative and only apply wdivmm if |
| //W is sparse and U/V unknown or dense; or if U/V are dense |
| if( (HopRewriteUtils.isSparse(W) && !HopRewriteUtils.isSparse(U) && !HopRewriteUtils.isSparse(V)) |
| || (HopRewriteUtils.isDense(U) && HopRewriteUtils.isDense(V)) ) { |
| V = !HopRewriteUtils.isTransposeOperation(V) ? |
| HopRewriteUtils.createTranspose(V) : V.getInput().get(0); |
| hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, |
| OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 0, true, false); |
| hnew.setBlocksize(W.getBlocksize()); |
| hnew.refreshSizeInformation(); |
| appliedPattern = true; |
| LOG.debug("Applied simplifyWeightedDivMM7 (line "+hi.getBeginLine()+")"); |
| } |
| } |
| |
| //relink new hop into original position |
| if( hnew != null ) { |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| hi = hnew; |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyWeightedCrossEntropy(Hop parent, Hop hi, int pos) |
| { |
| Hop hnew = null; |
| boolean appliedPattern = false; |
| |
| if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol |
| && ((AggUnaryOp)hi).getOp() == AggOp.SUM //pattern rooted by sum() |
| && hi.getInput().get(0) instanceof BinaryOp //pattern subrooted by binary op |
| && hi.getInput().get(0).getDim2() > 1 ) //not applied for vector-vector mult |
| { |
| BinaryOp bop = (BinaryOp) hi.getInput().get(0); |
| Hop left = bop.getInput().get(0); |
| Hop right = bop.getInput().get(1); |
| |
| //Pattern 1) sum( X * log(U %*% t(V))) |
| if( bop.getOp()==OpOp2.MULT && left.getDataType()==DataType.MATRIX |
| && HopRewriteUtils.isEqualSize(left, right) //prevent mb |
| && HopRewriteUtils.isUnary(right, OpOp1.LOG) |
| && right.getInput().get(0) instanceof AggBinaryOp //ba gurantees matrices |
| && HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT |
| { |
| Hop X = left; |
| Hop U = right.getInput().get(0).getInput().get(0); |
| Hop V = right.getInput().get(0).getInput().get(1); |
| |
| if( !HopRewriteUtils.isTransposeOperation(V) ) |
| V = HopRewriteUtils.createTranspose(V); |
| else |
| V = V.getInput().get(0); |
| |
| hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.FP64, OpOp4.WCEMM, X, U, V, |
| new LiteralOp(0.0), 0, false, false); |
| hnew.setBlocksize(X.getBlocksize()); |
| appliedPattern = true; |
| |
| LOG.debug("Applied simplifyWeightedCEMM (line "+hi.getBeginLine()+")"); |
| } |
| |
| //Pattern 2) sum( X * log(U %*% t(V) + eps)) |
| if( !appliedPattern |
| && bop.getOp()==OpOp2.MULT && left.getDataType()==DataType.MATRIX |
| && HopRewriteUtils.isEqualSize(left, right) |
| && HopRewriteUtils.isUnary(right, OpOp1.LOG) |
| && HopRewriteUtils.isBinary(right.getInput().get(0), OpOp2.PLUS) |
| && right.getInput().get(0).getInput().get(0) instanceof AggBinaryOp |
| && right.getInput().get(0).getInput().get(1) instanceof LiteralOp |
| && right.getInput().get(0).getInput().get(1).getDataType() == DataType.SCALAR |
| && HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0).getInput().get(0),true)) |
| { |
| Hop X = left; |
| Hop U = right.getInput().get(0).getInput().get(0).getInput().get(0); |
| Hop V = right.getInput().get(0).getInput().get(0).getInput().get(1); |
| Hop eps = right.getInput().get(0).getInput().get(1); |
| |
| if( !HopRewriteUtils.isTransposeOperation(V) ) |
| V = HopRewriteUtils.createTranspose(V); |
| else |
| V = V.getInput().get(0); |
| |
| hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.FP64, |
| OpOp4.WCEMM, X, U, V, eps, 1, false, false); // 1 => BASIC_EPS |
| hnew.setBlocksize(X.getBlocksize()); |
| |
| LOG.debug("Applied simplifyWeightedCEMMEps (line "+hi.getBeginLine()+")"); |
| } |
| } |
| |
| //relink new hop into original position |
| if( hnew != null ) { |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| hi = hnew; |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyWeightedUnaryMM(Hop parent, Hop hi, int pos) { |
| Hop hnew = null; |
| boolean appliedPattern = false; |
| |
| //Pattern 1) (W*uop(U%*%t(V))) |
| if( hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(),LOOKUP_VALID_WDIVMM_BINARY) |
| && HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv |
| && hi.getDim2() > 1 //not applied for vector-vector mult |
| && hi.getInput().get(0).getDataType() == DataType.MATRIX |
| && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getBlocksize() |
| && hi.getInput().get(1) instanceof UnaryOp |
| && HopRewriteUtils.isValidOp(((UnaryOp)hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_UNARY) |
| && hi.getInput().get(1).getInput().get(0) instanceof AggBinaryOp |
| && HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT |
| { |
| Hop W = hi.getInput().get(0); |
| Hop U = hi.getInput().get(1).getInput().get(0).getInput().get(0); |
| Hop V = hi.getInput().get(1).getInput().get(0).getInput().get(1); |
| boolean mult = ((BinaryOp)hi).getOp()==OpOp2.MULT; |
| OpOp1 op = ((UnaryOp)hi.getInput().get(1)).getOp(); |
| |
| if( !HopRewriteUtils.isTransposeOperation(V) ) |
| V = HopRewriteUtils.createTranspose(V); |
| else |
| V = V.getInput().get(0); |
| |
| hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, |
| OpOp4.WUMM, W, U, V, mult, op, null); |
| hnew.setBlocksize(W.getBlocksize()); |
| hnew.refreshSizeInformation(); |
| |
| appliedPattern = true; |
| LOG.debug("Applied simplifyWeightedUnaryMM1 (line "+hi.getBeginLine()+")"); |
| } |
| |
| //Pattern 2.7) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V)) |
| if( !appliedPattern |
| && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(), OpOp2.MULT) |
| && (HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 2) |
| || HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 2))) |
| { |
| final Hop nl; // non-literal |
| if( hi.getInput().get(0) instanceof LiteralOp ) { |
| nl = hi.getInput().get(1); |
| } else { |
| nl = hi.getInput().get(0); |
| } |
| |
| if ( HopRewriteUtils.isBinary(nl, OpOp2.MULT) |
| && nl.getParent().size()==1 // ensure no foreign parents |
| && HopRewriteUtils.isEqualSize(nl.getInput().get(0), nl.getInput().get(1)) //prevent mv |
| && nl.getDim2() > 1 //not applied for vector-vector mult |
| && nl.getInput().get(0).getDataType() == DataType.MATRIX |
| && nl.getInput().get(0).getDim2() > nl.getInput().get(0).getBlocksize() |
| && HopRewriteUtils.isOuterProductLikeMM(nl.getInput().get(1)) |
| && (((AggBinaryOp) nl.getInput().get(1)).checkMapMultChain() == ChainType.NONE || nl.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain |
| && HopRewriteUtils.isSingleBlock(nl.getInput().get(1).getInput().get(0),true) ) |
| { |
| final Hop W = nl.getInput().get(0); |
| final Hop U = nl.getInput().get(1).getInput().get(0); |
| Hop V = nl.getInput().get(1).getInput().get(1); |
| if( !HopRewriteUtils.isTransposeOperation(V) ) |
| V = HopRewriteUtils.createTranspose(V); |
| else |
| V = V.getInput().get(0); |
| |
| hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, |
| OpOp4.WUMM, W, U, V, true, null, OpOp2.MULT); |
| hnew.setBlocksize(W.getBlocksize()); |
| hnew.refreshSizeInformation(); |
| |
| appliedPattern = true; |
| LOG.debug("Applied simplifyWeightedUnaryMM2.7 (line "+hi.getBeginLine()+")"); |
| } |
| } |
| |
| //Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to unary ops |
| if( !appliedPattern |
| && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(),LOOKUP_VALID_WDIVMM_BINARY) |
| && HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv |
| && hi.getDim2() > 1 //not applied for vector-vector mult |
| && hi.getInput().get(0).getDataType() == DataType.MATRIX |
| && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getBlocksize() |
| && hi.getInput().get(1) instanceof BinaryOp |
| && HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_BINARY) ) |
| { |
| Hop left = hi.getInput().get(1).getInput().get(0); |
| Hop right = hi.getInput().get(1).getInput().get(1); |
| Hop abop = null; |
| |
| //pattern 2a) matrix-scalar operations |
| if( right.getDataType()==DataType.SCALAR && right instanceof LiteralOp |
| && HopRewriteUtils.getDoubleValue((LiteralOp)right)==2 //pow2, mult2 |
| && left instanceof AggBinaryOp |
| && HopRewriteUtils.isSingleBlock(left.getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT |
| { |
| abop = left; |
| } |
| //pattern 2b) scalar-matrix operations |
| else if( left.getDataType()==DataType.SCALAR && left instanceof LiteralOp |
| && HopRewriteUtils.getDoubleValue((LiteralOp)left)==2 //mult2 |
| && ((BinaryOp)hi.getInput().get(1)).getOp() == OpOp2.MULT |
| && right instanceof AggBinaryOp |
| && HopRewriteUtils.isSingleBlock(right.getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT |
| { |
| abop = right; |
| } |
| |
| if( abop != null ) { |
| Hop W = hi.getInput().get(0); |
| Hop U = abop.getInput().get(0); |
| Hop V = abop.getInput().get(1); |
| boolean mult = ((BinaryOp)hi).getOp()==OpOp2.MULT; |
| OpOp2 op = ((BinaryOp)hi.getInput().get(1)).getOp(); |
| |
| if( !HopRewriteUtils.isTransposeOperation(V) ) |
| V = HopRewriteUtils.createTranspose(V); |
| else |
| V = V.getInput().get(0); |
| |
| hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.FP64, |
| OpOp4.WUMM, W, U, V, mult, null, op); |
| hnew.setBlocksize(W.getBlocksize()); |
| hnew.refreshSizeInformation(); |
| |
| appliedPattern = true; |
| LOG.debug("Applied simplifyWeightedUnaryMM2 (line "+hi.getBeginLine()+")"); |
| } |
| } |
| |
| |
| //relink new hop into original position |
| if( hnew != null ) { |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| hi = hnew; |
| } |
| |
| return hi; |
| } |
| |
| /** |
| * NOTE: dot-product-sum could be also applied to sum(a*b). However, we |
| * restrict ourselfs to sum(a^2) and transitively sum(a*a) since a general mm |
| * a%*%b on MR can be also counter-productive (e.g., MMCJ) while tsmm is always |
| * beneficial. |
| * |
| * @param parent parent high-level operator |
| * @param hi high-level operator |
| * @param pos position |
| * @return high-level operator |
| */ |
| private static Hop simplifyDotProductSum(Hop parent, Hop hi, int pos) { |
| //sum(v^2)/sum(v1*v2) --> as.scalar(t(v)%*%v) in order to exploit tsmm vector dotproduct |
| //w/o materialization of intermediates |
| if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getOp()==AggOp.SUM //sum |
| && ((AggUnaryOp)hi).getDirection()==Direction.RowCol //full aggregate |
| && hi.getInput().get(0).getDim2() == 1 ) //vector (for correctness) |
| { |
| Hop baLeft = null; |
| Hop baRight = null; |
| |
| Hop hi2 = hi.getInput().get(0); //check for ^2 w/o multiple consumers |
| //check for sum(v^2), might have been rewritten from sum(v*v) |
| if( HopRewriteUtils.isBinary(hi2, OpOp2.POW) |
| && hi2.getInput().get(1) instanceof LiteralOp |
| && HopRewriteUtils.getDoubleValue((LiteralOp)hi2.getInput().get(1))==2 |
| && hi2.getParent().size() == 1 ) //no other consumer than sum |
| { |
| Hop input = hi2.getInput().get(0); |
| baLeft = input; |
| baRight = input; |
| } |
| //check for sum(v1*v2), but prevent to rewrite sum(v1*v2*v3) which is later compiled into a ta+* lop |
| else if( HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1) //no other consumer than sum |
| && hi2.getInput().get(0).getDim2()==1 && hi2.getInput().get(1).getDim2()==1 |
| && !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT) |
| && !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT) |
| && ( !ALLOW_SUM_PRODUCT_REWRITES |
| || !( HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW) // do not rewrite (A^2)*B |
| && hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp // let tak+* handle it |
| && ((LiteralOp)hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2 )) |
| && ( !ALLOW_SUM_PRODUCT_REWRITES |
| || !( HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW) // do not rewrite B*(A^2) |
| && hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp // let tak+* handle it |
| && ((LiteralOp)hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2 )) |
| ) |
| { |
| baLeft = hi2.getInput().get(0); |
| baRight = hi2.getInput().get(1); |
| } |
| |
| //perform actual rewrite (if necessary) |
| if( baLeft != null && baRight != null ) |
| { |
| //create new operator chain |
| ReorgOp trans = HopRewriteUtils.createTranspose(baLeft); |
| AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, baRight); |
| UnaryOp cast = HopRewriteUtils.createUnary(mmult, OpOp1.CAST_AS_SCALAR); |
| |
| //rehang new subdag under parent node |
| HopRewriteUtils.replaceChildReference(parent, hi, cast, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi, hi2); |
| |
| hi = cast; |
| |
| LOG.debug("Applied simplifyDotProductSum."); |
| } |
| } |
| |
| return hi; |
| } |
| |
| /** |
| * Replace SUM(X^2) with a fused SUM_SQ(X) HOP. |
| * |
| * @param parent Parent HOP for which hi is an input. |
| * @param hi Current HOP for potential rewrite. |
| * @param pos Position of hi in parent's list of inputs. |
| * |
| * @return Either hi or the rewritten HOP replacing it. |
| */ |
| private static Hop fuseSumSquared(Hop parent, Hop hi, int pos) { |
| // if SUM |
| if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM) { |
| Hop sumInput = hi.getInput().get(0); |
| |
| // if input to SUM is POW(X,2), and no other consumers of the POW(X,2) HOP |
| if( HopRewriteUtils.isBinary(sumInput, OpOp2.POW) |
| && sumInput.getInput().get(1) instanceof LiteralOp |
| && HopRewriteUtils.getDoubleValue((LiteralOp) sumInput.getInput().get(1)) == 2 |
| && sumInput.getParent().size() == 1) { |
| Hop x = sumInput.getInput().get(0); |
| |
| // if X is NOT a column vector |
| if (x.getDim2() > 1) { |
| // perform rewrite from SUM(POW(X,2)) to SUM_SQ(X) |
| Direction dir = ((AggUnaryOp) hi).getDirection(); |
| AggUnaryOp sumSq = HopRewriteUtils.createAggUnaryOp(x, AggOp.SUM_SQ, dir); |
| HopRewriteUtils.replaceChildReference(parent, hi, sumSq, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi, sumInput); |
| hi = sumSq; |
| |
| LOG.debug("Applied fuseSumSquared (line " +hi.getBeginLine()+")."); |
| } |
| } |
| } |
| return hi; |
| } |
| |
| private static Hop fuseAxpyBinaryOperationChain(Hop parent, Hop hi, int pos) |
| { |
| //patterns: (a) X + s*Y -> X +* sY, (b) s*Y+X -> X +* sY, (c) X - s*Y -> X -* sY |
| if( hi instanceof BinaryOp && !((BinaryOp) hi).isOuter() |
| && (((BinaryOp)hi).getOp()==OpOp2.PLUS || ((BinaryOp)hi).getOp()==OpOp2.MINUS) ) |
| { |
| BinaryOp bop = (BinaryOp) hi; |
| Hop left = bop.getInput().get(0); |
| Hop right = bop.getInput().get(1); |
| Hop ternop = null; |
| |
| //pattern (a) X + s*Y -> X +* sY |
| if( bop.getOp() == OpOp2.PLUS && left.getDataType()==DataType.MATRIX |
| && HopRewriteUtils.isScalarMatrixBinaryMult(right) |
| && HopRewriteUtils.isEqualSize(left, right) |
| && right.getParent().size() == 1 ) //single consumer s*Y |
| { |
| Hop smid = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); |
| Hop mright = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0); |
| ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)smid)==0) ? |
| left : HopRewriteUtils.createTernary(left, smid, mright, OpOp3.PLUS_MULT); |
| LOG.debug("Applied fuseAxpyBinaryOperationChain1. (line " +hi.getBeginLine()+")"); |
| } |
| //pattern (b) s*Y + X -> X +* sY |
| else if( bop.getOp() == OpOp2.PLUS && right.getDataType()==DataType.MATRIX |
| && HopRewriteUtils.isScalarMatrixBinaryMult(left) |
| && HopRewriteUtils.isEqualSize(left, right) |
| && left.getParent().size() == 1 ) //single consumer s*Y |
| { |
| Hop smid = left.getInput().get( (left.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); |
| Hop mright = left.getInput().get( (left.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0); |
| ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)smid)==0) ? |
| right : HopRewriteUtils.createTernary(right, smid, mright, OpOp3.PLUS_MULT); |
| LOG.debug("Applied fuseAxpyBinaryOperationChain2. (line " +hi.getBeginLine()+")"); |
| } |
| //pattern (c) X - s*Y -> X -* sY |
| else if( bop.getOp() == OpOp2.MINUS && left.getDataType()==DataType.MATRIX |
| && HopRewriteUtils.isScalarMatrixBinaryMult(right) |
| && HopRewriteUtils.isEqualSize(left, right) |
| && right.getParent().size() == 1 ) //single consumer s*Y |
| { |
| Hop smid = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); |
| Hop mright = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0); |
| ternop = (smid instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)smid)==0) ? |
| left : HopRewriteUtils.createTernary(left, smid, mright, OpOp3.MINUS_MULT); |
| LOG.debug("Applied fuseAxpyBinaryOperationChain3. (line " +hi.getBeginLine()+")"); |
| } |
| |
| //rewire parent-child operators if rewrite applied |
| if( ternop != null ) { |
| HopRewriteUtils.replaceChildReference(parent, hi, ternop, pos); |
| hi = ternop; |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyEmptyBinaryOperation(Hop parent, Hop hi, int pos) |
| { |
| if( hi instanceof BinaryOp ) //b(?) X Y |
| { |
| BinaryOp bop = (BinaryOp) hi; |
| Hop left = hi.getInput().get(0); |
| Hop right = hi.getInput().get(1); |
| |
| if( left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX ) |
| { |
| Hop hnew = null; |
| |
| //NOTE: these rewrites of binary cell operations need to be aware that right is |
| //potentially a vector but the result is of the size of left |
| |
| boolean notBinaryMV = HopRewriteUtils.isNotMatrixVectorBinaryOperation(bop); |
| |
| switch( bop.getOp() ){ |
| //X * Y -> matrix(0,nrow(X),ncol(X)); |
| case MULT: { |
| if( HopRewriteUtils.isEmpty(left) ) //empty left and size known |
| hnew = HopRewriteUtils.createDataGenOp(left, left, 0); |
| else if( HopRewriteUtils.isEmpty(right) //empty right and right not a vector |
| && right.getDim1()>1 && right.getDim2()>1 ) { |
| hnew = HopRewriteUtils.createDataGenOp(right, right, 0); |
| } |
| else if( HopRewriteUtils.isEmpty(right) )//empty right and right potentially a vector |
| hnew = HopRewriteUtils.createDataGenOp(left, left, 0); |
| break; |
| } |
| case PLUS: { |
| if( HopRewriteUtils.isEmpty(left) && HopRewriteUtils.isEmpty(right) ) //empty left/right and size known |
| hnew = HopRewriteUtils.createDataGenOp(left, left, 0); |
| else if( HopRewriteUtils.isEmpty(left) && notBinaryMV ) //empty left |
| hnew = right; |
| else if( HopRewriteUtils.isEmpty(right) ) //empty right |
| hnew = left; |
| break; |
| } |
| case MINUS: { |
| if( HopRewriteUtils.isEmpty(left) && notBinaryMV ) { //empty left |
| HopRewriteUtils.removeChildReference(hi, left); |
| HopRewriteUtils.addChildReference(hi, new LiteralOp(0), 0); |
| hnew = hi; |
| } |
| else if( HopRewriteUtils.isEmpty(right) ) //empty and size known |
| hnew = left; |
| break; |
| } |
| default: |
| //do nothing (hnew = null) |
| } |
| |
| if( hnew != null ) { |
| //create datagen and add it to parent |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| hi = hnew; |
| |
| LOG.debug("Applied simplifyEmptyBinaryOperation"); |
| } |
| } |
| } |
| |
| return hi; |
| } |
| |
| /** |
| * This is rewrite tries to reorder minus operators from inputs of matrix |
| * multiply to its output because the output is (except for outer products) |
| * usually significantly smaller. Furthermore, this rewrite is a precondition |
| * for the important hops-lops rewrite of transpose-matrixmult if the transpose |
| * is hidden under the minus. |
| * |
| * NOTE: in this rewrite we need to modify the links to all parents because we |
| * remove existing links of subdags and hence affect all consumers. |
| * |
| * @param parent the parent high-level operator |
| * @param hi high-level operator |
| * @param pos position |
| * @return high-level operator |
| */ |
| @SuppressWarnings("unchecked") |
| private static Hop reorderMinusMatrixMult(Hop parent, Hop hi, int pos) |
| { |
| if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y |
| { |
| Hop hileft = hi.getInput().get(0); |
| Hop hiright = hi.getInput().get(1); |
| |
| if( HopRewriteUtils.isBinary(hileft, OpOp2.MINUS) //X=-Z |
| && hileft.getInput().get(0) instanceof LiteralOp |
| && HopRewriteUtils.getDoubleValue((LiteralOp)hileft.getInput().get(0))==0.0 |
| && hi.dimsKnown() && hileft.getInput().get(1).dimsKnown() //size comparison |
| && HopRewriteUtils.compareSize(hi, hileft.getInput().get(1)) < 0 ) |
| { |
| Hop hi2 = hileft.getInput().get(1); |
| |
| //remove link from matrixmult to minus |
| HopRewriteUtils.removeChildReference(hi, hileft); |
| |
| //get old parents (before creating minus over matrix mult) |
| ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone(); |
| |
| //create new operators |
| BinaryOp minus = HopRewriteUtils.createBinary(new LiteralOp(0), hi, OpOp2.MINUS); |
| |
| //rehang minus under all parents |
| for( Hop p : parents ) { |
| int ix = HopRewriteUtils.getChildReferencePos(p, hi); |
| HopRewriteUtils.removeChildReference(p, hi); |
| HopRewriteUtils.addChildReference(p, minus, ix); |
| } |
| |
| //rehang child of minus under matrix mult |
| HopRewriteUtils.addChildReference(hi, hi2, 0); |
| |
| //cleanup if only consumer of minus |
| HopRewriteUtils.cleanupUnreferenced(hileft); |
| |
| hi = minus; |
| |
| LOG.debug("Applied reorderMinusMatrixMult (line "+hi.getBeginLine()+")."); |
| } |
| else if( HopRewriteUtils.isBinary(hiright, OpOp2.MINUS) //X=-Z |
| && hiright.getInput().get(0) instanceof LiteralOp |
| && HopRewriteUtils.getDoubleValue((LiteralOp)hiright.getInput().get(0))==0.0 |
| && hi.dimsKnown() && hiright.getInput().get(1).dimsKnown() //size comparison |
| && HopRewriteUtils.compareSize(hi, hiright.getInput().get(1)) < 0 ) |
| { |
| Hop hi2 = hiright.getInput().get(1); |
| |
| //remove link from matrixmult to minus |
| HopRewriteUtils.removeChildReference(hi, hiright); |
| |
| //get old parents (before creating minus over matrix mult) |
| ArrayList<Hop> parents = (ArrayList<Hop>) hi.getParent().clone(); |
| |
| //create new operators |
| BinaryOp minus = HopRewriteUtils.createBinary(new LiteralOp(0), hi, OpOp2.MINUS); |
| |
| //rehang minus under all parents |
| for( Hop p : parents ) { |
| int ix = HopRewriteUtils.getChildReferencePos(p, hi); |
| HopRewriteUtils.removeChildReference(p, hi); |
| HopRewriteUtils.addChildReference(p, minus, ix); |
| } |
| |
| //rehang child of minus under matrix mult |
| HopRewriteUtils.addChildReference(hi, hi2, 1); |
| |
| //cleanup if only consumer of minus |
| HopRewriteUtils.cleanupUnreferenced(hiright); |
| |
| hi = minus; |
| |
| LOG.debug("Applied reorderMinusMatrixMult (line "+hi.getBeginLine()+")."); |
| } |
| } |
| |
| return hi; |
| } |
| |
| |
| private static Hop simplifySumMatrixMult(Hop parent, Hop hi, int pos) |
| { |
| //sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), later rewritten to dot-product |
| //colSums(A%*%B) -> colSums(A)%*%B |
| //rowSums(A%*%B) -> A%*%rowSums(B) |
| //-- if not dot product, not applied since aggregate removed |
| //-- if sum not the only consumer, not applied to prevent redundancy |
| if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getOp()==AggOp.SUM //sum |
| && hi.getInput().get(0) instanceof AggBinaryOp //A%*%B |
| && (hi.getInput().get(0).getDim1()>1 || hi.getInput().get(0).getDim2()>1) //not dot product |
| && hi.getInput().get(0).getParent().size()==1 ) //not multiple consumers of matrix mult |
| { |
| Hop hi2 = hi.getInput().get(0); |
| Hop left = hi2.getInput().get(0); |
| Hop right = hi2.getInput().get(1); |
| |
| //remove link from parent to matrix mult |
| HopRewriteUtils.removeChildReference(hi, hi2); |
| |
| //create new operators |
| Hop root = null; |
| //pattern: sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), later rewritten to dot-product |
| if( ((AggUnaryOp)hi).getDirection() == Direction.RowCol ) { |
| AggUnaryOp colSum = HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col); |
| ReorgOp trans = HopRewriteUtils.createTranspose(colSum); |
| AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row); |
| root = HopRewriteUtils.createBinary(trans, rowSum, OpOp2.MULT); |
| LOG.debug("Applied simplifySumMatrixMult RC."); |
| } |
| //colSums(A%*%B) -> colSums(A)%*%B |
| else if( ((AggUnaryOp)hi).getDirection() == Direction.Col ) { |
| AggUnaryOp colSum = HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col); |
| root = HopRewriteUtils.createMatrixMultiply(colSum, right); |
| LOG.debug("Applied simplifySumMatrixMult C."); |
| } |
| //rowSums(A%*%B) -> A%*%rowSums(B) |
| else if( ((AggUnaryOp)hi).getDirection() == Direction.Row ) { |
| AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row); |
| root = HopRewriteUtils.createMatrixMultiply(left, rowSum); |
| LOG.debug("Applied simplifySumMatrixMult R."); |
| } |
| |
| //rehang new subdag under current node (keep hi intact) |
| HopRewriteUtils.addChildReference(hi, root, 0); |
| hi.refreshSizeInformation(); |
| |
| //cleanup if only consumer of intermediate |
| HopRewriteUtils.cleanupUnreferenced(hi2); |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyScalarMVBinaryOperation(Hop hi) |
| { |
| if( hi instanceof BinaryOp && ((BinaryOp)hi).supportsMatrixScalarOperations() //e.g., X * s |
| && hi.getInput().get(0).getDataType()==DataType.MATRIX |
| && hi.getInput().get(1).getDataType()==DataType.MATRIX ) |
| { |
| Hop right = hi.getInput().get(1); |
| |
| //X * s -> X * as.scalar(s) |
| if( HopRewriteUtils.isDimsKnown(right) && right.getDim1()==1 && right.getDim2()==1 ) //scalar right |
| { |
| //remove link to right child and introduce cast |
| UnaryOp cast = HopRewriteUtils.createUnary(right, OpOp1.CAST_AS_SCALAR); |
| HopRewriteUtils.replaceChildReference(hi, right, cast, 1); |
| |
| LOG.debug("Applied simplifyScalarMVBinaryOperation."); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyNnzComputation(Hop parent, Hop hi, int pos) |
| { |
| //sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known |
| if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getOp()==AggOp.SUM //sum |
| && ((AggUnaryOp)hi).getDirection() == Direction.RowCol //full aggregate |
| && HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.NOTEQUAL) ) |
| { |
| Hop ppred = hi.getInput().get(0); |
| Hop X = null; |
| if( ppred.getInput().get(0) instanceof LiteralOp |
| && HopRewriteUtils.getDoubleValue((LiteralOp)ppred.getInput().get(0))==0 ) |
| { |
| X = ppred.getInput().get(1); |
| } |
| else if( ppred.getInput().get(1) instanceof LiteralOp |
| && HopRewriteUtils.getDoubleValue((LiteralOp)ppred.getInput().get(1))==0 ) |
| { |
| X = ppred.getInput().get(0); |
| } |
| |
| //apply rewrite if known nnz |
| if( X != null && X.getNnz() > 0 ){ |
| Hop hnew = new LiteralOp(X.getNnz()); |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| hi = hnew; |
| |
| LOG.debug("Applied simplifyNnzComputation."); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyNrowNcolComputation(Hop parent, Hop hi, int pos) |
| { |
| //nrow(X) -> literal(nrow(X)), ncol(X) -> literal(ncol(X)), if respective dims known |
| //(this rewrite aims to remove unnecessary data dependencies to X which trigger computation |
| //even if the intermediate is otherwise not required, e.g., when part of a fused operator) |
| if( hi instanceof UnaryOp ) |
| { |
| if( ((UnaryOp)hi).getOp()==OpOp1.NROW && hi.getInput().get(0).rowsKnown() ) { |
| Hop hnew = new LiteralOp(hi.getInput().get(0).getDim1()); |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos, false); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| LOG.debug("Applied simplifyNrowComputation nrow("+hi.getHopID()+") -> " |
| + hnew.getName()+" (line "+hi.getBeginLine()+")."); |
| hi = hnew; |
| } |
| else if( ((UnaryOp)hi).getOp()==OpOp1.NCOL && hi.getInput().get(0).colsKnown() ) { |
| Hop hnew = new LiteralOp(hi.getInput().get(0).getDim2()); |
| HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos, false); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| LOG.debug("Applied simplifyNcolComputation ncol("+hi.getHopID()+") -> " |
| + hnew.getName()+" (line "+hi.getBeginLine()+")."); |
| hi = hnew; |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop simplifyTableSeqExpand(Hop parent, Hop hi, int pos) |
| { |
| //pattern: table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true) |
| //note: this rewrite supports both left/right sequence |
| if( hi instanceof TernaryOp && hi.getInput().size()==5 //table without weights |
| && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(2), 1) ) //i.e., weight of 1 |
| { |
| Hop first = hi.getInput().get(0); |
| Hop second = hi.getInput().get(1); |
| |
| //pattern a: table(seq(1,nrow(v)), v, nrow(v), m, 1) |
| if( HopRewriteUtils.isBasic1NSequence(first, second, true) |
| && HopRewriteUtils.isSizeExpressionOf(hi.getInput().get(3), second, true) ) |
| { |
| //setup input parameter hops |
| LinkedHashMap<String,Hop> args = new LinkedHashMap<>(); |
| args.put("target", second); |
| args.put("max", hi.getInput().get(4)); |
| args.put("dir", new LiteralOp("cols")); |
| args.put("ignore", new LiteralOp(false)); |
| args.put("cast", new LiteralOp(true)); |
| |
| //create new hop |
| ParameterizedBuiltinOp pbop = HopRewriteUtils |
| .createParameterizedBuiltinOp(second, args, ParamBuiltinOp.REXPAND); |
| HopRewriteUtils.replaceChildReference(parent, hi, pbop, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| hi = pbop; |
| |
| LOG.debug("Applied simplifyTableSeqExpand1 (line "+hi.getBeginLine()+")"); |
| } |
| //pattern b: table(v, seq(1,nrow(v)), m, nrow(v)) |
| else if( HopRewriteUtils.isBasic1NSequence(second, first, true) |
| && HopRewriteUtils.isSizeExpressionOf(hi.getInput().get(4), first, true) ) |
| { |
| //setup input parameter hops |
| LinkedHashMap<String,Hop> args = new LinkedHashMap<>(); |
| args.put("target", first); |
| args.put("max", hi.getInput().get(3)); |
| args.put("dir", new LiteralOp("rows")); |
| args.put("ignore", new LiteralOp(false)); |
| args.put("cast", new LiteralOp(true)); |
| |
| //create new hop |
| ParameterizedBuiltinOp pbop = HopRewriteUtils |
| .createParameterizedBuiltinOp(first, args, ParamBuiltinOp.REXPAND); |
| HopRewriteUtils.replaceChildReference(parent, hi, pbop, pos); |
| HopRewriteUtils.cleanupUnreferenced(hi); |
| hi = pbop; |
| |
| LOG.debug("Applied simplifyTableSeqExpand2 (line "+hi.getBeginLine()+")"); |
| } |
| } |
| |
| return hi; |
| } |
| |
| private static Hop foldMultipleMinMaxOperations(Hop hi) |
| { |
| if( (HopRewriteUtils.isBinary(hi, OpOp2.MIN, OpOp2.MAX, OpOp2.PLUS) |
| || HopRewriteUtils.isNary(hi, OpOpN.MIN, OpOpN.MAX, OpOpN.PLUS)) |
| && hi.getValueType() != ValueType.STRING //exclude string concat |
| && HopRewriteUtils.isNotMatrixVectorBinaryOperation(hi)) |
| { |
| OpOp2 bop = (hi instanceof BinaryOp) ? ((BinaryOp)hi).getOp() : |
| OpOp2.valueOf(((NaryOp)hi).getOp().name()); |
| OpOpN nop = (hi instanceof NaryOp) ? ((NaryOp)hi).getOp() : |
| OpOpN.valueOf(((BinaryOp)hi).getOp().name()); |
| |
| boolean converged = false; |
| while( !converged ) { |
| //get first matching min/max |
| Hop first = hi.getInput().stream() |
| .filter(h -> HopRewriteUtils.isBinary(h, bop) || HopRewriteUtils.isNary(h, nop)) |
| .findFirst().orElse(null); |
| |
| //replace current op with new nary min/max |
| final Hop lhi = hi; |
| if( first != null && first.getParent().size()==1 |
| && first.getInput().stream().allMatch(c -> c.getDataType()==DataType.SCALAR |
| || HopRewriteUtils.isEqualSize(lhi, c))) { |
| //construct new list of inputs (in original order) |
| ArrayList<Hop> linputs = new ArrayList<>(); |
| for(Hop in : hi.getInput()) |
| if( in == first ) |
| linputs.addAll(first.getInput()); |
| else |
| linputs.add(in); |
| Hop hnew = HopRewriteUtils.createNary(nop, linputs.toArray(new Hop[0])); |
| //clear dangling references |
| HopRewriteUtils.removeAllChildReferences(hi); |
| HopRewriteUtils.removeAllChildReferences(first); |
| //rewire all parents (avoid anomalies with refs to hi) |
| List<Hop> parents = new ArrayList<>(hi.getParent()); |
| for( Hop p : parents ) |
| HopRewriteUtils.replaceChildReference(p, hi, hnew); |
| hi = hnew; |
| LOG.debug("Applied foldMultipleMinMaxPlusOperations (line "+hi.getBeginLine()+")."); |
| } |
| else { |
| converged = true; |
| } |
| } |
| } |
| |
| return hi; |
| } |
| } |