blob: 55870fd6182fa63c8300026f43e83aed1c82c9df [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Data;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.lops.RightIndex;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
//for now only works for range based indexing op
public class IndexingOp extends Hop
{
public static String OPSTRING = "rix"; //"Indexing";
private boolean _rowLowerEqualsUpper = false;
private boolean _colLowerEqualsUpper = false;
private enum IndexingMethod {
CP_RIX, //in-memory range index
MR_RIX, //general case range reindex
MR_VRIX, //vector (row/col) range index
}
private IndexingOp() {
//default constructor for clone
}
//right indexing doesn't really need the dimensionality of the left matrix
//private static Lops dummy=new Data(null, Data.OperationTypes.READ, null, "-1", DataType.SCALAR, ValueType.INT, false);
public IndexingOp(String l, DataType dt, ValueType vt, Hop inpMatrix, Hop inpRowL, Hop inpRowU, Hop inpColL, Hop inpColU, boolean passedRowsLEU, boolean passedColsLEU) {
super(l, dt, vt);
getInput().add(0, inpMatrix);
getInput().add(1, inpRowL);
getInput().add(2, inpRowU);
getInput().add(3, inpColL);
getInput().add(4, inpColU);
// create hops if one of them is null
inpMatrix.getParent().add(this);
inpRowL.getParent().add(this);
inpRowU.getParent().add(this);
inpColL.getParent().add(this);
inpColU.getParent().add(this);
// set information whether left indexing operation involves row (n x 1) or column (1 x m) matrix
setRowLowerEqualsUpper(passedRowsLEU);
setColLowerEqualsUpper(passedColsLEU);
}
@Override
public void checkArity() {
HopsException.check(_input.size() == 5, this, "should have 5 inputs but has %d inputs", _input.size());
}
public boolean isRowLowerEqualsUpper(){
return _rowLowerEqualsUpper;
}
public boolean isColLowerEqualsUpper() {
return _colLowerEqualsUpper;
}
public void setRowLowerEqualsUpper(boolean passed){
_rowLowerEqualsUpper = passed;
}
public void setColLowerEqualsUpper(boolean passed) {
_colLowerEqualsUpper = passed;
}
@Override
public boolean isGPUEnabled() {
if(!DMLScript.USE_ACCELERATOR) {
return false;
}
else {
// Indexing is only supported on GPU if:
// 1. the input is of type matrix AND
// 2. the input is less than 2GB.
// The second condition is added for following reason:
// 1. Indexing is a purely memory-bound operation and doesnot benefit drastically from pushing down to GPU.
// 2. By forcing larger matrices to GPU (for example: training dataset), we run into risk of unnecessary evictions of
// parameters and the gradients. For single precision, there is additional overhead of converting training dataset
// to single precision every single time it is evicted.
return (getDataType() == DataType.MATRIX) && getInputMemEstimate() < 2e+9;
}
}
@Override
public Lop constructLops()
{
//return already created lops
if( getLops() != null )
return getLops();
Hop input = getInput().get(0);
//rewrite remove unnecessary right indexing
if( HopRewriteUtils.isUnnecessaryRightIndexing(this) ) {
setLops( input.constructLops() );
}
//actual lop construction, incl operator selection
else
{
try {
ExecType et = optFindExecType();
if( et == ExecType.SPARK )
{
IndexingMethod method = optFindIndexingMethod( _rowLowerEqualsUpper, _colLowerEqualsUpper,
input.getDim1(), input.getDim2(), getDim1(), getDim2());
SparkAggType aggtype = (method==IndexingMethod.MR_VRIX || isBlockAligned()) ?
SparkAggType.NONE : SparkAggType.MULTI_BLOCK;
Lop dummy = Data.createLiteralLop(ValueType.INT64, Integer.toString(-1));
RightIndex reindex = new RightIndex(
input.constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(),
getInput().get(3).constructLops(), getInput().get(4).constructLops(), dummy, dummy,
getDataType(), getValueType(), aggtype, et);
setOutputDimensions(reindex);
setLineNumbers(reindex);
setLops(reindex);
}
else //CP or GPU
{
Lop dummy = Data.createLiteralLop(ValueType.INT64, Integer.toString(-1));
RightIndex reindex = new RightIndex(
input.constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(),
getInput().get(3).constructLops(), getInput().get(4).constructLops(), dummy, dummy,
getDataType(), getValueType(), et);
setOutputDimensions(reindex);
setLineNumbers(reindex);
setLops(reindex);
}
} catch (Exception e) {
throw new HopsException(this.printErrorLocation() + "In IndexingOp Hop, error constructing Lops " , e);
}
}
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
return getLops();
}
@Override
public String getOpString() {
String s = new String("");
s += OPSTRING;
return s;
}
@Override
public boolean allowsAllExecTypes()
{
return true;
}
@Override
public void computeMemEstimate( MemoTable memo )
{
//default behavior
super.computeMemEstimate(memo);
//try to infer via worstcase input statistics (for the case of dims known
//but nnz initially unknown)
DataCharacteristics dcM1 = memo.getAllInputStats(getInput().get(0));
if( dimsKnown() && dcM1.getNonZeros()>=0 ){
long lnnz = dcM1.getNonZeros(); //worst-case output nnz
double lOutMemEst = computeOutputMemEstimate(getDim1(), getDim2(), lnnz);
if( lOutMemEst<_outputMemEstimate ){
_outputMemEstimate = lOutMemEst;
_memEstimate = getInputOutputSize();
}
}
}
@Override
protected double computeOutputMemEstimate( long dim1, long dim2, long nnz )
{
// only dense right indexing supported on GPU
double sparsity = isGPUEnabled() ? 1.0 : OptimizerUtils.getSparsity(dim1, dim2, nnz);
return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
}
@Override
protected double computeIntermediateMemEstimate( long dim1, long dim2, long nnz )
{
return 0;
}
@Override
protected DataCharacteristics inferOutputCharacteristics( MemoTable memo )
{
DataCharacteristics ret = null;
Hop input = getInput().get(0); //original matrix
DataCharacteristics dc = memo.getAllInputStats(input);
if( dc != null )
{
long lnnz = dc.dimsKnown()?Math.min(dc.getRows()*dc.getCols(), dc.getNonZeros()):-1;
//worst-case is input size, but dense
ret = new MatrixCharacteristics(dc.getRows(), dc.getCols(), -1, lnnz);
//exploit column/row indexing information
if( _rowLowerEqualsUpper ) ret.setRows(1);
if( _colLowerEqualsUpper ) ret.setCols(1);
//infer tight block indexing size
Hop rl = getInput().get(1);
Hop ru = getInput().get(2);
Hop cl = getInput().get(3);
Hop cu = getInput().get(4);
if( isBlockIndexingExpression(rl, ru) )
ret.setRows(getBlockIndexingExpressionSize(rl, ru));
if( isBlockIndexingExpression(cl, cu) )
ret.setCols(getBlockIndexingExpressionSize(cl, cu));
}
return ret;
}
/**
* Indicates if the lbound:rbound expressions is of the form
* "(c * (i - 1) + 1) : (c * i)", where we could use c as a tight size estimate.
*
* @param lbound lower bound high-level operator
* @param ubound uppser bound high-level operator
* @return true if block indexing expression
*/
private static boolean isBlockIndexingExpression(Hop lbound, Hop ubound)
{
boolean ret = false;
LiteralOp constant = null;
DataOp var = null;
//handle lower bound
if( lbound instanceof BinaryOp && ((BinaryOp)lbound).getOp()==OpOp2.PLUS
&& lbound.getInput().get(1) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValueSafe((LiteralOp)lbound.getInput().get(1))==1
&& lbound.getInput().get(0) instanceof BinaryOp)
{
BinaryOp lmult = (BinaryOp)lbound.getInput().get(0);
if( lmult.getOp()==OpOp2.MULT && lmult.getInput().get(0) instanceof LiteralOp
&& lmult.getInput().get(1) instanceof BinaryOp )
{
BinaryOp lminus = (BinaryOp)lmult.getInput().get(1);
if( lminus.getOp()==OpOp2.MINUS && lminus.getInput().get(1) instanceof LiteralOp
&& HopRewriteUtils.getDoubleValueSafe((LiteralOp)lminus.getInput().get(1))==1
&& lminus.getInput().get(0) instanceof DataOp )
{
constant = (LiteralOp)lmult.getInput().get(0);
var = (DataOp) lminus.getInput().get(0);
}
}
}
//handle upper bound
if( var != null && constant != null && ubound instanceof BinaryOp
&& ubound.getInput().get(0) instanceof LiteralOp
&& ubound.getInput().get(1) instanceof DataOp
&& ubound.getInput().get(1).getName().equals(var.getName()) )
{
LiteralOp constant2 = (LiteralOp)ubound.getInput().get(0);
ret = ( HopRewriteUtils.getDoubleValueSafe(constant) ==
HopRewriteUtils.getDoubleValueSafe(constant2) );
}
return ret;
}
/**
* Indicates if the right indexing ranging is block aligned, i.e., it does not require
* aggregation across blocks due to shifting.
*
* @return true if block aligned
*/
private boolean isBlockAligned() {
Hop input1 = getInput().get(0); //original matrix
Hop input2 = getInput().get(1); //inpRowL
Hop input3 = getInput().get(2); //inpRowU
Hop input4 = getInput().get(3); //inpColL
Hop input5 = getInput().get(4); //inpRowU
long rl = (input2 instanceof LiteralOp) ? (HopRewriteUtils.getIntValueSafe((LiteralOp)input2)) : -1;
long ru = (input3 instanceof LiteralOp) ? (HopRewriteUtils.getIntValueSafe((LiteralOp)input3)) : -1;
long cl = (input4 instanceof LiteralOp) ? (HopRewriteUtils.getIntValueSafe((LiteralOp)input4)) : -1;
long cu = (input5 instanceof LiteralOp) ? (HopRewriteUtils.getIntValueSafe((LiteralOp)input5)) : -1;
int blen = input1.getBlocksize();
return OptimizerUtils.isIndexingRangeBlockAligned(rl, ru, cl, cu, blen);
}
private static long getBlockIndexingExpressionSize(Hop lbound, Hop ubound) {
//NOTE: ensure consistency with isBlockIndexingExpression
LiteralOp c = (LiteralOp) ubound.getInput().get(0); //(c*i)
return HopRewriteUtils.getIntValueSafe(c);
}
@Override
protected ExecType optFindExecType() {
checkAndSetForcedPlatform();
if( _etypeForced != null )
{
_etype = _etypeForced;
}
else
{
if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
_etype = findExecTypeByMemEstimate();
}
else if ( getInput().get(0).areDimsBelowThreshold() )
{
_etype = ExecType.CP;
}
else
{
_etype = ExecType.SPARK;
}
//check for valid CP dimensions and matrix size
checkAndSetInvalidCPDimsAndSize();
}
if( getInput().get(0).getDataType()==DataType.LIST )
_etype = ExecType.CP;
//mark for recompile (forever)
setRequiresRecompileIfNecessary();
return _etype;
}
private static IndexingMethod optFindIndexingMethod( boolean singleRow, boolean singleCol, long m1_dim1, long m1_dim2, long m2_dim1, long m2_dim2 )
{
if( singleRow && m1_dim2 == m2_dim2 && m2_dim2!=-1
|| singleCol && m1_dim1 == m2_dim1 && m2_dim1!=-1 )
{
return IndexingMethod.MR_VRIX;
}
return IndexingMethod.MR_RIX; //general case
}
@Override
public void refreshSizeInformation()
{
Hop input1 = getInput().get(0); //matrix
Hop input2 = getInput().get(1); //inpRowL
Hop input3 = getInput().get(2); //inpRowU
Hop input4 = getInput().get(3); //inpColL
Hop input5 = getInput().get(4); //inpColU
//update single row/column flags (depends on CSE)
_rowLowerEqualsUpper = (input2 == input3);
_colLowerEqualsUpper = (input4 == input5);
//parse input information
boolean allRows = isAllRows();
boolean allCols = isAllCols();
boolean constRowRange = (input2 instanceof LiteralOp && input3 instanceof LiteralOp);
boolean constColRange = (input4 instanceof LiteralOp && input5 instanceof LiteralOp);
//set dimension information
if( _rowLowerEqualsUpper ) //ROWS
setDim1(1);
else if( allRows ) {
setDim1(input1.getDim1());
}
else if( constRowRange ) {
setDim1( HopRewriteUtils.getIntValueSafe((LiteralOp)input3)
-HopRewriteUtils.getIntValueSafe((LiteralOp)input2)+1 );
}
else if( isBlockIndexingExpression(input2, input3) ) {
setDim1(getBlockIndexingExpressionSize(input2, input3));
}
else {
//for reset (e.g., on reconcile after loops)
setDim1(-1);
}
if( _colLowerEqualsUpper ) //COLS
setDim2(1);
else if( allCols ) {
setDim2(input1.getDim2());
}
else if( constColRange ) {
setDim2( HopRewriteUtils.getIntValueSafe((LiteralOp)input5)
-HopRewriteUtils.getIntValueSafe((LiteralOp)input4)+1 );
}
else if( isBlockIndexingExpression(input4, input5) ) {
setDim2(getBlockIndexingExpressionSize(input4, input5));
}
else {
//for reset (e.g., on reconcile after loops)
setDim2(-1);
}
}
public boolean isAllRows() {
Hop input1 = getInput().get(0);
Hop input2 = getInput().get(1);
Hop input3 = getInput().get(2);
return HopRewriteUtils.isLiteralOfValue(input2, 1)
&& ((HopRewriteUtils.isUnary(input3, OpOp1.NROW) && input3.getInput().get(0) == input1 )
|| HopRewriteUtils.isLiteralOfValue(input3, input1.getDim1()));
}
public boolean isAllCols() {
Hop input1 = getInput().get(0);
Hop input4 = getInput().get(3);
Hop input5 = getInput().get(4);
return HopRewriteUtils.isLiteralOfValue(input4, 1)
&& ((HopRewriteUtils.isUnary(input5, OpOp1.NCOL) && input5.getInput().get(0) == input1 )
|| HopRewriteUtils.isLiteralOfValue(input5, input1.getDim2()));
}
@Override
public Object clone() throws CloneNotSupportedException {
IndexingOp ret = new IndexingOp();
//copy generic attributes
ret.clone(this, false);
//copy specific attributes
return ret;
}
@Override
public boolean compare( Hop that )
{
if( !(that instanceof IndexingOp)
|| getInput().size() != that.getInput().size() )
{
return false;
}
return getInput().get(0) == that.getInput().get(0)
&& getInput().get(1) == that.getInput().get(1)
&& getInput().get(2) == that.getInput().get(2)
&& getInput().get(3) == that.getInput().get(3)
&& getInput().get(4) == that.getInput().get(4);
}
}