| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysml.hops.rewrite; |
| |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| |
| import org.apache.commons.logging.Log; |
| import org.apache.commons.logging.LogFactory; |
| import org.apache.log4j.Level; |
| import org.apache.log4j.Logger; |
| |
| import org.apache.sysml.hops.AggBinaryOp; |
| import org.apache.sysml.hops.Hop; |
| import org.apache.sysml.hops.HopsException; |
| import org.apache.sysml.utils.Explain; |
| |
| /** |
| * Rule: Determine the optimal order of execution for a chain of |
| * matrix multiplications Solution: Classic Dynamic Programming |
| * Approach Currently, the approach based only on matrix dimensions |
| * Goal: To reduce the number of computations in the run-time |
| * (map-reduce) layer |
| */ |
| public class RewriteMatrixMultChainOptimization extends HopRewriteRule |
| { |
| |
| private static final Log LOG = LogFactory.getLog(RewriteMatrixMultChainOptimization.class.getName()); |
| private static final boolean LDEBUG = false; |
| |
| static |
| { |
| // for internal debugging only |
| if( LDEBUG ) { |
| Logger.getLogger("org.apache.sysml.hops.rewrite.RewriteMatrixMultChainOptimization") |
| .setLevel((Level) Level.TRACE); |
| } |
| } |
| |
| @Override |
| public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) |
| throws HopsException |
| { |
| if( roots == null ) |
| return null; |
| |
| for( Hop h : roots ) |
| { |
| // Find the optimal order for the chain whose result is the current HOP |
| rule_OptimizeMMChains(h); |
| } |
| |
| return roots; |
| } |
| |
| @Override |
| public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) |
| throws HopsException |
| { |
| if( root == null ) |
| return null; |
| |
| // Find the optimal order for the chain whose result is the current HOP |
| rule_OptimizeMMChains(root); |
| |
| return root; |
| } |
| |
| /** |
| * rule_OptimizeMMChains(): This method recurses through all Hops in the DAG |
| * to find chains that need to be optimized. |
| */ |
| private void rule_OptimizeMMChains(Hop hop) |
| throws HopsException |
| { |
| if(hop.getVisited() == Hop.VisitStatus.DONE) |
| return; |
| |
| if ( hop instanceof AggBinaryOp && ((AggBinaryOp) hop).isMatrixMultiply() |
| && !((AggBinaryOp)hop).hasLeftPMInput() |
| && hop.getVisited() != Hop.VisitStatus.DONE ) |
| { |
| // Try to find and optimize the chain in which current Hop is the |
| // last operator |
| optimizeMMChain(hop); |
| } |
| |
| for (Hop hi : hop.getInput()) |
| rule_OptimizeMMChains(hi); |
| |
| hop.setVisited(Hop.VisitStatus.DONE); |
| } |
| |
| |
| /** |
| * optimizeMMChain(): It optimizes the matrix multiplication chain in which |
| * the last Hop is "this". Step-1) Identify the chain (mmChain). (Step-2) clear all |
| * links among the Hops that are involved in mmChain. (Step-3) Find the |
| * optimal ordering (dynamic programming) (Step-4) Relink the hops in |
| * mmChain. |
| */ |
| private void optimizeMMChain( Hop hop ) throws HopsException |
| { |
| if( LOG.isTraceEnabled() ) { |
| LOG.trace("MM Chain Optimization for HOP: (" + " " + hop.getClass().getSimpleName() + ", " + hop.getHopID() + ", " |
| + hop.getName() + ")"); |
| } |
| |
| ArrayList<Hop> mmChain = new ArrayList<Hop>(); |
| ArrayList<Hop> mmOperators = new ArrayList<Hop>(); |
| ArrayList<Hop> tempList; |
| |
| // Step 1: Identify the chain (mmChain) & clear all links among the Hops |
| // that are involved in mmChain. |
| |
| mmOperators.add(hop); |
| // Initialize mmChain with my inputs |
| for (Hop hi : hop.getInput()) { |
| mmChain.add(hi); |
| } |
| |
| // expand each Hop in mmChain to find the entire matrix multiplication |
| // chain |
| int i = 0; |
| while (i < mmChain.size()) { |
| |
| boolean expandable = false; |
| |
| Hop h = mmChain.get(i); |
| /* |
| * Check if mmChain[i] is expandable: |
| * 1) It must be MATMULT |
| * 2) It must not have been visited already |
| * (one MATMULT should get expanded only in one chain) |
| * 3) Its output should not be used in multiple places |
| * (either within chain or outside the chain) |
| */ |
| |
| if ( h instanceof AggBinaryOp && ((AggBinaryOp) h).isMatrixMultiply() |
| && !((AggBinaryOp)hop).hasLeftPMInput() |
| && h.getVisited() != Hop.VisitStatus.DONE ) |
| { |
| // check if the output of "h" is used at multiple places. If yes, it can |
| // not be expanded. |
| if (h.getParent().size() > 1 || inputCount( (Hop) ((h.getParent().toArray())[0]), h) > 1 ) { |
| expandable = false; |
| break; |
| } |
| else |
| expandable = true; |
| } |
| |
| h.setVisited(Hop.VisitStatus.DONE); |
| |
| if ( !expandable ) { |
| i = i + 1; |
| } else { |
| tempList = mmChain.get(i).getInput(); |
| if (tempList.size() != 2) { |
| throw new HopsException(hop.printErrorLocation() + "Hops::rule_OptimizeMMChain(): AggBinary must have exactly two inputs."); |
| } |
| |
| // add current operator to mmOperators, and its input nodes to mmChain |
| mmOperators.add(mmChain.get(i)); |
| mmChain.set(i, tempList.get(0)); |
| mmChain.add(i + 1, tempList.get(1)); |
| } |
| } |
| |
| // print the MMChain |
| if( LOG.isTraceEnabled() ) { |
| LOG.trace("Identified MM Chain: "); |
| for (Hop h : mmChain) { |
| logTraceHop(h, 1); |
| } |
| } |
| |
| if (mmChain.size() == 2) { |
| // If the chain size is 2, then there is nothing to optimize. |
| return; |
| } |
| else |
| { |
| // Step 2: construct dims array |
| double[] dimsArray = new double[mmChain.size() + 1]; |
| boolean dimsKnown = getDimsArray( hop, mmChain, dimsArray ); |
| |
| if( dimsKnown ) { |
| // Step 3: clear the links among Hops within the identified chain |
| clearLinksWithinChain ( hop, mmOperators ); |
| |
| // Step 4: Find the optimal ordering via dynamic programming. |
| |
| // Invoke Dynamic Programming |
| int size = mmChain.size(); |
| int[][] split = mmChainDP(dimsArray, mmChain.size()); |
| |
| // Step 5: Relink the hops using the optimal ordering (split[][]) found from DP. |
| LOG.trace("Optimal MM Chain: "); |
| mmChainRelinkHops(mmOperators.get(0), 0, size - 1, mmChain, mmOperators, 1, split, 1); |
| } |
| } |
| } |
| |
| /** |
| * mmChainDP(): Core method to perform dynamic programming on a given array |
| * of matrix dimensions. |
| * |
| * Thomas H. Cormen, Charles E. Leiserson, Ronald L. Rivest, Clifford Stein |
| * Introduction to Algorithms, Third Edition, MIT Press, page 395. |
| */ |
| private int[][] mmChainDP(double[] dimArray, int size) |
| { |
| double[][] dpMatrix = new double[size][size]; //min cost table |
| int[][] split = new int[size][size]; //min cost index table |
| |
| //init minimum costs for chains of length 1 |
| for (int i = 0; i < size; i++) { |
| Arrays.fill(dpMatrix[i], 0); |
| Arrays.fill(split[i], -1); |
| } |
| |
| //compute cost-optimal chains for increasing chain sizes |
| for (int l = 2; l <= size; l++) { // chain length |
| for (int i = 0; i < size - l + 1; i++) { |
| int j = i + l - 1; |
| // find cost of (i,j) |
| dpMatrix[i][j] = Double.MAX_VALUE; |
| for (int k = i; k <= j - 1; k++) |
| { |
| //recursive cost computation |
| double cost = dpMatrix[i][k] + dpMatrix[k + 1][j] |
| + (dimArray[i] * dimArray[k + 1] * dimArray[j + 1]); |
| |
| //prune suboptimal |
| if (cost < dpMatrix[i][j]) { |
| dpMatrix[i][j] = cost; |
| split[i][j] = k; |
| } |
| } |
| |
| if( LOG.isTraceEnabled() ){ |
| LOG.trace("mmchainopt [i="+(i+1)+",j="+(j+1)+"]: costs = "+dpMatrix[i][j]+", split = "+(split[i][j]+1)); |
| } |
| } |
| } |
| |
| return split; |
| } |
| |
| /** |
| * mmChainRelinkHops(): This method gets invoked after finding the optimal |
| * order (split[][]) from dynamic programming. It relinks the Hops that are |
| * part of the mmChain. mmChain : basic operands in the entire matrix |
| * multiplication chain. mmOperators : Hops that store the intermediate |
| * results in the chain. For example: A = B %*% (C %*% D) there will be |
| * three Hops in mmChain (B,C,D), and two Hops in mmOperators (one for each |
| * %*%) . |
| */ |
| private void mmChainRelinkHops(Hop h, int i, int j, ArrayList<Hop> mmChain, ArrayList<Hop> mmOperators, |
| int opIndex, int[][] split, int level) |
| { |
| //single matrix - end of recursion |
| if (i == j) { |
| logTraceHop(h, level); |
| return; |
| } |
| |
| if( LOG.isTraceEnabled() ){ |
| String offset = Explain.getIdentation(level); |
| LOG.trace(offset + "("); |
| } |
| |
| // Set Input1 for current Hop h |
| if (i == split[i][j]) { |
| h.getInput().add(mmChain.get(i)); |
| mmChain.get(i).getParent().add(h); |
| } else { |
| h.getInput().add(mmOperators.get(opIndex)); |
| mmOperators.get(opIndex).getParent().add(h); |
| opIndex = opIndex + 1; |
| } |
| |
| // Set Input2 for current Hop h |
| if (split[i][j] + 1 == j) { |
| h.getInput().add(mmChain.get(j)); |
| mmChain.get(j).getParent().add(h); |
| } else { |
| h.getInput().add(mmOperators.get(opIndex)); |
| mmOperators.get(opIndex).getParent().add(h); |
| opIndex = opIndex + 1; |
| } |
| |
| // Find children for both the inputs |
| mmChainRelinkHops(h.getInput().get(0), i, split[i][j], mmChain, mmOperators, opIndex, split, level+1); |
| mmChainRelinkHops(h.getInput().get(1), split[i][j] + 1, j, mmChain, mmOperators, opIndex, split, level+1); |
| |
| // Propagate properties of input hops to current hop h |
| h.refreshSizeInformation(); |
| |
| if( LOG.isTraceEnabled() ){ |
| String offset = Explain.getIdentation(level); |
| LOG.trace(offset + ")"); |
| } |
| } |
| |
| /** |
| * |
| * @param operators |
| * @throws HopsException |
| */ |
| private void clearLinksWithinChain ( Hop hop, ArrayList<Hop> operators ) |
| throws HopsException |
| { |
| Hop op, input1, input2; |
| |
| for ( int i=0; i < operators.size(); i++ ) { |
| op = operators.get(i); |
| if ( op.getInput().size() != 2 || (i != 0 && op.getParent().size() > 1 ) ) { |
| throw new HopsException(hop.printErrorLocation() + "Unexpected error while applying optimization on matrix-mult chain. \n"); |
| } |
| input1 = op.getInput().get(0); |
| input2 = op.getInput().get(1); |
| |
| op.getInput().clear(); |
| input1.getParent().remove(op); |
| input2.getParent().remove(op); |
| } |
| } |
| |
| /** |
| * Obtains all dimension information of the chain and constructs the dimArray. |
| * If all dimensions are known it returns true; othrewise the mmchain rewrite |
| * should be ended without modifications. |
| * |
| * @param hop |
| * @param chain |
| * @param dimArray |
| * @return |
| * @throws HopsException |
| */ |
| private boolean getDimsArray( Hop hop, ArrayList<Hop> chain, double[] dimsArray ) |
| throws HopsException |
| { |
| boolean dimsKnown = true; |
| |
| // Build the array containing dimensions from all matrices in the chain |
| // check the dimensions in the matrix chain to insure all dimensions are known |
| for (int i=0; i< chain.size(); i++){ |
| if (chain.get(i).getDim1() <= 0 || chain.get(i).getDim2() <= 0) |
| dimsKnown = false; |
| } |
| |
| if( dimsKnown ) { //populate dims array if all dims known |
| for (int i = 0; i < chain.size(); i++) |
| { |
| if (i == 0) { |
| dimsArray[i] = chain.get(i).getDim1(); |
| if (dimsArray[i] <= 0) { |
| throw new HopsException(hop.printErrorLocation() + |
| "Hops::optimizeMMChain() : Invalid Matrix Dimension: "+ dimsArray[i]); |
| } |
| } else { |
| if (chain.get(i - 1).getDim2() != chain.get(i).getDim1()) { |
| throw new HopsException(hop.printErrorLocation() + |
| "Hops::optimizeMMChain() : Matrix Dimension Mismatch: "+chain.get(i - 1).getDim2()+" != "+chain.get(i).getDim1()); |
| } |
| } |
| dimsArray[i + 1] = chain.get(i).getDim2(); |
| if (dimsArray[i + 1] <= 0) { |
| throw new HopsException(hop.printErrorLocation() + |
| "Hops::optimizeMMChain() : Invalid Matrix Dimension: " + dimsArray[i + 1]); |
| } |
| } |
| } |
| |
| return dimsKnown; |
| } |
| |
| |
| /** |
| * |
| * @param p |
| * @param h |
| * @return |
| */ |
| private int inputCount ( Hop p, Hop h ) { |
| int count = 0; |
| for ( int i=0; i < p.getInput().size(); i++ ) |
| if ( p.getInput().get(i).equals(h) ) |
| count++; |
| return count; |
| } |
| |
| /** |
| * |
| * @param hop |
| * @param level |
| */ |
| private void logTraceHop( Hop hop, int level ) |
| { |
| if( LOG.isTraceEnabled() ) { |
| String offset = Explain.getIdentation(level); |
| LOG.trace(offset+ "Hop " + hop.getName() + "(" + hop.getClass().getSimpleName() + ", " + hop.getHopID() + ")" + " " |
| + hop.getDim1() + "x" + hop.getDim2()); |
| } |
| } |
| } |