| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.hops; |
| |
| import org.apache.sysds.api.DMLScript; |
| import org.apache.sysds.common.Types.DataType; |
| import org.apache.sysds.common.Types.OpOp1; |
| import org.apache.sysds.common.Types.OpOp2; |
| import org.apache.sysds.common.Types.ValueType; |
| import org.apache.sysds.hops.AggBinaryOp.SparkAggType; |
| import org.apache.sysds.hops.rewrite.HopRewriteUtils; |
| import org.apache.sysds.lops.Data; |
| import org.apache.sysds.lops.Lop; |
| import org.apache.sysds.lops.LopProperties.ExecType; |
| import org.apache.sysds.lops.RightIndex; |
| import org.apache.sysds.runtime.meta.DataCharacteristics; |
| import org.apache.sysds.runtime.meta.MatrixCharacteristics; |
| |
| //for now only works for range based indexing op |
| public class IndexingOp extends Hop |
| { |
| public static String OPSTRING = "rix"; //"Indexing"; |
| |
| private boolean _rowLowerEqualsUpper = false; |
| private boolean _colLowerEqualsUpper = false; |
| |
| private enum IndexingMethod { |
| CP_RIX, //in-memory range index |
| MR_RIX, //general case range reindex |
| MR_VRIX, //vector (row/col) range index |
| } |
| |
| |
| private IndexingOp() { |
| //default constructor for clone |
| } |
| |
| //right indexing doesn't really need the dimensionality of the left matrix |
| //private static Lops dummy=new Data(null, Data.OperationTypes.READ, null, "-1", DataType.SCALAR, ValueType.INT, false); |
| public IndexingOp(String l, DataType dt, ValueType vt, Hop inpMatrix, Hop inpRowL, Hop inpRowU, Hop inpColL, Hop inpColU, boolean passedRowsLEU, boolean passedColsLEU) { |
| super(l, dt, vt); |
| |
| getInput().add(0, inpMatrix); |
| getInput().add(1, inpRowL); |
| getInput().add(2, inpRowU); |
| getInput().add(3, inpColL); |
| getInput().add(4, inpColU); |
| |
| // create hops if one of them is null |
| inpMatrix.getParent().add(this); |
| inpRowL.getParent().add(this); |
| inpRowU.getParent().add(this); |
| inpColL.getParent().add(this); |
| inpColU.getParent().add(this); |
| |
| // set information whether left indexing operation involves row (n x 1) or column (1 x m) matrix |
| setRowLowerEqualsUpper(passedRowsLEU); |
| setColLowerEqualsUpper(passedColsLEU); |
| } |
| |
| @Override |
| public void checkArity() { |
| HopsException.check(_input.size() == 5, this, "should have 5 inputs but has %d inputs", _input.size()); |
| } |
| |
| public boolean isRowLowerEqualsUpper(){ |
| return _rowLowerEqualsUpper; |
| } |
| |
| public boolean isColLowerEqualsUpper() { |
| return _colLowerEqualsUpper; |
| } |
| |
| public void setRowLowerEqualsUpper(boolean passed){ |
| _rowLowerEqualsUpper = passed; |
| } |
| |
| public void setColLowerEqualsUpper(boolean passed) { |
| _colLowerEqualsUpper = passed; |
| } |
| |
| @Override |
| public boolean isGPUEnabled() { |
| if(!DMLScript.USE_ACCELERATOR) { |
| return false; |
| } |
| else { |
| // Indexing is only supported on GPU if: |
| // 1. the input is of type matrix AND |
| // 2. the input is less than 2GB. |
| // The second condition is added for following reason: |
| // 1. Indexing is a purely memory-bound operation and doesnot benefit drastically from pushing down to GPU. |
| // 2. By forcing larger matrices to GPU (for example: training dataset), we run into risk of unnecessary evictions of |
| // parameters and the gradients. For single precision, there is additional overhead of converting training dataset |
| // to single precision every single time it is evicted. |
| return (getDataType() == DataType.MATRIX) && getInputMemEstimate() < 2e+9; |
| } |
| } |
| |
| @Override |
| public Lop constructLops() |
| { |
| //return already created lops |
| if( getLops() != null ) |
| return getLops(); |
| |
| Hop input = getInput().get(0); |
| |
| //rewrite remove unnecessary right indexing |
| if( HopRewriteUtils.isUnnecessaryRightIndexing(this) ) { |
| setLops( input.constructLops() ); |
| } |
| //actual lop construction, incl operator selection |
| else |
| { |
| try { |
| ExecType et = optFindExecType(); |
| |
| if( et == ExecType.SPARK ) |
| { |
| IndexingMethod method = optFindIndexingMethod( _rowLowerEqualsUpper, _colLowerEqualsUpper, |
| input.getDim1(), input.getDim2(), getDim1(), getDim2()); |
| SparkAggType aggtype = (method==IndexingMethod.MR_VRIX || isBlockAligned()) ? |
| SparkAggType.NONE : SparkAggType.MULTI_BLOCK; |
| |
| Lop dummy = Data.createLiteralLop(ValueType.INT64, Integer.toString(-1)); |
| RightIndex reindex = new RightIndex( |
| input.constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(), |
| getInput().get(3).constructLops(), getInput().get(4).constructLops(), dummy, dummy, |
| getDataType(), getValueType(), aggtype, et); |
| |
| setOutputDimensions(reindex); |
| setLineNumbers(reindex); |
| setLops(reindex); |
| } |
| else //CP or GPU |
| { |
| Lop dummy = Data.createLiteralLop(ValueType.INT64, Integer.toString(-1)); |
| RightIndex reindex = new RightIndex( |
| input.constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(), |
| getInput().get(3).constructLops(), getInput().get(4).constructLops(), dummy, dummy, |
| getDataType(), getValueType(), et); |
| |
| setOutputDimensions(reindex); |
| setLineNumbers(reindex); |
| setLops(reindex); |
| } |
| } catch (Exception e) { |
| throw new HopsException(this.printErrorLocation() + "In IndexingOp Hop, error constructing Lops " , e); |
| } |
| } |
| |
| //add reblock/checkpoint lops if necessary |
| constructAndSetLopsDataFlowProperties(); |
| |
| return getLops(); |
| } |
| |
| @Override |
| public String getOpString() { |
| String s = new String(""); |
| s += OPSTRING; |
| return s; |
| } |
| |
| @Override |
| public boolean allowsAllExecTypes() |
| { |
| return true; |
| } |
| |
| @Override |
| public void computeMemEstimate( MemoTable memo ) |
| { |
| //default behavior |
| super.computeMemEstimate(memo); |
| |
| //try to infer via worstcase input statistics (for the case of dims known |
| //but nnz initially unknown) |
| DataCharacteristics dcM1 = memo.getAllInputStats(getInput().get(0)); |
| if( dimsKnown() && dcM1.getNonZeros()>=0 ){ |
| long lnnz = dcM1.getNonZeros(); //worst-case output nnz |
| double lOutMemEst = computeOutputMemEstimate(getDim1(), getDim2(), lnnz); |
| if( lOutMemEst<_outputMemEstimate ){ |
| _outputMemEstimate = lOutMemEst; |
| _memEstimate = getInputOutputSize(); |
| } |
| } |
| } |
| |
| @Override |
| protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) |
| { |
| // only dense right indexing supported on GPU |
| double sparsity = isGPUEnabled() ? 1.0 : OptimizerUtils.getSparsity(dim1, dim2, nnz); |
| return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); |
| } |
| |
| @Override |
| protected double computeIntermediateMemEstimate( long dim1, long dim2, long nnz ) |
| { |
| return 0; |
| } |
| |
| @Override |
| protected DataCharacteristics inferOutputCharacteristics( MemoTable memo ) |
| { |
| DataCharacteristics ret = null; |
| |
| Hop input = getInput().get(0); //original matrix |
| DataCharacteristics dc = memo.getAllInputStats(input); |
| if( dc != null ) |
| { |
| long lnnz = dc.dimsKnown()?Math.min(dc.getRows()*dc.getCols(), dc.getNonZeros()):-1; |
| //worst-case is input size, but dense |
| ret = new MatrixCharacteristics(dc.getRows(), dc.getCols(), -1, lnnz); |
| |
| //exploit column/row indexing information |
| if( _rowLowerEqualsUpper ) ret.setRows(1); |
| if( _colLowerEqualsUpper ) ret.setCols(1); |
| |
| //infer tight block indexing size |
| Hop rl = getInput().get(1); |
| Hop ru = getInput().get(2); |
| Hop cl = getInput().get(3); |
| Hop cu = getInput().get(4); |
| if( isBlockIndexingExpression(rl, ru) ) |
| ret.setRows(getBlockIndexingExpressionSize(rl, ru)); |
| if( isBlockIndexingExpression(cl, cu) ) |
| ret.setCols(getBlockIndexingExpressionSize(cl, cu)); |
| } |
| |
| return ret; |
| } |
| |
| /** |
| * Indicates if the lbound:rbound expressions is of the form |
| * "(c * (i - 1) + 1) : (c * i)", where we could use c as a tight size estimate. |
| * |
| * @param lbound lower bound high-level operator |
| * @param ubound uppser bound high-level operator |
| * @return true if block indexing expression |
| */ |
| private static boolean isBlockIndexingExpression(Hop lbound, Hop ubound) |
| { |
| boolean ret = false; |
| LiteralOp constant = null; |
| DataOp var = null; |
| |
| //handle lower bound |
| if( lbound instanceof BinaryOp && ((BinaryOp)lbound).getOp()==OpOp2.PLUS |
| && lbound.getInput().get(1) instanceof LiteralOp |
| && HopRewriteUtils.getDoubleValueSafe((LiteralOp)lbound.getInput().get(1))==1 |
| && lbound.getInput().get(0) instanceof BinaryOp) |
| { |
| BinaryOp lmult = (BinaryOp)lbound.getInput().get(0); |
| if( lmult.getOp()==OpOp2.MULT && lmult.getInput().get(0) instanceof LiteralOp |
| && lmult.getInput().get(1) instanceof BinaryOp ) |
| { |
| BinaryOp lminus = (BinaryOp)lmult.getInput().get(1); |
| if( lminus.getOp()==OpOp2.MINUS && lminus.getInput().get(1) instanceof LiteralOp |
| && HopRewriteUtils.getDoubleValueSafe((LiteralOp)lminus.getInput().get(1))==1 |
| && lminus.getInput().get(0) instanceof DataOp ) |
| { |
| constant = (LiteralOp)lmult.getInput().get(0); |
| var = (DataOp) lminus.getInput().get(0); |
| } |
| } |
| } |
| |
| //handle upper bound |
| if( var != null && constant != null && ubound instanceof BinaryOp |
| && ubound.getInput().get(0) instanceof LiteralOp |
| && ubound.getInput().get(1) instanceof DataOp |
| && ubound.getInput().get(1).getName().equals(var.getName()) ) |
| { |
| LiteralOp constant2 = (LiteralOp)ubound.getInput().get(0); |
| ret = ( HopRewriteUtils.getDoubleValueSafe(constant) == |
| HopRewriteUtils.getDoubleValueSafe(constant2) ); |
| } |
| |
| return ret; |
| } |
| |
| /** |
| * Indicates if the right indexing ranging is block aligned, i.e., it does not require |
| * aggregation across blocks due to shifting. |
| * |
| * @return true if block aligned |
| */ |
| private boolean isBlockAligned() { |
| Hop input1 = getInput().get(0); //original matrix |
| Hop input2 = getInput().get(1); //inpRowL |
| Hop input3 = getInput().get(2); //inpRowU |
| Hop input4 = getInput().get(3); //inpColL |
| Hop input5 = getInput().get(4); //inpRowU |
| |
| long rl = (input2 instanceof LiteralOp) ? (HopRewriteUtils.getIntValueSafe((LiteralOp)input2)) : -1; |
| long ru = (input3 instanceof LiteralOp) ? (HopRewriteUtils.getIntValueSafe((LiteralOp)input3)) : -1; |
| long cl = (input4 instanceof LiteralOp) ? (HopRewriteUtils.getIntValueSafe((LiteralOp)input4)) : -1; |
| long cu = (input5 instanceof LiteralOp) ? (HopRewriteUtils.getIntValueSafe((LiteralOp)input5)) : -1; |
| int blen = input1.getBlocksize(); |
| |
| return OptimizerUtils.isIndexingRangeBlockAligned(rl, ru, cl, cu, blen); |
| } |
| |
| private static long getBlockIndexingExpressionSize(Hop lbound, Hop ubound) { |
| //NOTE: ensure consistency with isBlockIndexingExpression |
| LiteralOp c = (LiteralOp) ubound.getInput().get(0); //(c*i) |
| return HopRewriteUtils.getIntValueSafe(c); |
| } |
| |
| @Override |
| protected ExecType optFindExecType() { |
| |
| checkAndSetForcedPlatform(); |
| |
| if( _etypeForced != null ) |
| { |
| _etype = _etypeForced; |
| } |
| else |
| { |
| if ( OptimizerUtils.isMemoryBasedOptLevel() ) { |
| _etype = findExecTypeByMemEstimate(); |
| } |
| else if ( getInput().get(0).areDimsBelowThreshold() ) |
| { |
| _etype = ExecType.CP; |
| } |
| else |
| { |
| _etype = ExecType.SPARK; |
| } |
| |
| //check for valid CP dimensions and matrix size |
| checkAndSetInvalidCPDimsAndSize(); |
| } |
| |
| if( getInput().get(0).getDataType()==DataType.LIST ) |
| _etype = ExecType.CP; |
| |
| //mark for recompile (forever) |
| setRequiresRecompileIfNecessary(); |
| |
| return _etype; |
| } |
| |
| private static IndexingMethod optFindIndexingMethod( boolean singleRow, boolean singleCol, long m1_dim1, long m1_dim2, long m2_dim1, long m2_dim2 ) |
| { |
| if( singleRow && m1_dim2 == m2_dim2 && m2_dim2!=-1 |
| || singleCol && m1_dim1 == m2_dim1 && m2_dim1!=-1 ) |
| { |
| return IndexingMethod.MR_VRIX; |
| } |
| |
| return IndexingMethod.MR_RIX; //general case |
| } |
| |
| @Override |
| public void refreshSizeInformation() |
| { |
| Hop input1 = getInput().get(0); //matrix |
| Hop input2 = getInput().get(1); //inpRowL |
| Hop input3 = getInput().get(2); //inpRowU |
| Hop input4 = getInput().get(3); //inpColL |
| Hop input5 = getInput().get(4); //inpColU |
| |
| //update single row/column flags (depends on CSE) |
| _rowLowerEqualsUpper = (input2 == input3); |
| _colLowerEqualsUpper = (input4 == input5); |
| |
| //parse input information |
| boolean allRows = isAllRows(); |
| boolean allCols = isAllCols(); |
| boolean constRowRange = (input2 instanceof LiteralOp && input3 instanceof LiteralOp); |
| boolean constColRange = (input4 instanceof LiteralOp && input5 instanceof LiteralOp); |
| |
| //set dimension information |
| if( _rowLowerEqualsUpper ) //ROWS |
| setDim1(1); |
| else if( allRows ) { |
| setDim1(input1.getDim1()); |
| } |
| else if( constRowRange ) { |
| setDim1( HopRewriteUtils.getIntValueSafe((LiteralOp)input3) |
| -HopRewriteUtils.getIntValueSafe((LiteralOp)input2)+1 ); |
| } |
| else if( isBlockIndexingExpression(input2, input3) ) { |
| setDim1(getBlockIndexingExpressionSize(input2, input3)); |
| } |
| else { |
| //for reset (e.g., on reconcile after loops) |
| setDim1(-1); |
| } |
| |
| if( _colLowerEqualsUpper ) //COLS |
| setDim2(1); |
| else if( allCols ) { |
| setDim2(input1.getDim2()); |
| } |
| else if( constColRange ) { |
| setDim2( HopRewriteUtils.getIntValueSafe((LiteralOp)input5) |
| -HopRewriteUtils.getIntValueSafe((LiteralOp)input4)+1 ); |
| } |
| else if( isBlockIndexingExpression(input4, input5) ) { |
| setDim2(getBlockIndexingExpressionSize(input4, input5)); |
| } |
| else { |
| //for reset (e.g., on reconcile after loops) |
| setDim2(-1); |
| } |
| } |
| |
| public boolean isAllRows() { |
| Hop input1 = getInput().get(0); |
| Hop input2 = getInput().get(1); |
| Hop input3 = getInput().get(2); |
| return HopRewriteUtils.isLiteralOfValue(input2, 1) |
| && ((HopRewriteUtils.isUnary(input3, OpOp1.NROW) && input3.getInput().get(0) == input1 ) |
| || HopRewriteUtils.isLiteralOfValue(input3, input1.getDim1())); |
| } |
| |
| public boolean isAllCols() { |
| Hop input1 = getInput().get(0); |
| Hop input4 = getInput().get(3); |
| Hop input5 = getInput().get(4); |
| return HopRewriteUtils.isLiteralOfValue(input4, 1) |
| && ((HopRewriteUtils.isUnary(input5, OpOp1.NCOL) && input5.getInput().get(0) == input1 ) |
| || HopRewriteUtils.isLiteralOfValue(input5, input1.getDim2())); |
| } |
| |
| @Override |
| public Object clone() throws CloneNotSupportedException { |
| IndexingOp ret = new IndexingOp(); |
| //copy generic attributes |
| ret.clone(this, false); |
| //copy specific attributes |
| return ret; |
| } |
| |
| @Override |
| public boolean compare( Hop that ) |
| { |
| if( !(that instanceof IndexingOp) |
| || getInput().size() != that.getInput().size() ) |
| { |
| return false; |
| } |
| |
| return getInput().get(0) == that.getInput().get(0) |
| && getInput().get(1) == that.getInput().get(1) |
| && getInput().get(2) == that.getInput().get(2) |
| && getInput().get(3) == that.getInput().get(3) |
| && getInput().get(4) == that.getInput().get(4); |
| } |
| } |