blob: eefd8d09ca421576d8781e285e63f783d46c9f17 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops.recompile;
import java.util.ArrayList;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.DataOpTypes;
import org.apache.sysml.hops.Hop.Direction;
import org.apache.sysml.hops.Hop.OpOp1;
import org.apache.sysml.hops.Hop.VisitStatus;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.utils.Statistics;
public class LiteralReplacement
{
//internal configuration parameters
private static final long REPLACE_LITERALS_MAX_MATRIX_SIZE = 1000000; //10^6 cells (8MB)
private static final boolean REPORT_LITERAL_REPLACE_OPS_STATS = true;
/**
*
* @param hop
* @param vars
* @throws DMLRuntimeException
*/
protected static void rReplaceLiterals( Hop hop, LocalVariableMap vars )
throws DMLRuntimeException
{
if( hop.getVisited() == VisitStatus.DONE )
return;
if( hop.getInput() != null )
{
//indexed access to allow parent-child modifications
for( int i=0; i<hop.getInput().size(); i++ )
{
Hop c = hop.getInput().get(i);
Hop lit = null;
//conditional apply of literal replacements
lit = (lit==null) ? replaceLiteralScalarRead(c, vars) : lit;
lit = (lit==null) ? replaceLiteralValueTypeCastScalarRead(c, vars) : lit;
lit = (lit==null) ? replaceLiteralValueTypeCastLiteral(c, vars) : lit;
lit = (lit==null) ? replaceLiteralDataTypeCastMatrixRead(c, vars) : lit;
lit = (lit==null) ? replaceLiteralValueTypeCastRightIndexing(c, vars) : lit;
lit = (lit==null) ? replaceLiteralFullUnaryAggregate(c, vars) : lit;
lit = (lit==null) ? replaceLiteralFullUnaryAggregateRightIndexing(c, vars) : lit;
//replace hop w/ literal on demand
if( lit != null )
{
//replace hop c by literal, for all parents to prevent (1) missed opportunities
//because hop c marked as visited, and (2) repeated evaluation of uagg ops
if( c.getParent().size() > 1 ) { //multiple parents
ArrayList<Hop> parents = new ArrayList<Hop>(c.getParent());
for( Hop p : parents ) {
int pos = HopRewriteUtils.getChildReferencePos(p, c);
HopRewriteUtils.removeChildReferenceByPos(p, c, pos);
HopRewriteUtils.addChildReference(p, lit, pos);
}
}
else { //current hop is only parent
HopRewriteUtils.removeChildReferenceByPos(hop, c, i);
HopRewriteUtils.addChildReference(hop, lit, i);
}
}
//recursively process children
else
{
rReplaceLiterals(c, vars);
}
}
}
hop.setVisited(VisitStatus.DONE);
}
///////////////////////////////
// Literal replacement rules
///////////////////////////////
/**
*
* @param c
* @param vars
* @return
*/
private static LiteralOp replaceLiteralScalarRead(Hop c, LocalVariableMap vars)
{
LiteralOp ret = null;
//scalar read - literal replacement
if( c instanceof DataOp && ((DataOp)c).getDataOpType() != DataOpTypes.PERSISTENTREAD
&& c.getDataType()==DataType.SCALAR )
{
Data dat = vars.get(c.getName());
if( dat != null ) //required for selective constant propagation
{
ScalarObject sdat = (ScalarObject)dat;
switch( sdat.getValueType() ) {
case INT:
ret = new LiteralOp(sdat.getLongValue());
break;
case DOUBLE:
ret = new LiteralOp(sdat.getDoubleValue());
break;
case BOOLEAN:
ret = new LiteralOp(sdat.getBooleanValue());
break;
default:
//otherwise: do nothing
}
}
}
return ret;
}
/**
*
* @param c
* @param vars
* @return
*/
private static LiteralOp replaceLiteralValueTypeCastScalarRead( Hop c, LocalVariableMap vars )
{
LiteralOp ret = null;
//as.double/as.integer/as.boolean over scalar read - literal replacement
if( c instanceof UnaryOp && (((UnaryOp)c).getOp() == OpOp1.CAST_AS_DOUBLE
|| ((UnaryOp)c).getOp() == OpOp1.CAST_AS_INT || ((UnaryOp)c).getOp() == OpOp1.CAST_AS_BOOLEAN )
&& c.getInput().get(0) instanceof DataOp && c.getDataType()==DataType.SCALAR )
{
Data dat = vars.get(c.getInput().get(0).getName());
if( dat != null ) //required for selective constant propagation
{
ScalarObject sdat = (ScalarObject)dat;
UnaryOp cast = (UnaryOp) c;
switch( cast.getOp() ) {
case CAST_AS_INT:
ret = new LiteralOp(sdat.getLongValue());
break;
case CAST_AS_DOUBLE:
ret = new LiteralOp(sdat.getDoubleValue());
break;
case CAST_AS_BOOLEAN:
ret = new LiteralOp(sdat.getBooleanValue());
break;
default:
//otherwise: do nothing
}
}
}
return ret;
}
/**
*
* @param c
* @param vars
* @return
* @throws DMLRuntimeException
*/
private static LiteralOp replaceLiteralValueTypeCastLiteral( Hop c, LocalVariableMap vars )
throws DMLRuntimeException
{
LiteralOp ret = null;
//as.double/as.integer/as.boolean over scalar literal (potentially created by other replacement
//rewrite in same dag) - literal replacement
if( c instanceof UnaryOp && (((UnaryOp)c).getOp() == OpOp1.CAST_AS_DOUBLE
|| ((UnaryOp)c).getOp() == OpOp1.CAST_AS_INT || ((UnaryOp)c).getOp() == OpOp1.CAST_AS_BOOLEAN )
&& c.getInput().get(0) instanceof LiteralOp )
{
LiteralOp sdat = (LiteralOp)c.getInput().get(0);
UnaryOp cast = (UnaryOp) c;
try
{
switch( cast.getOp() ) {
case CAST_AS_INT:
long ival = HopRewriteUtils.getIntValue(sdat);
ret = new LiteralOp(ival);
break;
case CAST_AS_DOUBLE:
double dval = HopRewriteUtils.getDoubleValue(sdat);
ret = new LiteralOp(dval);
break;
case CAST_AS_BOOLEAN:
boolean bval = HopRewriteUtils.getBooleanValue(sdat);
ret = new LiteralOp(bval);
break;
default:
//otherwise: do nothing
}
}
catch(HopsException ex) {
throw new DMLRuntimeException(ex);
}
}
return ret;
}
/**
*
* @param c
* @param vars
* @return
* @throws DMLRuntimeException
*/
private static LiteralOp replaceLiteralDataTypeCastMatrixRead( Hop c, LocalVariableMap vars )
throws DMLRuntimeException
{
LiteralOp ret = null;
//as.scalar/matrix read - literal replacement
if( c instanceof UnaryOp && ((UnaryOp)c).getOp() == OpOp1.CAST_AS_SCALAR
&& c.getInput().get(0) instanceof DataOp
&& c.getInput().get(0).getDataType() == DataType.MATRIX )
{
Data dat = vars.get(c.getInput().get(0).getName());
if( dat != null ) //required for selective constant propagation
{
//cast as scalar (see VariableCPInstruction)
MatrixObject mo = (MatrixObject)dat;
MatrixBlock mBlock = mo.acquireRead();
if( mBlock.getNumRows()!=1 || mBlock.getNumColumns()!=1 )
throw new DMLRuntimeException("Dimension mismatch - unable to cast matrix of dimension ("+mBlock.getNumRows()+" x "+mBlock.getNumColumns()+") to scalar.");
double value = mBlock.getValue(0,0);
mo.release();
//literal substitution (always double)
ret = new LiteralOp(value);
}
}
return ret;
}
/**
*
* @param c
* @param vars
* @return
* @throws DMLRuntimeException
*/
private static LiteralOp replaceLiteralValueTypeCastRightIndexing( Hop c, LocalVariableMap vars )
throws DMLRuntimeException
{
LiteralOp ret = null;
//as.scalar/right indexing w/ literals/vars and matrix less than 10^6 cells
if( c instanceof UnaryOp && ((UnaryOp)c).getOp() == OpOp1.CAST_AS_SCALAR
&& c.getInput().get(0) instanceof IndexingOp
&& c.getInput().get(0).getDataType() == DataType.MATRIX)
{
IndexingOp rix = (IndexingOp)c.getInput().get(0);
Hop data = rix.getInput().get(0);
Hop rl = rix.getInput().get(1);
Hop ru = rix.getInput().get(2);
Hop cl = rix.getInput().get(3);
Hop cu = rix.getInput().get(4);
if( rix.dimsKnown() && rix.getDim1()==1 && rix.getDim2()==1
&& data instanceof DataOp && vars.keySet().contains(data.getName())
&& isIntValueDataLiteral(rl, vars) && isIntValueDataLiteral(ru, vars)
&& isIntValueDataLiteral(cl, vars) && isIntValueDataLiteral(cu, vars) )
{
long rlval = getIntValueDataLiteral(rl, vars);
long clval = getIntValueDataLiteral(cl, vars);
MatrixObject mo = (MatrixObject)vars.get(data.getName());
//get the dimension information from the matrix object because the hop
//dimensions might not have been updated during recompile
if( mo.getNumRows()*mo.getNumColumns() < REPLACE_LITERALS_MAX_MATRIX_SIZE )
{
MatrixBlock mBlock = mo.acquireRead();
double value = mBlock.getValue((int)rlval-1,(int)clval-1);
mo.release();
//literal substitution (always double)
ret = new LiteralOp(value);
}
}
}
return ret;
}
/**
*
* @param c
* @param vars
* @return
* @throws DMLRuntimeException
*/
private static LiteralOp replaceLiteralFullUnaryAggregate( Hop c, LocalVariableMap vars )
throws DMLRuntimeException
{
LiteralOp ret = null;
//full unary aggregate w/ matrix less than 10^6 cells
if( c instanceof AggUnaryOp
&& isReplaceableUnaryAggregate((AggUnaryOp)c)
&& c.getInput().get(0) instanceof DataOp
&& vars.keySet().contains(c.getInput().get(0).getName()) )
{
Hop data = c.getInput().get(0);
MatrixObject mo = (MatrixObject) vars.get(data.getName());
//get the dimension information from the matrix object because the hop
//dimensions might not have been updated during recompile
if( mo.getNumRows()*mo.getNumColumns() < REPLACE_LITERALS_MAX_MATRIX_SIZE )
{
MatrixBlock mBlock = mo.acquireRead();
double value = replaceUnaryAggregate((AggUnaryOp)c, mBlock);
mo.release();
//literal substitution (always double)
ret = new LiteralOp(value);
}
}
return ret;
}
/**
*
* @param c
* @param vars
* @return
* @throws DMLRuntimeException
*/
private static LiteralOp replaceLiteralFullUnaryAggregateRightIndexing( Hop c, LocalVariableMap vars )
throws DMLRuntimeException
{
LiteralOp ret = null;
//full unary aggregate w/ indexed matrix less than 10^6 cells
if( c instanceof AggUnaryOp
&& isReplaceableUnaryAggregate((AggUnaryOp)c)
&& c.getInput().get(0) instanceof IndexingOp
&& c.getInput().get(0).getInput().get(0) instanceof DataOp )
{
IndexingOp rix = (IndexingOp)c.getInput().get(0);
Hop data = rix.getInput().get(0);
Hop rl = rix.getInput().get(1);
Hop ru = rix.getInput().get(2);
Hop cl = rix.getInput().get(3);
Hop cu = rix.getInput().get(4);
if( data instanceof DataOp && vars.keySet().contains(data.getName())
&& isIntValueDataLiteral(rl, vars) && isIntValueDataLiteral(ru, vars)
&& isIntValueDataLiteral(cl, vars) && isIntValueDataLiteral(cu, vars) )
{
long rlval = getIntValueDataLiteral(rl, vars);
long ruval = getIntValueDataLiteral(ru, vars);
long clval = getIntValueDataLiteral(cl, vars);
long cuval = getIntValueDataLiteral(cu, vars);
MatrixObject mo = (MatrixObject) vars.get(data.getName());
//get the dimension information from the matrix object because the hop
//dimensions might not have been updated during recompile
if( mo.getNumRows()*mo.getNumColumns() < REPLACE_LITERALS_MAX_MATRIX_SIZE )
{
MatrixBlock mBlock = mo.acquireRead();
MatrixBlock mBlock2 = mBlock.sliceOperations((int)(rlval-1), (int)(ruval-1), (int)(clval-1), (int)(cuval-1), new MatrixBlock());
double value = replaceUnaryAggregate((AggUnaryOp)c, mBlock2);
mo.release();
//literal substitution (always double)
ret = new LiteralOp(value);
}
}
}
return ret;
}
///////////////////////////////
// Utility functions
///////////////////////////////
/**
*
* @param h
* @param vars
* @return
*/
private static boolean isIntValueDataLiteral(Hop h, LocalVariableMap vars)
{
return ( (h instanceof DataOp && vars.keySet().contains(h.getName()))
|| h instanceof LiteralOp
||(h instanceof UnaryOp && (((UnaryOp)h).getOp()==OpOp1.NROW || ((UnaryOp)h).getOp()==OpOp1.NCOL)
&& h.getInput().get(0) instanceof DataOp && vars.keySet().contains(h.getInput().get(0).getName())) );
}
/**
*
* @param hop
* @param vars
* @return
* @throws DMLRuntimeException
*/
private static long getIntValueDataLiteral(Hop hop, LocalVariableMap vars)
throws DMLRuntimeException
{
long value = -1;
try
{
if( hop instanceof LiteralOp )
{
value = HopRewriteUtils.getIntValue((LiteralOp)hop);
}
else if( hop instanceof UnaryOp && ((UnaryOp)hop).getOp()==OpOp1.NROW )
{
//get the dimension information from the matrix object because the hop
//dimensions might not have been updated during recompile
MatrixObject mo = (MatrixObject)vars.get(hop.getInput().get(0).getName());
value = mo.getNumRows();
}
else if( hop instanceof UnaryOp && ((UnaryOp)hop).getOp()==OpOp1.NCOL )
{
//get the dimension information from the matrix object because the hop
//dimensions might not have been updated during recompile
MatrixObject mo = (MatrixObject)vars.get(hop.getInput().get(0).getName());
value = mo.getNumColumns();
}
else
{
ScalarObject sdat = (ScalarObject) vars.get(hop.getName());
value = sdat.getLongValue();
}
}
catch(HopsException ex)
{
throw new DMLRuntimeException("Failed to get int value for literal replacement", ex);
}
return value;
}
/**
*
* @param auop
* @return
*/
private static boolean isReplaceableUnaryAggregate( AggUnaryOp auop )
{
boolean cdir = (auop.getDirection() == Direction.RowCol);
boolean cop = ( auop.getOp() == AggOp.SUM
|| auop.getOp() == AggOp.SUM_SQ
|| auop.getOp() == AggOp.MIN
|| auop.getOp() == AggOp.MAX );
return cdir && cop;
}
/**
*
* @param auop
* @param mb
* @return
* @throws DMLRuntimeException
*/
private static double replaceUnaryAggregate( AggUnaryOp auop, MatrixBlock mb )
throws DMLRuntimeException
{
//setup stats reporting if necessary
boolean REPORT_STATS = (DMLScript.STATISTICS && REPORT_LITERAL_REPLACE_OPS_STATS);
long t0 = REPORT_STATS ? System.nanoTime() : 0;
//compute required unary aggregate
double val = Double.MAX_VALUE;
switch( auop.getOp() ) {
case SUM:
val = mb.sum();
break;
case SUM_SQ:
val = mb.sumSq();
break;
case MIN:
val = mb.min();
break;
case MAX:
val = mb.max();
break;
default:
throw new DMLRuntimeException("Unsupported unary aggregate replacement: "+auop.getOp());
}
//report statistics if necessary
if( REPORT_STATS ){
long t1 = System.nanoTime();
Statistics.maintainCPHeavyHitters("rlit", t1-t0);
}
return val;
}
}