blob: 8025fc055bbdadc1e98a0db314ee3b2fa5a95254 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.lops.Binary;
import org.apache.sysml.lops.Group;
import org.apache.sysml.lops.LeftIndex;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.lops.RangeBasedReIndex;
import org.apache.sysml.lops.UnaryCP;
import org.apache.sysml.lops.ZeroOut;
import org.apache.sysml.lops.LopProperties.ExecType;
import org.apache.sysml.lops.UnaryCP.OperationTypes;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
public class LeftIndexingOp extends Hop
{
public static LeftIndexingMethod FORCED_LEFT_INDEXING = null;
public enum LeftIndexingMethod {
SP_GLEFTINDEX, // general case
SP_MLEFTINDEX //map-only left index where we broadcast right hand side matrix
}
public static String OPSTRING = "lix"; //"LeftIndexing";
private boolean _rowLowerEqualsUpper = false;
private boolean _colLowerEqualsUpper = false;
private LeftIndexingOp() {
//default constructor for clone
}
public LeftIndexingOp(String l, DataType dt, ValueType vt, Hop inpMatrixLeft, Hop inpMatrixRight, Hop inpRowL, Hop inpRowU, Hop inpColL, Hop inpColU, boolean passedRowsLEU, boolean passedColsLEU) {
super(l, dt, vt);
getInput().add(0, inpMatrixLeft);
getInput().add(1, inpMatrixRight);
getInput().add(2, inpRowL);
getInput().add(3, inpRowU);
getInput().add(4, inpColL);
getInput().add(5, inpColU);
// create hops if one of them is null
inpMatrixLeft.getParent().add(this);
inpMatrixRight.getParent().add(this);
inpRowL.getParent().add(this);
inpRowU.getParent().add(this);
inpColL.getParent().add(this);
inpColU.getParent().add(this);
// set information whether left indexing operation involves row (n x 1) or column (1 x m) matrix
setRowLowerEqualsUpper(passedRowsLEU);
setColLowerEqualsUpper(passedColsLEU);
}
public boolean getRowLowerEqualsUpper(){
return _rowLowerEqualsUpper;
}
public boolean getColLowerEqualsUpper() {
return _colLowerEqualsUpper;
}
public void setRowLowerEqualsUpper(boolean passed){
_rowLowerEqualsUpper = passed;
}
public void setColLowerEqualsUpper(boolean passed) {
_colLowerEqualsUpper = passed;
}
@Override
public Lop constructLops()
throws HopsException, LopsException
{
//return already created lops
if( getLops() != null )
return getLops();
try
{
ExecType et = optFindExecType();
if(et == ExecType.MR)
{
//the right matrix is reindexed
Lop top=getInput().get(2).constructLops();
Lop bottom=getInput().get(3).constructLops();
Lop left=getInput().get(4).constructLops();
Lop right=getInput().get(5).constructLops();
/*
//need to creat new lops for converting the index ranges
//original range is (a, b) --> (c, d)
//newa=2-a, newb=2-b
Lops two=new Data(null, Data.OperationTypes.READ, null, "2", Expression.DataType.SCALAR, Expression.ValueType.INT, false);
Lops newTop=new Binary(two, top, HopsOpOp2LopsB.get(Hops.OpOp2.MINUS), Expression.DataType.SCALAR, Expression.ValueType.INT, et);
Lops newLeft=new Binary(two, left, HopsOpOp2LopsB.get(Hops.OpOp2.MINUS), Expression.DataType.SCALAR, Expression.ValueType.INT, et);
//newc=leftmatrix.row-a+1, newd=leftmatrix.row
*/
//right hand matrix
Lop nrow=new UnaryCP(getInput().get(0).constructLops(),
OperationTypes.NROW, DataType.SCALAR, ValueType.INT);
Lop ncol=new UnaryCP(getInput().get(0).constructLops(),
OperationTypes.NCOL, DataType.SCALAR, ValueType.INT);
Lop rightInput = null;
if (isRightHandSideScalar()) {
//insert cast to matrix if necessary (for reuse MR runtime)
rightInput = new UnaryCP(getInput().get(1).constructLops(),
OperationTypes.CAST_AS_MATRIX,
DataType.MATRIX, ValueType.DOUBLE);
rightInput.getOutputParameters().setDimensions( (long)1, (long)1,
(long)ConfigurationManager.getBlocksize(),
(long)ConfigurationManager.getBlocksize(),
(long)-1);
}
else
rightInput = getInput().get(1).constructLops();
RangeBasedReIndex reindex = new RangeBasedReIndex(
rightInput, top, bottom,
left, right, nrow, ncol,
getDataType(), getValueType(), et, true);
reindex.getOutputParameters().setDimensions(getInput().get(0).getDim1(), getInput().get(0).getDim2(),
getRowsInBlock(), getColsInBlock(), getNnz());
setLineNumbers(reindex);
Group group1 = new Group(
reindex, Group.OperationTypes.Sort, DataType.MATRIX,
getValueType());
group1.getOutputParameters().setDimensions(getInput().get(0).getDim1(), getInput().get(0).getDim2(),
getRowsInBlock(), getColsInBlock(), getNnz());
setLineNumbers(group1);
//the left matrix is zeroed out
ZeroOut zeroout = new ZeroOut(
getInput().get(0).constructLops(), top, bottom,
left, right, getInput().get(0).getDim1(), getInput().get(0).getDim2(),
getDataType(), getValueType(), et);
zeroout.getOutputParameters().setDimensions(getInput().get(0).getDim1(), getInput().get(0).getDim2(),
getRowsInBlock(), getColsInBlock(), getNnz());
setLineNumbers(zeroout);
Group group2 = new Group(
zeroout, Group.OperationTypes.Sort, DataType.MATRIX,
getValueType());
group2.getOutputParameters().setDimensions(getInput().get(0).getDim1(), getInput().get(0).getDim2(),
getRowsInBlock(), getColsInBlock(), getNnz());
setLineNumbers(group2);
Binary binary = new Binary(group1, group2, HopsOpOp2LopsB.get(Hop.OpOp2.PLUS),
getDataType(), getValueType(), et);
binary.getOutputParameters().setDimensions(getInput().get(0).getDim1(), getInput().get(0).getDim2(),
getRowsInBlock(), getColsInBlock(), getNnz());
setLineNumbers(binary);
setLops(binary);
}
else if(et == ExecType.SPARK)
{
Hop left = getInput().get(0);
Hop right = getInput().get(1);
LeftIndexingMethod method = getOptMethodLeftIndexingMethod( right.getDim1(), right.getDim2(),
right.getRowsInBlock(), right.getColsInBlock(), right.getNnz(), getDataType()==DataType.SCALAR );
boolean isBroadcast = (method == LeftIndexingMethod.SP_MLEFTINDEX);
//insert cast to matrix if necessary (for reuse broadcast runtime)
Lop rightInput = right.constructLops();
if (isRightHandSideScalar()) {
rightInput = new UnaryCP(rightInput, (left.getDataType()==DataType.MATRIX?OperationTypes.CAST_AS_MATRIX:OperationTypes.CAST_AS_FRAME),
left.getDataType(), right.getValueType());
long bsize = ConfigurationManager.getBlocksize();
rightInput.getOutputParameters().setDimensions( 1, 1, bsize, bsize, -1);
}
LeftIndex leftIndexLop = new LeftIndex(
left.constructLops(), rightInput,
getInput().get(2).constructLops(), getInput().get(3).constructLops(),
getInput().get(4).constructLops(), getInput().get(5).constructLops(),
getDataType(), getValueType(), et, isBroadcast);
setOutputDimensions(leftIndexLop);
setLineNumbers(leftIndexLop);
setLops(leftIndexLop);
}
else
{
LeftIndex left = new LeftIndex(
getInput().get(0).constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(),
getInput().get(3).constructLops(), getInput().get(4).constructLops(), getInput().get(5).constructLops(),
getDataType(), getValueType(), et);
setOutputDimensions(left);
setLineNumbers(left);
setLops(left);
}
}
catch (Exception e) {
throw new HopsException(this.printErrorLocation() + "In LeftIndexingOp Hop, error in constructing Lops " , e);
}
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
return getLops();
}
/**
* @return true if the right hand side of the indexing operation is a
* literal.
*/
private boolean isRightHandSideScalar() {
Hop rightHandSide = getInput().get(1);
return (rightHandSide.getDataType() == DataType.SCALAR);
}
@Override
public String getOpString() {
String s = new String("");
s += OPSTRING;
return s;
}
public void printMe() throws HopsException {
if (getVisited() != VisitStatus.DONE) {
super.printMe();
for (Hop h : getInput()) {
h.printMe();
}
;
}
setVisited(VisitStatus.DONE);
}
@Override
public boolean allowsAllExecTypes()
{
return false;
}
@Override
public void computeMemEstimate( MemoTable memo )
{
//overwrites default hops behavior
super.computeMemEstimate(memo);
//changed final estimate (infer and use input size)
Hop rhM = getInput().get(1);
MatrixCharacteristics mcRhM = memo.getAllInputStats(rhM);
//TODO also use worstcase estimate for output
if( dimsKnown() && !(rhM.dimsKnown()||mcRhM.dimsKnown()) )
{
// unless second input is single cell / row vector / column vector
// use worst-case memory estimate for second input (it cannot be larger than overall matrix)
double subSize = -1;
if( _rowLowerEqualsUpper && _colLowerEqualsUpper )
subSize = OptimizerUtils.estimateSize(1, 1);
else if( _rowLowerEqualsUpper )
subSize = OptimizerUtils.estimateSize(1, _dim2);
else if( _colLowerEqualsUpper )
subSize = OptimizerUtils.estimateSize(_dim1, 1);
else
subSize = _outputMemEstimate; //worstcase
_memEstimate = getInputSize(0) //original matrix (left)
+ subSize // new submatrix (right)
+ _outputMemEstimate; //output size (output)
}
else if ( dimsKnown() && _nnz<0 &&
_memEstimate>=OptimizerUtils.DEFAULT_SIZE)
{
//try a last attempt to infer a reasonable estimate wrt output sparsity
//(this is important for indexing sparse matrices into empty matrices).
MatrixCharacteristics mcM1 = memo.getAllInputStats(getInput().get(0));
MatrixCharacteristics mcM2 = memo.getAllInputStats(getInput().get(1));
if( mcM1.getNonZeros()>=0 && mcM2.getNonZeros()>=0 ) {
long lnnz = mcM1.getNonZeros() + mcM2.getNonZeros();
_outputMemEstimate = computeOutputMemEstimate( _dim1, _dim2, lnnz );
_memEstimate = getInputSize(0) //original matrix (left)
+ getInputSize(1) // new submatrix (right)
+ _outputMemEstimate; //output size (output)
}
}
}
@Override
protected double computeOutputMemEstimate( long dim1, long dim2, long nnz )
{
double sparsity = 1.0;
if( nnz < 0 ) //check for exactly known nnz
{
Hop input1 = getInput().get(0);
Hop input2 = getInput().get(1);
if( input1.dimsKnown() ) {
sparsity = OptimizerUtils.getLeftIndexingSparsity(
input1.getDim1(), input1.getDim2(), input1.getNnz(),
input2.getDim1(), input2.getDim2(), input2.getNnz());
}
}
else
{
sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz);
}
// The dimensions of the left indexing output is same as that of the first input i.e., getInput().get(0)
return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
}
@Override
protected double computeIntermediateMemEstimate( long dim1, long dim2, long nnz )
{
return 0;
}
@Override
protected long[] inferOutputCharacteristics( MemoTable memo )
{
long[] ret = null;
Hop input1 = getInput().get(0); //original matrix
Hop input2 = getInput().get(1); //right matrix
MatrixCharacteristics mc1 = memo.getAllInputStats(input1);
MatrixCharacteristics mc2 = memo.getAllInputStats(input2);
if( mc1.dimsKnown() ) {
double sparsity = OptimizerUtils.getLeftIndexingSparsity(
mc1.getRows(), mc1.getCols(), mc1.getNonZeros(),
mc2.getRows(), mc2.getCols(), mc2.getNonZeros());
long lnnz = (long)(sparsity * mc1.getRows() * mc1.getCols());
ret = new long[]{mc1.getRows(), mc1.getCols(), lnnz};
}
return ret;
}
@Override
protected ExecType optFindExecType() throws HopsException {
checkAndSetForcedPlatform();
ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
if( _etypeForced != null )
{
_etype = _etypeForced;
}
else
{
if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
_etype = findExecTypeByMemEstimate();
checkAndModifyRecompilationStatus();
}
else if ( getInput().get(0).areDimsBelowThreshold() )
{
_etype = ExecType.CP;
}
else
{
_etype = REMOTE;
}
//check for valid CP dimensions and matrix size
checkAndSetInvalidCPDimsAndSize();
}
//mark for recompile (forever)
if( ConfigurationManager.isDynamicRecompilation() && !dimsKnown(true) && _etype==REMOTE )
setRequiresRecompile();
return _etype;
}
/**
*
* @param m1_dim1
* @param m1_dim2
* @param m1_rpb
* @param m1_cpb
* @param m2_dim1
* @param m2_dim2
* @return
*/
private LeftIndexingMethod getOptMethodLeftIndexingMethod( long m2_dim1, long m2_dim2,
long m2_rpb, long m2_cpb, long m2_nnz, boolean isScalar)
{
if(FORCED_LEFT_INDEXING != null) {
return FORCED_LEFT_INDEXING;
}
// broadcast-based left indexing has memory constraints but is more efficient
// since it does not require shuffle
if( isScalar || m2_dim1 >= 1 && m2_dim2 >= 1 // rhs dims known
&& OptimizerUtils.checkSparkBroadcastMemoryBudget(m2_dim1, m2_dim2, m2_rpb, m2_cpb, m2_nnz) )
{
return LeftIndexingMethod.SP_MLEFTINDEX;
}
// default general case
return LeftIndexingMethod.SP_GLEFTINDEX;
}
@Override
public void refreshSizeInformation()
{
Hop input1 = getInput().get(0); //original matrix
Hop input2 = getInput().get(1); //rhs matrix
//refresh output dimensions based on original matrix
setDim1( input1.getDim1() );
setDim2( input1.getDim2() );
//refresh output nnz if exactly known; otherwise later inference
if( input1.getNnz() == 0 ) {
if( input2.getDataType()==DataType.SCALAR )
setNnz(1);
else
setNnz(input2.getNnz());
}
else
setNnz(-1);
}
/**
*
*/
private void checkAndModifyRecompilationStatus()
{
// disable recompile for LIX and second input matrix (under certain conditions)
// if worst-case estimate (2 * original matrix size) was enough to already send it to CP
if( _etype == ExecType.CP )
{
_requiresRecompile = false;
Hop rInput = getInput().get(1);
if( (!rInput.dimsKnown()) && rInput instanceof DataOp )
{
//disable recompile for this dataop (we cannot set requiresRecompile directly
//because we use a top-down traversal for creating lops, hence it would be overwritten)
((DataOp)rInput).disableRecompileRead();
}
}
}
@Override
public Object clone() throws CloneNotSupportedException
{
LeftIndexingOp ret = new LeftIndexingOp();
//copy generic attributes
ret.clone(this, false);
//copy specific attributes
return ret;
}
@Override
public boolean compare( Hop that )
{
if( !(that instanceof LeftIndexingOp)
|| getInput().size() != that.getInput().size() )
{
return false;
}
return ( getInput().get(0) == that.getInput().get(0)
&& getInput().get(1) == that.getInput().get(1)
&& getInput().get(2) == that.getInput().get(2)
&& getInput().get(3) == that.getInput().get(3)
&& getInput().get(4) == that.getInput().get(4)
&& getInput().get(5) == that.getInput().get(5));
}
}