blob: 7aa26bdcda7e0cd33e90d3fde8c8a102b20a1a1b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.lops.LeftIndex;
import org.apache.sysds.lops.LeftIndex.LixCacheType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.lops.UnaryCP;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
public class LeftIndexingOp extends Hop
{
public static LeftIndexingMethod FORCED_LEFT_INDEXING = null;
public enum LeftIndexingMethod {
SP_GLEFTINDEX, //general case
SP_MLEFTINDEX_R, //map-only left index, broadcast rhs
SP_MLEFTINDEX_L, //map-only left index, broadcast lhs
}
public static String OPSTRING = "lix"; //"LeftIndexing";
private boolean _rowLowerEqualsUpper = false;
private boolean _colLowerEqualsUpper = false;
private LeftIndexingOp() {
//default constructor for clone
}
public LeftIndexingOp(String l, DataType dt, ValueType vt, Hop inpMatrixLeft, Hop inpMatrixRight, Hop inpRowL, Hop inpRowU, Hop inpColL, Hop inpColU, boolean passedRowsLEU, boolean passedColsLEU) {
super(l, dt, vt);
getInput().add(0, inpMatrixLeft);
getInput().add(1, inpMatrixRight);
getInput().add(2, inpRowL);
getInput().add(3, inpRowU);
getInput().add(4, inpColL);
getInput().add(5, inpColU);
// create hops if one of them is null
inpMatrixLeft.getParent().add(this);
inpMatrixRight.getParent().add(this);
inpRowL.getParent().add(this);
inpRowU.getParent().add(this);
inpColL.getParent().add(this);
inpColU.getParent().add(this);
// set information whether left indexing operation involves row (n x 1) or column (1 x m) matrix
setRowLowerEqualsUpper(passedRowsLEU);
setColLowerEqualsUpper(passedColsLEU);
}
@Override
public void checkArity() {
HopsException.check(_input.size() == 6, this, "should have 6 inputs but has %d inputs", 6);
}
public boolean isRowLowerEqualsUpper(){
return _rowLowerEqualsUpper;
}
public boolean isColLowerEqualsUpper() {
return _colLowerEqualsUpper;
}
public void setRowLowerEqualsUpper(boolean passed){
_rowLowerEqualsUpper = passed;
}
public void setColLowerEqualsUpper(boolean passed) {
_colLowerEqualsUpper = passed;
}
@Override
public boolean isGPUEnabled() {
return false;
}
@Override
public Lop constructLops()
{
//return already created lops
if( getLops() != null )
return getLops();
try
{
ExecType et = optFindExecType();
if(et == ExecType.SPARK)
{
Hop left = getInput().get(0);
Hop right = getInput().get(1);
LeftIndexingMethod method = getOptMethodLeftIndexingMethod(
left.getDim1(), left.getDim2(), left.getBlocksize(), left.getNnz(),
right.getDim1(), right.getDim2(), right.getNnz(), right.getDataType() );
//insert cast to matrix if necessary (for reuse broadcast runtime)
Lop rightInput = right.constructLops();
if (isRightHandSideScalar()) {
rightInput = new UnaryCP(rightInput,
(left.getDataType()==DataType.MATRIX?OpOp1.CAST_AS_MATRIX:OpOp1.CAST_AS_FRAME),
left.getDataType(), right.getValueType());
long bsize = ConfigurationManager.getBlocksize();
rightInput.getOutputParameters().setDimensions( 1, 1, bsize, -1);
}
LeftIndex leftIndexLop = new LeftIndex(
left.constructLops(), rightInput,
getInput().get(2).constructLops(), getInput().get(3).constructLops(),
getInput().get(4).constructLops(), getInput().get(5).constructLops(),
getDataType(), getValueType(), et, getSpLixCacheType(method));
setOutputDimensions(leftIndexLop);
setLineNumbers(leftIndexLop);
setLops(leftIndexLop);
}
else
{
LeftIndex left = new LeftIndex(
getInput().get(0).constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(),
getInput().get(3).constructLops(), getInput().get(4).constructLops(), getInput().get(5).constructLops(),
getDataType(), getValueType(), et);
setOutputDimensions(left);
setLineNumbers(left);
setLops(left);
}
}
catch (Exception e) {
throw new HopsException(this.printErrorLocation() + "In LeftIndexingOp Hop, error in constructing Lops " , e);
}
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
return getLops();
}
/**
* @return true if the right hand side of the indexing operation is a
* literal.
*/
private boolean isRightHandSideScalar() {
Hop rightHandSide = getInput().get(1);
return (rightHandSide.getDataType() == DataType.SCALAR);
}
private static LixCacheType getSpLixCacheType(LeftIndexingMethod method) {
switch( method ) {
case SP_MLEFTINDEX_L: return LixCacheType.LEFT;
case SP_MLEFTINDEX_R: return LixCacheType.RIGHT;
default: return LixCacheType.NONE;
}
}
@Override
public String getOpString() {
String s = new String("");
s += OPSTRING;
return s;
}
@Override
public boolean allowsAllExecTypes() {
return false;
}
@Override
public void computeMemEstimate( MemoTable memo )
{
//overwrites default hops behavior
super.computeMemEstimate(memo);
//changed final estimate (infer and use input size)
Hop rhM = getInput().get(1);
DataCharacteristics dcRhM = memo.getAllInputStats(rhM);
//TODO also use worstcase estimate for output
if( dimsKnown() && !(rhM.dimsKnown()||dcRhM.dimsKnown()) )
{
// unless second input is single cell / row vector / column vector
// use worst-case memory estimate for second input (it cannot be larger than overall matrix)
double subSize = -1;
if( _rowLowerEqualsUpper && _colLowerEqualsUpper )
subSize = OptimizerUtils.estimateSize(1, 1);
else if( _rowLowerEqualsUpper )
subSize = OptimizerUtils.estimateSize(1, getDim2());
else if( _colLowerEqualsUpper )
subSize = OptimizerUtils.estimateSize(getDim1(), 1);
else
subSize = _outputMemEstimate; //worstcase
_memEstimate = getInputSize(0) //original matrix (left)
+ subSize // new submatrix (right)
+ _outputMemEstimate; //output size (output)
}
else if ( dimsKnown() && getNnz()<0 &&
_memEstimate>=OptimizerUtils.DEFAULT_SIZE)
{
//try a last attempt to infer a reasonable estimate wrt output sparsity
//(this is important for indexing sparse matrices into empty matrices).
DataCharacteristics dcM1 = memo.getAllInputStats(getInput().get(0));
DataCharacteristics dcM2 = memo.getAllInputStats(getInput().get(1));
if( dcM1.getNonZeros()>=0 && dcM2.getNonZeros()>=0
&& hasConstantIndexingRange() )
{
long lnnz = dcM1.getNonZeros() + dcM2.getNonZeros();
_outputMemEstimate = computeOutputMemEstimate(getDim1(), getDim2(), lnnz);
_memEstimate = getInputSize(0) //original matrix (left)
+ getInputSize(1) // new submatrix (right)
+ _outputMemEstimate; //output size (output)
}
}
}
@Override
protected double computeOutputMemEstimate( long dim1, long dim2, long nnz )
{
double sparsity = 1.0;
if( nnz < 0 ) //check for exactly known nnz
{
Hop input1 = getInput().get(0);
Hop input2 = getInput().get(1);
if( input1.dimsKnown() && hasConstantIndexingRange() ) {
sparsity = OptimizerUtils.getLeftIndexingSparsity(
input1.getDim1(), input1.getDim2(), input1.getNnz(),
input2.getDim1(), input2.getDim2(), input2.getNnz());
}
}
else {
sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz);
}
// The dimensions of the left indexing output is same as that of the first input i.e., getInput().get(0)
return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
}
@Override
protected double computeIntermediateMemEstimate( long dim1, long dim2, long nnz )
{
return 0;
}
@Override
protected DataCharacteristics inferOutputCharacteristics( MemoTable memo )
{
DataCharacteristics ret = null;
Hop input1 = getInput().get(0); //original matrix
Hop input2 = getInput().get(1); //right matrix
DataCharacteristics dc1 = memo.getAllInputStats(input1);
DataCharacteristics dc2 = memo.getAllInputStats(input2);
if( dc1.dimsKnown() ) {
double sparsity = OptimizerUtils.getLeftIndexingSparsity(
dc1.getRows(), dc1.getCols(), dc1.getNonZeros(),
dc2.getRows(), dc2.getCols(), dc2.getNonZeros());
long lnnz = !hasConstantIndexingRange() ? -1 :
(long)(sparsity * dc1.getRows() * dc1.getCols());
ret = new MatrixCharacteristics(dc1.getRows(), dc1.getCols(), -1, lnnz);
}
return ret;
}
@Override
protected ExecType optFindExecType() {
checkAndSetForcedPlatform();
if( _etypeForced != null )
{
_etype = _etypeForced;
}
else
{
if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
_etype = findExecTypeByMemEstimate();
checkAndModifyRecompilationStatus();
}
else if ( getInput().get(0).areDimsBelowThreshold() )
{
_etype = ExecType.CP;
}
else
{
_etype = ExecType.SPARK;
}
//check for valid CP dimensions and matrix size
checkAndSetInvalidCPDimsAndSize();
}
if( getInput().get(0).getDataType()==DataType.LIST )
_etype = ExecType.CP;
//mark for recompile (forever)
setRequiresRecompileIfNecessary();
return _etype;
}
private static LeftIndexingMethod getOptMethodLeftIndexingMethod(
long m1_dim1, long m1_dim2, long m1_blen, long m1_nnz,
long m2_dim1, long m2_dim2, long m2_nnz, DataType rhsDt)
{
if(FORCED_LEFT_INDEXING != null) {
return FORCED_LEFT_INDEXING;
}
// broadcast-based left indexing w/o shuffle for scalar rhs
if( rhsDt == DataType.SCALAR ) {
return LeftIndexingMethod.SP_MLEFTINDEX_R;
}
// broadcast-based left indexing w/o shuffle for small left/right inputs
if( m2_dim1 >= 1 && m2_dim2 >= 1 && m2_dim1 >= 1 && m2_dim2 >= 1 ) { //lhs/rhs known
boolean isAligned = (rhsDt == DataType.MATRIX) &&
((m1_dim1 == m2_dim1 && m1_dim2 <= m1_blen) || (m1_dim2 == m2_dim2 && m1_dim1 <= m1_blen));
boolean broadcastRhs = OptimizerUtils.checkSparkBroadcastMemoryBudget(m2_dim1, m2_dim2, m1_blen, m2_nnz);
double m1SizeP = OptimizerUtils.estimatePartitionedSizeExactSparsity(m1_dim1, m1_dim2, m1_blen, m1_nnz);
double m2SizeP = OptimizerUtils.estimatePartitionedSizeExactSparsity(m2_dim1, m2_dim2, m1_blen, m2_nnz);
if( broadcastRhs ) {
if( isAligned && m1SizeP<m2SizeP ) //e.g., sparse-dense lix
return LeftIndexingMethod.SP_MLEFTINDEX_L;
else //all other cases, where rhs smaller than lhs
return LeftIndexingMethod.SP_MLEFTINDEX_R;
}
}
// default general case
return LeftIndexingMethod.SP_GLEFTINDEX;
}
@Override
public void refreshSizeInformation()
{
Hop input1 = getInput().get(0); //original matrix
Hop input2 = getInput().get(1); //rhs matrix
//refresh output dimensions based on original matrix
setDim1( input1.getDim1() );
setDim2( input1.getDim2() );
//refresh output nnz if exactly known; otherwise later inference
//note: leveraging the nnz for estimating the output sparsity is
//only valid for constant index identifiers (e.g., after literal
//replacement during dynamic recompilation), otherwise this could
//lead to underestimation and hence OOMs in loops
if( input1.getNnz() == 0 && hasConstantIndexingRange() ) {
if( input2.getDataType()==DataType.SCALAR )
setNnz(1);
else
setNnz(input2.getNnz());
}
else
setNnz(-1);
}
private boolean hasConstantIndexingRange() {
return (getInput().get(2) instanceof LiteralOp
&& getInput().get(3) instanceof LiteralOp
&& getInput().get(4) instanceof LiteralOp
&& getInput().get(5) instanceof LiteralOp);
}
private void checkAndModifyRecompilationStatus()
{
// disable recompile for LIX and second input matrix (under certain conditions)
// if worst-case estimate (2 * original matrix size) was enough to already send it to CP
if( _etype == ExecType.CP )
{
_requiresRecompile = false;
Hop rInput = getInput().get(1);
if( (!rInput.dimsKnown()) && rInput instanceof DataOp )
{
//disable recompile for this dataop (we cannot set requiresRecompile directly
//because we use a top-down traversal for creating lops, hence it would be overwritten)
((DataOp)rInput).disableRecompileRead();
}
}
}
@Override
public Object clone() throws CloneNotSupportedException
{
LeftIndexingOp ret = new LeftIndexingOp();
//copy generic attributes
ret.clone(this, false);
//copy specific attributes
return ret;
}
@Override
public boolean compare( Hop that ) {
if( !(that instanceof LeftIndexingOp)
|| getInput().size() != that.getInput().size() )
{
return false;
}
return getInput().get(0) == that.getInput().get(0)
&& getInput().get(1) == that.getInput().get(1)
&& getInput().get(2) == that.getInput().get(2)
&& getInput().get(3) == that.getInput().get(3)
&& getInput().get(4) == that.getInput().get(4)
&& getInput().get(5) == that.getInput().get(5);
}
}