blob: 145b76c088bc9394b697251c2a88697d234a6b79 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.utils.Explain;
/**
* Rule: Determine the optimal order of execution for a chain of
* matrix multiplications
*
* Solution: Classic Dynamic Programming
* Approach: Currently, the approach based only on matrix dimensions
* Goal: To reduce the number of computations in the run-time
* (map-reduce) layer
*/
public class RewriteMatrixMultChainOptimization extends HopRewriteRule
{
@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state)
{
if( roots == null )
return null;
// Find the optimal order for the chain whose result is the current HOP
for( Hop h : roots )
rule_OptimizeMMChains(h, state);
return roots;
}
@Override
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state)
{
if( root == null )
return null;
// Find the optimal order for the chain whose result is the current HOP
rule_OptimizeMMChains(root, state);
return root;
}
/**
* rule_OptimizeMMChains(): This method recurses through all Hops in the DAG
* to find chains that need to be optimized.
*
* @param hop high-level operator
*/
private void rule_OptimizeMMChains(Hop hop, ProgramRewriteStatus state)
{
if( hop.isVisited() )
return;
if( HopRewriteUtils.isMatrixMultiply(hop)
&& !((AggBinaryOp)hop).hasLeftPMInput() && !hop.isVisited() )
{
// Try to find and optimize the chain in which current Hop is the
// last operator
prepAndOptimizeMMChain(hop, state);
}
for( Hop hi : hop.getInput() )
rule_OptimizeMMChains(hi, state);
hop.setVisited();
}
/**
* optimizeMMChain(): It optimizes the matrix multiplication chain in which
* the last Hop is "this". Step-1) Identify the chain (mmChain). (Step-2) clear all
* links among the Hops that are involved in mmChain. (Step-3) Find the
* optimal ordering (dynamic programming) (Step-4) Relink the hops in
* mmChain.
*
* @param hop high-level operator
*/
private void prepAndOptimizeMMChain( Hop hop, ProgramRewriteStatus state )
{
if( LOG.isTraceEnabled() ) {
LOG.trace("MM Chain Optimization for HOP: (" + hop.getClass().getSimpleName()
+ ", " + hop.getHopID() + ", " + hop.getName() + ")");
}
ArrayList<Hop> mmChain = new ArrayList<>();
ArrayList<Hop> mmOperators = new ArrayList<>();
ArrayList<Hop> tempList;
// Step 1: Identify the chain (mmChain) & clear all links among the Hops
// that are involved in mmChain.
// Initialize mmChain with my inputs
mmOperators.add(hop);
for( Hop hi : hop.getInput() )
mmChain.add(hi);
// expand each Hop in mmChain to find the entire matrix multiplication
// chain
int i = 0;
while( i < mmChain.size() )
{
boolean expandable = false;
Hop h = mmChain.get(i);
/*
* Check if mmChain[i] is expandable:
* 1) It must be MATMULT
* 2) It must not have been visited already
* (one MATMULT should get expanded only in one chain)
* 3) Its output should not be used in multiple places
* (either within chain or outside the chain)
*/
if ( HopRewriteUtils.isMatrixMultiply(h)
&& !((AggBinaryOp)hop).hasLeftPMInput() && !h.isVisited() )
{
// check if the output of "h" is used at multiple places. If yes, it can
// not be expanded.
expandable = !(h.getParent().size() > 1
|| inputCount(h.getParent().get(0), h) > 1);
if( !expandable )
break;
}
h.setVisited();
if( !expandable ) {
i = i + 1;
}
else {
tempList = mmChain.get(i).getInput();
if( tempList.size() != 2 ) {
throw new HopsException(hop.printErrorLocation() + "Hops::rule_OptimizeMMChain(): AggBinary must have exactly two inputs.");
}
// add current operator to mmOperators, and its input nodes to mmChain
mmOperators.add(mmChain.get(i));
mmChain.set(i, tempList.get(0));
mmChain.add(i + 1, tempList.get(1));
}
}
// print the MMChain
if( LOG.isTraceEnabled() ) {
LOG.trace("Identified MM Chain: ");
for( Hop h : mmChain ) {
logTraceHop(h, 1);
}
}
//core mmchain optimization (potentially overridden)
if( mmChain.size() == 2 )
return; //nothing to optimize
else
optimizeMMChain(hop, mmChain, mmOperators, state);
}
protected void optimizeMMChain(Hop hop, ArrayList<Hop> mmChain, ArrayList<Hop> mmOperators, ProgramRewriteStatus state) {
// Step 2: construct dims array
double[] dimsArray = new double[mmChain.size() + 1];
boolean dimsKnown = getDimsArray( hop, mmChain, dimsArray );
if( dimsKnown ) {
// Step 3: clear the links among Hops within the identified chain
clearLinksWithinChain ( hop, mmOperators );
// Step 4: Find the optimal ordering via dynamic programming.
// Invoke Dynamic Programming
int size = mmChain.size();
int[][] split = mmChainDP(dimsArray, mmChain.size());
// Step 5: Relink the hops using the optimal ordering (split[][]) found from DP.
LOG.trace("Optimal MM Chain: ");
mmChainRelinkHops(mmOperators.get(0), 0, size - 1, mmChain, mmOperators, 1, split, 1);
}
}
/**
* mmChainDP(): Core method to perform dynamic programming on a given array
* of matrix dimensions.
*
* Thomas H. Cormen, Charles E. Leiserson, Ronald L. Rivest, Clifford Stein
* Introduction to Algorithms, Third Edition, MIT Press, page 395.
*/
private static int[][] mmChainDP(double[] dimArray, int size)
{
double[][] dpMatrix = new double[size][size]; //min cost table
int[][] split = new int[size][size]; //min cost index table
//init minimum costs for chains of length 1
for( int i = 0; i < size; i++ ) {
Arrays.fill(dpMatrix[i], 0);
Arrays.fill(split[i], -1);
}
//compute cost-optimal chains for increasing chain sizes
for( int l = 2; l <= size; l++ ) { // chain length
for( int i = 0; i < size - l + 1; i++ ) {
int j = i + l - 1;
// find cost of (i,j)
dpMatrix[i][j] = Double.MAX_VALUE;
for( int k = i; k <= j - 1; k++ )
{
//recursive cost computation
double cost = dpMatrix[i][k] + dpMatrix[k + 1][j]
+ (dimArray[i] * dimArray[k + 1] * dimArray[j + 1]);
//prune suboptimal
if( cost < dpMatrix[i][j] ) {
dpMatrix[i][j] = cost;
split[i][j] = k;
}
}
if( LOG.isTraceEnabled() ){
LOG.trace("mmchainopt [i="+(i+1)+",j="+(j+1)+"]: costs = "+dpMatrix[i][j]+", split = "+(split[i][j]+1));
}
}
}
return split;
}
/**
* mmChainRelinkHops(): This method gets invoked after finding the optimal
* order (split[][]) from dynamic programming. It relinks the Hops that are
* part of the mmChain.
* @param mmChain : basic operands in the entire matrix multiplication chain.
* @param mmOperators : Hops that store the intermediate results in the chain.
* For example: A = B %*% (C %*% D) there will be three
* Hops in mmChain (B,C,D), and two Hops in mmOperators
* (one for each * %*%).
* @param h high level operator
* @param i array index i
* @param j array index j
* @param opIndex operator index
* @param split optimal order
* @param level log level
*/
protected final void mmChainRelinkHops(Hop h, int i, int j, ArrayList<Hop> mmChain, ArrayList<Hop> mmOperators,
int opIndex, int[][] split, int level)
{
//single matrix - end of recursion
if( i == j ) {
logTraceHop(h, level);
return;
}
if( LOG.isTraceEnabled() ){
String offset = Explain.getIdentation(level);
LOG.trace(offset + "(");
}
// Set Input1 for current Hop h
if( i == split[i][j] ) {
h.getInput().add(mmChain.get(i));
mmChain.get(i).getParent().add(h);
}
else {
h.getInput().add(mmOperators.get(opIndex));
mmOperators.get(opIndex).getParent().add(h);
opIndex = opIndex + 1;
}
// Set Input2 for current Hop h
if( split[i][j] + 1 == j ) {
h.getInput().add(mmChain.get(j));
mmChain.get(j).getParent().add(h);
}
else {
h.getInput().add(mmOperators.get(opIndex));
mmOperators.get(opIndex).getParent().add(h);
opIndex = opIndex + 1;
}
// Find children for both the inputs
mmChainRelinkHops(h.getInput().get(0), i, split[i][j], mmChain, mmOperators, opIndex, split, level+1);
mmChainRelinkHops(h.getInput().get(1), split[i][j] + 1, j, mmChain, mmOperators, opIndex, split, level+1);
// Propagate properties of input hops to current hop h
h.refreshSizeInformation();
if( LOG.isTraceEnabled() ){
String offset = Explain.getIdentation(level);
LOG.trace(offset + ")");
}
}
protected static void clearLinksWithinChain( Hop hop, ArrayList<Hop> operators )
{
for( int i=0; i < operators.size(); i++ ) {
Hop op = operators.get(i);
if( op.getInput().size() != 2 || (i != 0 && op.getParent().size() > 1 ) ) {
throw new HopsException(hop.printErrorLocation() +
"Unexpected error while applying optimization on matrix-mult chain. \n");
}
Hop input1 = op.getInput().get(0);
Hop input2 = op.getInput().get(1);
op.getInput().clear();
input1.getParent().remove(op);
input2.getParent().remove(op);
}
}
/**
* Obtains all dimension information of the chain and constructs the dimArray.
* If all dimensions are known it returns true; othrewise the mmchain rewrite
* should be ended without modifications.
*
* @param hop high-level operator
* @param chain list of high-level operators
* @param dimsArray dimension array
* @return true if all dimensions known
*/
protected static boolean getDimsArray( Hop hop, ArrayList<Hop> chain, double[] dimsArray )
{
boolean dimsKnown = true;
// Build the array containing dimensions from all matrices in the chain
// check the dimensions in the matrix chain to insure all dimensions are known
for( int i=0; i< chain.size(); i++ )
if( chain.get(i).getDim1() <= 0 || chain.get(i).getDim2() <= 0 )
dimsKnown = false;
if( dimsKnown ) { //populate dims array if all dims known
for( int i = 0; i < chain.size(); i++ ) {
if (i == 0) {
dimsArray[i] = chain.get(i).getDim1();
if (dimsArray[i] <= 0) {
throw new HopsException(hop.printErrorLocation() +
"Hops::optimizeMMChain() : Invalid Matrix Dimension: "+ dimsArray[i]);
}
}
else if (chain.get(i - 1).getDim2() != chain.get(i).getDim1()) {
throw new HopsException(hop.printErrorLocation() +
"Hops::optimizeMMChain() : Matrix Dimension Mismatch: " +
chain.get(i - 1).getDim2()+" != "+chain.get(i).getDim1());
}
dimsArray[i + 1] = chain.get(i).getDim2();
if( dimsArray[i + 1] <= 0 ) {
throw new HopsException(hop.printErrorLocation() +
"Hops::optimizeMMChain() : Invalid Matrix Dimension: " + dimsArray[i + 1]);
}
}
}
return dimsKnown;
}
private static int inputCount( Hop p, Hop h ) {
return CollectionUtils.cardinality(h, p.getInput());
}
private static void logTraceHop( Hop hop, int level ) {
if( LOG.isTraceEnabled() ) {
String offset = Explain.getIdentation(level);
LOG.trace(offset+ "Hop " + hop.getName() + "(" + hop.getClass().getSimpleName()
+ ", " + hop.getHopID() + ")" + " " + hop.getDim1() + "x" + hop.getDim2());
}
}
}