blob: ce55d926753a147e27f87a84fe7abbe4b38c6de1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops.rewrite;
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.utils.Explain;
/**
* Rule: Determine the optimal order of execution for a chain of
* matrix multiplications Solution: Classic Dynamic Programming
* Approach Currently, the approach based only on matrix dimensions
* Goal: To reduce the number of computations in the run-time
* (map-reduce) layer
*/
public class RewriteMatrixMultChainOptimization extends HopRewriteRule
{
private static final Log LOG = LogFactory.getLog(RewriteMatrixMultChainOptimization.class.getName());
private static final boolean LDEBUG = false;
static
{
// for internal debugging only
if( LDEBUG ) {
Logger.getLogger("org.apache.sysml.hops.rewrite.RewriteMatrixMultChainOptimization")
.setLevel((Level) Level.TRACE);
}
}
@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state)
throws HopsException
{
if( roots == null )
return null;
for( Hop h : roots )
{
// Find the optimal order for the chain whose result is the current HOP
rule_OptimizeMMChains(h);
}
return roots;
}
@Override
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state)
throws HopsException
{
if( root == null )
return null;
// Find the optimal order for the chain whose result is the current HOP
rule_OptimizeMMChains(root);
return root;
}
/**
* rule_OptimizeMMChains(): This method recurses through all Hops in the DAG
* to find chains that need to be optimized.
*/
private void rule_OptimizeMMChains(Hop hop)
throws HopsException
{
if(hop.getVisited() == Hop.VisitStatus.DONE)
return;
if ( hop instanceof AggBinaryOp && ((AggBinaryOp) hop).isMatrixMultiply()
&& !((AggBinaryOp)hop).hasLeftPMInput()
&& hop.getVisited() != Hop.VisitStatus.DONE )
{
// Try to find and optimize the chain in which current Hop is the
// last operator
optimizeMMChain(hop);
}
for (Hop hi : hop.getInput())
rule_OptimizeMMChains(hi);
hop.setVisited(Hop.VisitStatus.DONE);
}
/**
* optimizeMMChain(): It optimizes the matrix multiplication chain in which
* the last Hop is "this". Step-1) Identify the chain (mmChain). (Step-2) clear all
* links among the Hops that are involved in mmChain. (Step-3) Find the
* optimal ordering (dynamic programming) (Step-4) Relink the hops in
* mmChain.
*/
private void optimizeMMChain( Hop hop ) throws HopsException
{
if( LOG.isTraceEnabled() ) {
LOG.trace("MM Chain Optimization for HOP: (" + " " + hop.getClass().getSimpleName() + ", " + hop.getHopID() + ", "
+ hop.getName() + ")");
}
ArrayList<Hop> mmChain = new ArrayList<Hop>();
ArrayList<Hop> mmOperators = new ArrayList<Hop>();
ArrayList<Hop> tempList;
// Step 1: Identify the chain (mmChain) & clear all links among the Hops
// that are involved in mmChain.
mmOperators.add(hop);
// Initialize mmChain with my inputs
for (Hop hi : hop.getInput()) {
mmChain.add(hi);
}
// expand each Hop in mmChain to find the entire matrix multiplication
// chain
int i = 0;
while (i < mmChain.size()) {
boolean expandable = false;
Hop h = mmChain.get(i);
/*
* Check if mmChain[i] is expandable:
* 1) It must be MATMULT
* 2) It must not have been visited already
* (one MATMULT should get expanded only in one chain)
* 3) Its output should not be used in multiple places
* (either within chain or outside the chain)
*/
if ( h instanceof AggBinaryOp && ((AggBinaryOp) h).isMatrixMultiply()
&& !((AggBinaryOp)hop).hasLeftPMInput()
&& h.getVisited() != Hop.VisitStatus.DONE )
{
// check if the output of "h" is used at multiple places. If yes, it can
// not be expanded.
if (h.getParent().size() > 1 || inputCount( (Hop) ((h.getParent().toArray())[0]), h) > 1 ) {
expandable = false;
break;
}
else
expandable = true;
}
h.setVisited(Hop.VisitStatus.DONE);
if ( !expandable ) {
i = i + 1;
} else {
tempList = mmChain.get(i).getInput();
if (tempList.size() != 2) {
throw new HopsException(hop.printErrorLocation() + "Hops::rule_OptimizeMMChain(): AggBinary must have exactly two inputs.");
}
// add current operator to mmOperators, and its input nodes to mmChain
mmOperators.add(mmChain.get(i));
mmChain.set(i, tempList.get(0));
mmChain.add(i + 1, tempList.get(1));
}
}
// print the MMChain
if( LOG.isTraceEnabled() ) {
LOG.trace("Identified MM Chain: ");
for (Hop h : mmChain) {
logTraceHop(h, 1);
}
}
if (mmChain.size() == 2) {
// If the chain size is 2, then there is nothing to optimize.
return;
}
else
{
// Step 2: construct dims array
double[] dimsArray = new double[mmChain.size() + 1];
boolean dimsKnown = getDimsArray( hop, mmChain, dimsArray );
if( dimsKnown ) {
// Step 3: clear the links among Hops within the identified chain
clearLinksWithinChain ( hop, mmOperators );
// Step 4: Find the optimal ordering via dynamic programming.
// Invoke Dynamic Programming
int size = mmChain.size();
int[][] split = mmChainDP(dimsArray, mmChain.size());
// Step 5: Relink the hops using the optimal ordering (split[][]) found from DP.
LOG.trace("Optimal MM Chain: ");
mmChainRelinkHops(mmOperators.get(0), 0, size - 1, mmChain, mmOperators, 1, split, 1);
}
}
}
/**
* mmChainDP(): Core method to perform dynamic programming on a given array
* of matrix dimensions.
*
* Thomas H. Cormen, Charles E. Leiserson, Ronald L. Rivest, Clifford Stein
* Introduction to Algorithms, Third Edition, MIT Press, page 395.
*/
private int[][] mmChainDP(double[] dimArray, int size)
{
double[][] dpMatrix = new double[size][size]; //min cost table
int[][] split = new int[size][size]; //min cost index table
//init minimum costs for chains of length 1
for (int i = 0; i < size; i++) {
Arrays.fill(dpMatrix[i], 0);
Arrays.fill(split[i], -1);
}
//compute cost-optimal chains for increasing chain sizes
for (int l = 2; l <= size; l++) { // chain length
for (int i = 0; i < size - l + 1; i++) {
int j = i + l - 1;
// find cost of (i,j)
dpMatrix[i][j] = Double.MAX_VALUE;
for (int k = i; k <= j - 1; k++)
{
//recursive cost computation
double cost = dpMatrix[i][k] + dpMatrix[k + 1][j]
+ (dimArray[i] * dimArray[k + 1] * dimArray[j + 1]);
//prune suboptimal
if (cost < dpMatrix[i][j]) {
dpMatrix[i][j] = cost;
split[i][j] = k;
}
}
if( LOG.isTraceEnabled() ){
LOG.trace("mmchainopt [i="+(i+1)+",j="+(j+1)+"]: costs = "+dpMatrix[i][j]+", split = "+(split[i][j]+1));
}
}
}
return split;
}
/**
* mmChainRelinkHops(): This method gets invoked after finding the optimal
* order (split[][]) from dynamic programming. It relinks the Hops that are
* part of the mmChain. mmChain : basic operands in the entire matrix
* multiplication chain. mmOperators : Hops that store the intermediate
* results in the chain. For example: A = B %*% (C %*% D) there will be
* three Hops in mmChain (B,C,D), and two Hops in mmOperators (one for each
* %*%) .
*/
private void mmChainRelinkHops(Hop h, int i, int j, ArrayList<Hop> mmChain, ArrayList<Hop> mmOperators,
int opIndex, int[][] split, int level)
{
//single matrix - end of recursion
if (i == j) {
logTraceHop(h, level);
return;
}
if( LOG.isTraceEnabled() ){
String offset = Explain.getIdentation(level);
LOG.trace(offset + "(");
}
// Set Input1 for current Hop h
if (i == split[i][j]) {
h.getInput().add(mmChain.get(i));
mmChain.get(i).getParent().add(h);
} else {
h.getInput().add(mmOperators.get(opIndex));
mmOperators.get(opIndex).getParent().add(h);
opIndex = opIndex + 1;
}
// Set Input2 for current Hop h
if (split[i][j] + 1 == j) {
h.getInput().add(mmChain.get(j));
mmChain.get(j).getParent().add(h);
} else {
h.getInput().add(mmOperators.get(opIndex));
mmOperators.get(opIndex).getParent().add(h);
opIndex = opIndex + 1;
}
// Find children for both the inputs
mmChainRelinkHops(h.getInput().get(0), i, split[i][j], mmChain, mmOperators, opIndex, split, level+1);
mmChainRelinkHops(h.getInput().get(1), split[i][j] + 1, j, mmChain, mmOperators, opIndex, split, level+1);
// Propagate properties of input hops to current hop h
h.refreshSizeInformation();
if( LOG.isTraceEnabled() ){
String offset = Explain.getIdentation(level);
LOG.trace(offset + ")");
}
}
/**
*
* @param operators
* @throws HopsException
*/
private void clearLinksWithinChain ( Hop hop, ArrayList<Hop> operators )
throws HopsException
{
Hop op, input1, input2;
for ( int i=0; i < operators.size(); i++ ) {
op = operators.get(i);
if ( op.getInput().size() != 2 || (i != 0 && op.getParent().size() > 1 ) ) {
throw new HopsException(hop.printErrorLocation() + "Unexpected error while applying optimization on matrix-mult chain. \n");
}
input1 = op.getInput().get(0);
input2 = op.getInput().get(1);
op.getInput().clear();
input1.getParent().remove(op);
input2.getParent().remove(op);
}
}
/**
* Obtains all dimension information of the chain and constructs the dimArray.
* If all dimensions are known it returns true; othrewise the mmchain rewrite
* should be ended without modifications.
*
* @param hop
* @param chain
* @param dimArray
* @return
* @throws HopsException
*/
private boolean getDimsArray( Hop hop, ArrayList<Hop> chain, double[] dimsArray )
throws HopsException
{
boolean dimsKnown = true;
// Build the array containing dimensions from all matrices in the chain
// check the dimensions in the matrix chain to insure all dimensions are known
for (int i=0; i< chain.size(); i++){
if (chain.get(i).getDim1() <= 0 || chain.get(i).getDim2() <= 0)
dimsKnown = false;
}
if( dimsKnown ) { //populate dims array if all dims known
for (int i = 0; i < chain.size(); i++)
{
if (i == 0) {
dimsArray[i] = chain.get(i).getDim1();
if (dimsArray[i] <= 0) {
throw new HopsException(hop.printErrorLocation() +
"Hops::optimizeMMChain() : Invalid Matrix Dimension: "+ dimsArray[i]);
}
} else {
if (chain.get(i - 1).getDim2() != chain.get(i).getDim1()) {
throw new HopsException(hop.printErrorLocation() +
"Hops::optimizeMMChain() : Matrix Dimension Mismatch: "+chain.get(i - 1).getDim2()+" != "+chain.get(i).getDim1());
}
}
dimsArray[i + 1] = chain.get(i).getDim2();
if (dimsArray[i + 1] <= 0) {
throw new HopsException(hop.printErrorLocation() +
"Hops::optimizeMMChain() : Invalid Matrix Dimension: " + dimsArray[i + 1]);
}
}
}
return dimsKnown;
}
/**
*
* @param p
* @param h
* @return
*/
private int inputCount ( Hop p, Hop h ) {
int count = 0;
for ( int i=0; i < p.getInput().size(); i++ )
if ( p.getInput().get(i).equals(h) )
count++;
return count;
}
/**
*
* @param hop
* @param level
*/
private void logTraceHop( Hop hop, int level )
{
if( LOG.isTraceEnabled() ) {
String offset = Explain.getIdentation(level);
LOG.trace(offset+ "Hop " + hop.getName() + "(" + hop.getClass().getSimpleName() + ", " + hop.getHopID() + ")" + " "
+ hop.getDim1() + "x" + hop.getDim2());
}
}
}