| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.hops; |
| |
| import org.apache.sysds.common.Types.DataType; |
| import org.apache.sysds.common.Types.OpOp1; |
| import org.apache.sysds.common.Types.ValueType; |
| import org.apache.sysds.conf.ConfigurationManager; |
| import org.apache.sysds.lops.LeftIndex; |
| import org.apache.sysds.lops.LeftIndex.LixCacheType; |
| import org.apache.sysds.lops.Lop; |
| import org.apache.sysds.lops.LopProperties.ExecType; |
| import org.apache.sysds.lops.UnaryCP; |
| import org.apache.sysds.runtime.meta.DataCharacteristics; |
| import org.apache.sysds.runtime.meta.MatrixCharacteristics; |
| |
| public class LeftIndexingOp extends Hop |
| { |
| public static LeftIndexingMethod FORCED_LEFT_INDEXING = null; |
| |
| public enum LeftIndexingMethod { |
| SP_GLEFTINDEX, //general case |
| SP_MLEFTINDEX_R, //map-only left index, broadcast rhs |
| SP_MLEFTINDEX_L, //map-only left index, broadcast lhs |
| } |
| |
| public static String OPSTRING = "lix"; //"LeftIndexing"; |
| |
| private boolean _rowLowerEqualsUpper = false; |
| private boolean _colLowerEqualsUpper = false; |
| |
| private LeftIndexingOp() { |
| //default constructor for clone |
| } |
| |
| public LeftIndexingOp(String l, DataType dt, ValueType vt, Hop inpMatrixLeft, Hop inpMatrixRight, Hop inpRowL, Hop inpRowU, Hop inpColL, Hop inpColU, boolean passedRowsLEU, boolean passedColsLEU) { |
| super(l, dt, vt); |
| |
| getInput().add(0, inpMatrixLeft); |
| getInput().add(1, inpMatrixRight); |
| getInput().add(2, inpRowL); |
| getInput().add(3, inpRowU); |
| getInput().add(4, inpColL); |
| getInput().add(5, inpColU); |
| |
| // create hops if one of them is null |
| inpMatrixLeft.getParent().add(this); |
| inpMatrixRight.getParent().add(this); |
| inpRowL.getParent().add(this); |
| inpRowU.getParent().add(this); |
| inpColL.getParent().add(this); |
| inpColU.getParent().add(this); |
| |
| // set information whether left indexing operation involves row (n x 1) or column (1 x m) matrix |
| setRowLowerEqualsUpper(passedRowsLEU); |
| setColLowerEqualsUpper(passedColsLEU); |
| } |
| |
| @Override |
| public void checkArity() { |
| HopsException.check(_input.size() == 6, this, "should have 6 inputs but has %d inputs", 6); |
| } |
| |
| public boolean isRowLowerEqualsUpper(){ |
| return _rowLowerEqualsUpper; |
| } |
| |
| public boolean isColLowerEqualsUpper() { |
| return _colLowerEqualsUpper; |
| } |
| |
| public void setRowLowerEqualsUpper(boolean passed){ |
| _rowLowerEqualsUpper = passed; |
| } |
| |
| public void setColLowerEqualsUpper(boolean passed) { |
| _colLowerEqualsUpper = passed; |
| } |
| |
| @Override |
| public boolean isGPUEnabled() { |
| return false; |
| } |
| |
| @Override |
| public Lop constructLops() |
| { |
| //return already created lops |
| if( getLops() != null ) |
| return getLops(); |
| |
| try |
| { |
| ExecType et = optFindExecType(); |
| |
| if(et == ExecType.SPARK) |
| { |
| Hop left = getInput().get(0); |
| Hop right = getInput().get(1); |
| |
| LeftIndexingMethod method = getOptMethodLeftIndexingMethod( |
| left.getDim1(), left.getDim2(), left.getBlocksize(), left.getNnz(), |
| right.getDim1(), right.getDim2(), right.getNnz(), right.getDataType() ); |
| |
| //insert cast to matrix if necessary (for reuse broadcast runtime) |
| Lop rightInput = right.constructLops(); |
| if (isRightHandSideScalar()) { |
| rightInput = new UnaryCP(rightInput, |
| (left.getDataType()==DataType.MATRIX?OpOp1.CAST_AS_MATRIX:OpOp1.CAST_AS_FRAME), |
| left.getDataType(), right.getValueType()); |
| long bsize = ConfigurationManager.getBlocksize(); |
| rightInput.getOutputParameters().setDimensions( 1, 1, bsize, -1); |
| } |
| |
| LeftIndex leftIndexLop = new LeftIndex( |
| left.constructLops(), rightInput, |
| getInput().get(2).constructLops(), getInput().get(3).constructLops(), |
| getInput().get(4).constructLops(), getInput().get(5).constructLops(), |
| getDataType(), getValueType(), et, getSpLixCacheType(method)); |
| |
| setOutputDimensions(leftIndexLop); |
| setLineNumbers(leftIndexLop); |
| setLops(leftIndexLop); |
| } |
| else |
| { |
| LeftIndex left = new LeftIndex( |
| getInput().get(0).constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(), |
| getInput().get(3).constructLops(), getInput().get(4).constructLops(), getInput().get(5).constructLops(), |
| getDataType(), getValueType(), et); |
| |
| setOutputDimensions(left); |
| setLineNumbers(left); |
| setLops(left); |
| } |
| } |
| catch (Exception e) { |
| throw new HopsException(this.printErrorLocation() + "In LeftIndexingOp Hop, error in constructing Lops " , e); |
| } |
| |
| //add reblock/checkpoint lops if necessary |
| constructAndSetLopsDataFlowProperties(); |
| |
| return getLops(); |
| } |
| |
| /** |
| * @return true if the right hand side of the indexing operation is a |
| * literal. |
| */ |
| private boolean isRightHandSideScalar() { |
| Hop rightHandSide = getInput().get(1); |
| return (rightHandSide.getDataType() == DataType.SCALAR); |
| } |
| |
| private static LixCacheType getSpLixCacheType(LeftIndexingMethod method) { |
| switch( method ) { |
| case SP_MLEFTINDEX_L: return LixCacheType.LEFT; |
| case SP_MLEFTINDEX_R: return LixCacheType.RIGHT; |
| default: return LixCacheType.NONE; |
| } |
| } |
| |
| @Override |
| public String getOpString() { |
| String s = new String(""); |
| s += OPSTRING; |
| return s; |
| } |
| |
| @Override |
| public boolean allowsAllExecTypes() { |
| return false; |
| } |
| |
| @Override |
| public void computeMemEstimate( MemoTable memo ) |
| { |
| //overwrites default hops behavior |
| super.computeMemEstimate(memo); |
| |
| //changed final estimate (infer and use input size) |
| Hop rhM = getInput().get(1); |
| DataCharacteristics dcRhM = memo.getAllInputStats(rhM); |
| //TODO also use worstcase estimate for output |
| if( dimsKnown() && !(rhM.dimsKnown()||dcRhM.dimsKnown()) ) |
| { |
| // unless second input is single cell / row vector / column vector |
| // use worst-case memory estimate for second input (it cannot be larger than overall matrix) |
| double subSize = -1; |
| if( _rowLowerEqualsUpper && _colLowerEqualsUpper ) |
| subSize = OptimizerUtils.estimateSize(1, 1); |
| else if( _rowLowerEqualsUpper ) |
| subSize = OptimizerUtils.estimateSize(1, getDim2()); |
| else if( _colLowerEqualsUpper ) |
| subSize = OptimizerUtils.estimateSize(getDim1(), 1); |
| else |
| subSize = _outputMemEstimate; //worstcase |
| |
| _memEstimate = getInputSize(0) //original matrix (left) |
| + subSize // new submatrix (right) |
| + _outputMemEstimate; //output size (output) |
| } |
| else if ( dimsKnown() && getNnz()<0 && |
| _memEstimate>=OptimizerUtils.DEFAULT_SIZE) |
| { |
| //try a last attempt to infer a reasonable estimate wrt output sparsity |
| //(this is important for indexing sparse matrices into empty matrices). |
| DataCharacteristics dcM1 = memo.getAllInputStats(getInput().get(0)); |
| DataCharacteristics dcM2 = memo.getAllInputStats(getInput().get(1)); |
| if( dcM1.getNonZeros()>=0 && dcM2.getNonZeros()>=0 |
| && hasConstantIndexingRange() ) |
| { |
| long lnnz = dcM1.getNonZeros() + dcM2.getNonZeros(); |
| _outputMemEstimate = computeOutputMemEstimate(getDim1(), getDim2(), lnnz); |
| _memEstimate = getInputSize(0) //original matrix (left) |
| + getInputSize(1) // new submatrix (right) |
| + _outputMemEstimate; //output size (output) |
| } |
| } |
| } |
| |
| @Override |
| protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) |
| { |
| double sparsity = 1.0; |
| if( nnz < 0 ) //check for exactly known nnz |
| { |
| Hop input1 = getInput().get(0); |
| Hop input2 = getInput().get(1); |
| if( input1.dimsKnown() && hasConstantIndexingRange() ) { |
| sparsity = OptimizerUtils.getLeftIndexingSparsity( |
| input1.getDim1(), input1.getDim2(), input1.getNnz(), |
| input2.getDim1(), input2.getDim2(), input2.getNnz()); |
| } |
| } |
| else { |
| sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz); |
| } |
| |
| // The dimensions of the left indexing output is same as that of the first input i.e., getInput().get(0) |
| return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); |
| } |
| |
| @Override |
| protected double computeIntermediateMemEstimate( long dim1, long dim2, long nnz ) |
| { |
| return 0; |
| } |
| |
| @Override |
| protected DataCharacteristics inferOutputCharacteristics( MemoTable memo ) |
| { |
| DataCharacteristics ret = null; |
| |
| Hop input1 = getInput().get(0); //original matrix |
| Hop input2 = getInput().get(1); //right matrix |
| DataCharacteristics dc1 = memo.getAllInputStats(input1); |
| DataCharacteristics dc2 = memo.getAllInputStats(input2); |
| |
| if( dc1.dimsKnown() ) { |
| double sparsity = OptimizerUtils.getLeftIndexingSparsity( |
| dc1.getRows(), dc1.getCols(), dc1.getNonZeros(), |
| dc2.getRows(), dc2.getCols(), dc2.getNonZeros()); |
| long lnnz = !hasConstantIndexingRange() ? -1 : |
| (long)(sparsity * dc1.getRows() * dc1.getCols()); |
| ret = new MatrixCharacteristics(dc1.getRows(), dc1.getCols(), -1, lnnz); |
| } |
| |
| return ret; |
| } |
| |
| |
| @Override |
| protected ExecType optFindExecType() { |
| |
| checkAndSetForcedPlatform(); |
| |
| if( _etypeForced != null ) |
| { |
| _etype = _etypeForced; |
| } |
| else |
| { |
| if ( OptimizerUtils.isMemoryBasedOptLevel() ) { |
| _etype = findExecTypeByMemEstimate(); |
| checkAndModifyRecompilationStatus(); |
| } |
| else if ( getInput().get(0).areDimsBelowThreshold() ) |
| { |
| _etype = ExecType.CP; |
| } |
| else |
| { |
| _etype = ExecType.SPARK; |
| } |
| |
| //check for valid CP dimensions and matrix size |
| checkAndSetInvalidCPDimsAndSize(); |
| } |
| |
| if( getInput().get(0).getDataType()==DataType.LIST ) |
| _etype = ExecType.CP; |
| |
| //mark for recompile (forever) |
| setRequiresRecompileIfNecessary(); |
| |
| return _etype; |
| } |
| |
| private static LeftIndexingMethod getOptMethodLeftIndexingMethod( |
| long m1_dim1, long m1_dim2, long m1_blen, long m1_nnz, |
| long m2_dim1, long m2_dim2, long m2_nnz, DataType rhsDt) |
| { |
| if(FORCED_LEFT_INDEXING != null) { |
| return FORCED_LEFT_INDEXING; |
| } |
| |
| // broadcast-based left indexing w/o shuffle for scalar rhs |
| if( rhsDt == DataType.SCALAR ) { |
| return LeftIndexingMethod.SP_MLEFTINDEX_R; |
| } |
| |
| // broadcast-based left indexing w/o shuffle for small left/right inputs |
| if( m2_dim1 >= 1 && m2_dim2 >= 1 && m2_dim1 >= 1 && m2_dim2 >= 1 ) { //lhs/rhs known |
| boolean isAligned = (rhsDt == DataType.MATRIX) && |
| ((m1_dim1 == m2_dim1 && m1_dim2 <= m1_blen) || (m1_dim2 == m2_dim2 && m1_dim1 <= m1_blen)); |
| boolean broadcastRhs = OptimizerUtils.checkSparkBroadcastMemoryBudget(m2_dim1, m2_dim2, m1_blen, m2_nnz); |
| double m1SizeP = OptimizerUtils.estimatePartitionedSizeExactSparsity(m1_dim1, m1_dim2, m1_blen, m1_nnz); |
| double m2SizeP = OptimizerUtils.estimatePartitionedSizeExactSparsity(m2_dim1, m2_dim2, m1_blen, m2_nnz); |
| |
| if( broadcastRhs ) { |
| if( isAligned && m1SizeP<m2SizeP ) //e.g., sparse-dense lix |
| return LeftIndexingMethod.SP_MLEFTINDEX_L; |
| else //all other cases, where rhs smaller than lhs |
| return LeftIndexingMethod.SP_MLEFTINDEX_R; |
| } |
| } |
| |
| // default general case |
| return LeftIndexingMethod.SP_GLEFTINDEX; |
| } |
| |
| |
| @Override |
| public void refreshSizeInformation() |
| { |
| Hop input1 = getInput().get(0); //original matrix |
| Hop input2 = getInput().get(1); //rhs matrix |
| |
| //refresh output dimensions based on original matrix |
| setDim1( input1.getDim1() ); |
| setDim2( input1.getDim2() ); |
| |
| //refresh output nnz if exactly known; otherwise later inference |
| //note: leveraging the nnz for estimating the output sparsity is |
| //only valid for constant index identifiers (e.g., after literal |
| //replacement during dynamic recompilation), otherwise this could |
| //lead to underestimation and hence OOMs in loops |
| if( input1.getNnz() == 0 && hasConstantIndexingRange() ) { |
| if( input2.getDataType()==DataType.SCALAR ) |
| setNnz(1); |
| else |
| setNnz(input2.getNnz()); |
| } |
| else |
| setNnz(-1); |
| } |
| |
| private boolean hasConstantIndexingRange() { |
| return (getInput().get(2) instanceof LiteralOp |
| && getInput().get(3) instanceof LiteralOp |
| && getInput().get(4) instanceof LiteralOp |
| && getInput().get(5) instanceof LiteralOp); |
| } |
| |
| private void checkAndModifyRecompilationStatus() |
| { |
| // disable recompile for LIX and second input matrix (under certain conditions) |
| // if worst-case estimate (2 * original matrix size) was enough to already send it to CP |
| |
| if( _etype == ExecType.CP ) |
| { |
| _requiresRecompile = false; |
| |
| Hop rInput = getInput().get(1); |
| if( (!rInput.dimsKnown()) && rInput instanceof DataOp ) |
| { |
| //disable recompile for this dataop (we cannot set requiresRecompile directly |
| //because we use a top-down traversal for creating lops, hence it would be overwritten) |
| |
| ((DataOp)rInput).disableRecompileRead(); |
| } |
| } |
| } |
| |
| @Override |
| public Object clone() throws CloneNotSupportedException |
| { |
| LeftIndexingOp ret = new LeftIndexingOp(); |
| |
| //copy generic attributes |
| ret.clone(this, false); |
| |
| //copy specific attributes |
| |
| return ret; |
| } |
| |
| @Override |
| public boolean compare( Hop that ) { |
| if( !(that instanceof LeftIndexingOp) |
| || getInput().size() != that.getInput().size() ) |
| { |
| return false; |
| } |
| |
| return getInput().get(0) == that.getInput().get(0) |
| && getInput().get(1) == that.getInput().get(1) |
| && getInput().get(2) == that.getInput().get(2) |
| && getInput().get(3) == that.getInput().get(3) |
| && getInput().get(4) == that.getInput().get(4) |
| && getInput().get(5) == that.getInput().get(5); |
| } |
| |
| } |