blob: 7a31a97b5c41a61dc428a7497f73937a971f49dd [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.AggBinaryOp.SparkAggType;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.lops.Aggregate;
import org.apache.sysml.lops.Data;
import org.apache.sysml.lops.Group;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.lops.RangeBasedReIndex;
import org.apache.sysml.lops.LopProperties.ExecType;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
//for now only works for range based indexing op
public class IndexingOp extends Hop
{
public static String OPSTRING = "rix"; //"Indexing";
private boolean _rowLowerEqualsUpper = false;
private boolean _colLowerEqualsUpper = false;
private enum IndexingMethod {
CP_RIX, //in-memory range index
MR_RIX, //general case range reindex
MR_VRIX, //vector (row/col) range index
};
private IndexingOp() {
//default constructor for clone
}
//right indexing doesn't really need the dimensionality of the left matrix
//private static Lops dummy=new Data(null, Data.OperationTypes.READ, null, "-1", DataType.SCALAR, ValueType.INT, false);
public IndexingOp(String l, DataType dt, ValueType vt, Hop inpMatrix, Hop inpRowL, Hop inpRowU, Hop inpColL, Hop inpColU, boolean passedRowsLEU, boolean passedColsLEU) {
super(l, dt, vt);
getInput().add(0, inpMatrix);
getInput().add(1, inpRowL);
getInput().add(2, inpRowU);
getInput().add(3, inpColL);
getInput().add(4, inpColU);
// create hops if one of them is null
inpMatrix.getParent().add(this);
inpRowL.getParent().add(this);
inpRowU.getParent().add(this);
inpColL.getParent().add(this);
inpColU.getParent().add(this);
// set information whether left indexing operation involves row (n x 1) or column (1 x m) matrix
setRowLowerEqualsUpper(passedRowsLEU);
setColLowerEqualsUpper(passedColsLEU);
}
public boolean getRowLowerEqualsUpper(){
return _rowLowerEqualsUpper;
}
public boolean getColLowerEqualsUpper() {
return _colLowerEqualsUpper;
}
public void setRowLowerEqualsUpper(boolean passed){
_rowLowerEqualsUpper = passed;
}
public void setColLowerEqualsUpper(boolean passed) {
_colLowerEqualsUpper = passed;
}
@Override
public Lop constructLops()
throws HopsException, LopsException
{
//return already created lops
if( getLops() != null )
return getLops();
Hop input = getInput().get(0);
//rewrite remove unnecessary right indexing
if( dimsKnown() && input.dimsKnown()
&& getDim1() == input.getDim1() && getDim2() == input.getDim2()
&& !(getDim1()==1 && getDim2()==1))
{
setLops( input.constructLops() );
}
//actual lop construction, incl operator selection
else
{
try {
ExecType et = optFindExecType();
if(et == ExecType.MR) {
IndexingMethod method = optFindIndexingMethod( _rowLowerEqualsUpper, _colLowerEqualsUpper,
input._dim1, input._dim2, _dim1, _dim2);
Lop dummy = Data.createLiteralLop(ValueType.INT, Integer.toString(-1));
RangeBasedReIndex reindex = new RangeBasedReIndex(
input.constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(),
getInput().get(3).constructLops(), getInput().get(4).constructLops(), dummy, dummy,
getDataType(), getValueType(), et);
setOutputDimensions(reindex);
setLineNumbers(reindex);
if( method == IndexingMethod.MR_RIX )
{
Group group1 = new Group( reindex, Group.OperationTypes.Sort,
DataType.MATRIX, getValueType());
setOutputDimensions(group1);
setLineNumbers(group1);
Aggregate agg1 = new Aggregate(
group1, Aggregate.OperationTypes.Sum, DataType.MATRIX,
getValueType(), et);
setOutputDimensions(agg1);
setLineNumbers(agg1);
setLops(agg1);
}
else //method == IndexingMethod.MR_VRIX
{
setLops(reindex);
}
}
else if( et == ExecType.SPARK )
{
IndexingMethod method = optFindIndexingMethod( _rowLowerEqualsUpper, _colLowerEqualsUpper,
input._dim1, input._dim2, _dim1, _dim2);
SparkAggType aggtype = (method==IndexingMethod.MR_VRIX || isBlockAligned()) ?
SparkAggType.NONE : SparkAggType.MULTI_BLOCK;
Lop dummy = Data.createLiteralLop(ValueType.INT, Integer.toString(-1));
RangeBasedReIndex reindex = new RangeBasedReIndex(
input.constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(),
getInput().get(3).constructLops(), getInput().get(4).constructLops(), dummy, dummy,
getDataType(), getValueType(), aggtype, et);
setOutputDimensions(reindex);
setLineNumbers(reindex);
setLops(reindex);
}
else //CP
{
Lop dummy = Data.createLiteralLop(ValueType.INT, Integer.toString(-1));
RangeBasedReIndex reindex = new RangeBasedReIndex(
input.constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(),
getInput().get(3).constructLops(), getInput().get(4).constructLops(), dummy, dummy,
getDataType(), getValueType(), et);
setOutputDimensions(reindex);
setLineNumbers(reindex);
setLops(reindex);
}
} catch (Exception e) {
throw new HopsException(this.printErrorLocation() + "In IndexingOp Hop, error constructing Lops " , e);
}
}
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
return getLops();
}
@Override
public String getOpString() {
String s = new String("");
s += OPSTRING;
return s;
}
public void printMe() throws HopsException {
if (getVisited() != VisitStatus.DONE) {
super.printMe();
for (Hop h : getInput()) {
h.printMe();
}
}
setVisited(VisitStatus.DONE);
}
@Override
public boolean allowsAllExecTypes()
{
return true;
}
@Override
public void computeMemEstimate( MemoTable memo )
{
//default behavior
super.computeMemEstimate(memo);
//try to infer via worstcase input statistics (for the case of dims known
//but nnz initially unknown)
MatrixCharacteristics mcM1 = memo.getAllInputStats(getInput().get(0));
if( dimsKnown() && mcM1.getNonZeros()>=0 ){
long lnnz = mcM1.getNonZeros(); //worst-case output nnz
double lOutMemEst = computeOutputMemEstimate( _dim1, _dim2, lnnz );
if( lOutMemEst<_outputMemEstimate ){
_outputMemEstimate = lOutMemEst;
_memEstimate = getInputOutputSize();
}
}
}
@Override
protected double computeOutputMemEstimate( long dim1, long dim2, long nnz )
{
double sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz);
return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
}
@Override
protected double computeIntermediateMemEstimate( long dim1, long dim2, long nnz )
{
return 0;
}
@Override
protected long[] inferOutputCharacteristics( MemoTable memo )
{
long[] ret = null;
Hop input = getInput().get(0); //original matrix
MatrixCharacteristics mc = memo.getAllInputStats(input);
if( mc != null )
{
long lnnz = mc.dimsKnown()?Math.min(mc.getRows()*mc.getCols(), mc.getNonZeros()):-1;
//worst-case is input size, but dense
ret = new long[]{mc.getRows(), mc.getCols(), lnnz};
//exploit column/row indexing information
if( _rowLowerEqualsUpper ) ret[0]=1;
if( _colLowerEqualsUpper ) ret[1]=1;
//infer tight block indexing size
Hop rl = getInput().get(1);
Hop ru = getInput().get(2);
Hop cl = getInput().get(3);
Hop cu = getInput().get(4);
if( isBlockIndexingExpression(rl, ru) )
ret[0] = getBlockIndexingExpressionSize(rl, ru);
if( isBlockIndexingExpression(cl, cu) )
ret[1] = getBlockIndexingExpressionSize(cl, cu);
}
return ret;
}
/**
* Indicates if the lbound:rbound expressions is of the form
* "(c * (i - 1) + 1) : (c * i)", where we could use c as a tight size estimate.
*
* @param lbound
* @param ubound
* @return
*/
private boolean isBlockIndexingExpression(Hop lbound, Hop ubound)
{
boolean ret = false;
LiteralOp constant = null;
DataOp var = null;
//handle lower bound
if( lbound instanceof BinaryOp && ((BinaryOp)lbound).getOp()==OpOp2.PLUS
&& lbound.getInput().get(1) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValueSafe((LiteralOp)lbound.getInput().get(1))==1
&& lbound.getInput().get(0) instanceof BinaryOp)
{
BinaryOp lmult = (BinaryOp)lbound.getInput().get(0);
if( lmult.getOp()==OpOp2.MULT && lmult.getInput().get(0) instanceof LiteralOp
&& lmult.getInput().get(1) instanceof BinaryOp )
{
BinaryOp lminus = (BinaryOp)lmult.getInput().get(1);
if( lminus.getOp()==OpOp2.MINUS && lminus.getInput().get(1) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValueSafe((LiteralOp)lminus.getInput().get(1))==1
&& lminus.getInput().get(0) instanceof DataOp )
{
constant = (LiteralOp)lmult.getInput().get(0);
var = (DataOp) lminus.getInput().get(0);
}
}
}
//handle upper bound
if( var != null && constant != null && ubound instanceof BinaryOp
&& ubound.getInput().get(0) instanceof LiteralOp
&& ubound.getInput().get(1) instanceof DataOp
&& ubound.getInput().get(1).getName().equals(var.getName()) )
{
LiteralOp constant2 = (LiteralOp)ubound.getInput().get(0);
ret = ( HopRewriteUtils.getDoubleValueSafe(constant) ==
HopRewriteUtils.getDoubleValueSafe(constant2) );
}
return ret;
}
/**
* Indicates if the right indexing ranging is block aligned, i.e., it does not require
* aggregation across blocks due to shifting.
*
* @return
*/
private boolean isBlockAligned() {
Hop input1 = getInput().get(0); //original matrix
Hop input2 = getInput().get(1); //inpRowL
Hop input3 = getInput().get(2); //inpRowU
Hop input4 = getInput().get(3); //inpColL
Hop input5 = getInput().get(4); //inpRowU
long rl = (input2 instanceof LiteralOp) ? (HopRewriteUtils.getIntValueSafe((LiteralOp)input2)) : -1;
long ru = (input3 instanceof LiteralOp) ? (HopRewriteUtils.getIntValueSafe((LiteralOp)input3)) : -1;
long cl = (input4 instanceof LiteralOp) ? (HopRewriteUtils.getIntValueSafe((LiteralOp)input4)) : -1;
long cu = (input5 instanceof LiteralOp) ? (HopRewriteUtils.getIntValueSafe((LiteralOp)input5)) : -1;
int brlen = (int)input1.getRowsInBlock();
int bclen = (int)input1.getColsInBlock();
return OptimizerUtils.isIndexingRangeBlockAligned(rl, ru, cl, cu, brlen, bclen);
}
/**
*
* @param lbound
* @param ubound
* @return
*/
private long getBlockIndexingExpressionSize(Hop lbound, Hop ubound)
{
//NOTE: ensure consistency with isBlockIndexingExpression
LiteralOp c = (LiteralOp) ubound.getInput().get(0); //(c*i)
return HopRewriteUtils.getIntValueSafe(c);
}
@Override
protected ExecType optFindExecType() throws HopsException {
checkAndSetForcedPlatform();
ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
if( _etypeForced != null )
{
_etype = _etypeForced;
}
else
{
if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
_etype = findExecTypeByMemEstimate();
}
else if ( getInput().get(0).areDimsBelowThreshold() )
{
_etype = ExecType.CP;
}
else
{
_etype = REMOTE;
}
//check for valid CP dimensions and matrix size
checkAndSetInvalidCPDimsAndSize();
}
//mark for recompile (forever)
if( ConfigurationManager.isDynamicRecompilation() && !dimsKnown(true) && _etype==REMOTE )
setRequiresRecompile();
return _etype;
}
/**
*
* @param singleRow
* @param singleCol
* @param m1_dim1
* @param m1_dim2
* @param m2_dim1
* @param m2_dim2
* @return
*/
private static IndexingMethod optFindIndexingMethod( boolean singleRow, boolean singleCol, long m1_dim1, long m1_dim2, long m2_dim1, long m2_dim2 )
{
if( singleRow && m1_dim2 == m2_dim2 && m2_dim2!=-1
|| singleCol && m1_dim1 == m2_dim1 && m2_dim1!=-1 )
{
return IndexingMethod.MR_VRIX;
}
return IndexingMethod.MR_RIX; //general case
}
@Override
public void refreshSizeInformation()
{
Hop input1 = getInput().get(0); //original matrix
Hop input2 = getInput().get(1); //inpRowL
Hop input3 = getInput().get(2); //inpRowU
Hop input4 = getInput().get(3); //inpColL
Hop input5 = getInput().get(4); //inpColU
//parse input information
boolean allRows =
( input2 instanceof LiteralOp && HopRewriteUtils.getIntValueSafe((LiteralOp)input2)==1
&& input3 instanceof UnaryOp && ((UnaryOp)input3).getOp() == OpOp1.NROW );
boolean allCols =
( input4 instanceof LiteralOp && HopRewriteUtils.getIntValueSafe((LiteralOp)input4)==1
&& input5 instanceof UnaryOp && ((UnaryOp)input5).getOp() == OpOp1.NCOL );
boolean constRowRange = (input2 instanceof LiteralOp && input3 instanceof LiteralOp);
boolean constColRange = (input4 instanceof LiteralOp && input5 instanceof LiteralOp);
//set dimension information
if( _rowLowerEqualsUpper ) //ROWS
setDim1(1);
else if( allRows )
setDim1(input1.getDim1());
else if( constRowRange ){
setDim1( HopRewriteUtils.getIntValueSafe((LiteralOp)input3)
-HopRewriteUtils.getIntValueSafe((LiteralOp)input2)+1 );
}
else if( isBlockIndexingExpression(input2, input3) ) {
setDim1(getBlockIndexingExpressionSize(input2, input3));
}
if( _colLowerEqualsUpper ) //COLS
setDim2(1);
else if( allCols )
setDim2(input1.getDim2());
else if( constColRange ){
setDim2( HopRewriteUtils.getIntValueSafe((LiteralOp)input5)
-HopRewriteUtils.getIntValueSafe((LiteralOp)input4)+1 );
}
else if( isBlockIndexingExpression(input4, input5) ) {
setDim2(getBlockIndexingExpressionSize(input4, input5));
}
}
@Override
public Object clone() throws CloneNotSupportedException
{
IndexingOp ret = new IndexingOp();
//copy generic attributes
ret.clone(this, false);
//copy specific attributes
return ret;
}
@Override
public boolean compare( Hop that )
{
if( !(that instanceof IndexingOp)
|| getInput().size() != that.getInput().size() )
{
return false;
}
return ( getInput().get(0) == that.getInput().get(0)
&& getInput().get(1) == that.getInput().get(1)
&& getInput().get(2) == that.getInput().get(2)
&& getInput().get(3) == that.getInput().get(3)
&& getInput().get(4) == that.getInput().get(4));
}
}