| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysml.hops.codegen; |
| |
| import java.util.ArrayList; |
| |
| import org.apache.sysml.hops.Hop; |
| import org.apache.sysml.hops.Hop.MultiThreadedHop; |
| import org.apache.sysml.hops.HopsException; |
| import org.apache.sysml.hops.MemoTable; |
| import org.apache.sysml.hops.OptimizerUtils; |
| import org.apache.sysml.lops.Lop; |
| import org.apache.sysml.lops.LopProperties.ExecType; |
| import org.apache.sysml.lops.LopsException; |
| import org.apache.sysml.lops.SpoofFused; |
| import org.apache.sysml.parser.Expression.DataType; |
| import org.apache.sysml.parser.Expression.ValueType; |
| import org.apache.sysml.runtime.codegen.SpoofRowwise; |
| import org.apache.sysml.runtime.matrix.MatrixCharacteristics; |
| |
| public class SpoofFusedOp extends Hop implements MultiThreadedHop |
| { |
| public enum SpoofOutputDimsType { |
| INPUT_DIMS, |
| INPUT_DIMS_CONST2, |
| ROW_DIMS, |
| COLUMN_DIMS_ROWS, |
| COLUMN_DIMS_COLS, |
| RANK_DIMS_COLS, |
| SCALAR, |
| MULTI_SCALAR, |
| ROW_RANK_DIMS, // right wdivmm, row mm |
| COLUMN_RANK_DIMS, // left wdivmm, row mm |
| COLUMN_RANK_DIMS_T, |
| VECT_CONST2; |
| } |
| |
| private Class<?> _class = null; |
| private boolean _distSupported = false; |
| private int _numThreads = -1; |
| private long _constDim2 = -1; |
| private SpoofOutputDimsType _dimsType; |
| |
| public SpoofFusedOp ( ) { |
| |
| } |
| |
| public SpoofFusedOp( String name, DataType dt, ValueType vt, Class<?> cla, boolean dist, SpoofOutputDimsType type ) { |
| super(name, dt, vt); |
| _class = cla; |
| _distSupported = dist; |
| _dimsType = type; |
| } |
| |
| @Override |
| public void checkArity() throws HopsException {} |
| |
| @Override |
| public void setMaxNumThreads(int k) { |
| _numThreads = k; |
| } |
| |
| @Override |
| public int getMaxNumThreads() { |
| return _numThreads; |
| } |
| |
| @Override |
| public boolean allowsAllExecTypes() { |
| return _distSupported; |
| } |
| |
| public void setConstDim2(long constDim2) { |
| _constDim2 = constDim2; |
| } |
| |
| @Override |
| protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) { |
| return _class.getGenericSuperclass().equals(SpoofRowwise.class) ? |
| OptimizerUtils.estimateSize(dim1, dim2) : |
| OptimizerUtils.estimatePartitionedSizeExactSparsity( |
| dim1, dim2, getRowsInBlock(), getColsInBlock(), nnz); |
| } |
| |
| @Override |
| protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) { |
| return 0; |
| } |
| |
| @Override |
| public Lop constructLops() throws HopsException, LopsException { |
| if( getLops() != null ) |
| return getLops(); |
| |
| ExecType et = optFindExecType(); |
| |
| ArrayList<Lop> inputs = new ArrayList<>(); |
| for( Hop c : getInput() ) |
| inputs.add(c.constructLops()); |
| |
| int k = OptimizerUtils.getConstrainedNumThreads(_numThreads); |
| SpoofFused lop = new SpoofFused(inputs, getDataType(), getValueType(), _class, k, et); |
| setOutputDimensions(lop); |
| setLineNumbers(lop); |
| setLops(lop); |
| |
| return lop; |
| } |
| |
| @Override |
| protected ExecType optFindExecType() throws HopsException { |
| |
| checkAndSetForcedPlatform(); |
| |
| if( _etypeForced != null ) { |
| _etype = _etypeForced; |
| } |
| else { |
| _etype = findExecTypeByMemEstimate(); |
| checkAndSetInvalidCPDimsAndSize(); |
| } |
| |
| //ensure valid execution plans |
| if( _etype == ExecType.MR ) |
| _etype = ExecType.CP; |
| |
| return _etype; |
| } |
| |
| @Override |
| public String getOpString() { |
| return "spoof("+_class.getSimpleName()+")"; |
| } |
| |
| @Override |
| protected long[] inferOutputCharacteristics( MemoTable memo ) |
| { |
| long[] ret = null; |
| |
| //get statistics of main input |
| MatrixCharacteristics mc = memo.getAllInputStats(getInput().get(0)); |
| |
| if( mc.dimsKnown() ) { |
| switch(_dimsType) |
| { |
| case ROW_DIMS: |
| ret = new long[]{mc.getRows(), 1, -1}; |
| break; |
| case COLUMN_DIMS_ROWS: |
| ret = new long[]{mc.getCols(), 1, -1}; |
| break; |
| case COLUMN_DIMS_COLS: |
| ret = new long[]{1, mc.getCols(), -1}; |
| break; |
| case RANK_DIMS_COLS: { |
| MatrixCharacteristics mc2 = memo.getAllInputStats(getInput().get(1)); |
| if( mc2.dimsKnown() ) |
| ret = new long[]{1, mc2.getCols(), -1}; |
| break; |
| } |
| case INPUT_DIMS: |
| ret = new long[]{mc.getRows(), mc.getCols(), -1}; |
| break; |
| case INPUT_DIMS_CONST2: |
| ret = new long[]{mc.getRows(), _constDim2, -1}; |
| break; |
| case VECT_CONST2: |
| ret = new long[]{1, _constDim2, -1}; |
| break; |
| case SCALAR: |
| ret = new long[]{0, 0, -1}; |
| break; |
| case MULTI_SCALAR: |
| //dim2 statically set from outside |
| ret = new long[]{1, _dim2, -1}; |
| break; |
| case ROW_RANK_DIMS: { |
| MatrixCharacteristics mc2 = memo.getAllInputStats(getInput().get(1)); |
| if( mc2.dimsKnown() ) |
| ret = new long[]{mc.getRows(), mc2.getCols(), -1}; |
| break; |
| } |
| case COLUMN_RANK_DIMS: { |
| MatrixCharacteristics mc2 = memo.getAllInputStats(getInput().get(1)); |
| if( mc2.dimsKnown() ) |
| ret = new long[]{mc.getCols(), mc2.getCols(), -1}; |
| break; |
| } |
| case COLUMN_RANK_DIMS_T: { |
| MatrixCharacteristics mc2 = memo.getAllInputStats(getInput().get(1)); |
| if( mc2.dimsKnown() ) |
| ret = new long[]{mc2.getCols(), mc.getCols(), -1}; |
| break; |
| } |
| default: |
| throw new RuntimeException("Failed to infer worst-case size information " |
| + "for type: "+_dimsType.toString()); |
| } |
| } |
| |
| return ret; |
| } |
| |
| @Override |
| public void refreshSizeInformation() { |
| switch(_dimsType) |
| { |
| case ROW_DIMS: |
| setDim1(getInput().get(0).getDim1()); |
| setDim2(1); |
| break; |
| case COLUMN_DIMS_ROWS: |
| setDim1(getInput().get(0).getDim2()); |
| setDim2(1); |
| break; |
| case COLUMN_DIMS_COLS: |
| setDim1(1); |
| setDim2(getInput().get(0).getDim2()); |
| break; |
| case RANK_DIMS_COLS: |
| setDim1(1); |
| setDim2(getInput().get(1).getDim2()); |
| break; |
| case INPUT_DIMS: |
| setDim1(getInput().get(0).getDim1()); |
| setDim2(getInput().get(0).getDim2()); |
| break; |
| case INPUT_DIMS_CONST2: |
| setDim1(getInput().get(0).getDim1()); |
| setDim2(_constDim2); |
| break; |
| case VECT_CONST2: |
| setDim1(1); |
| setDim2(_constDim2); |
| break; |
| case SCALAR: |
| setDim1(0); |
| setDim2(0); |
| break; |
| case MULTI_SCALAR: |
| setDim1(1); //row vector |
| //dim2 statically set from outside |
| break; |
| case ROW_RANK_DIMS: |
| setDim1(getInput().get(0).getDim1()); |
| setDim2(getInput().get(1).getDim2()); |
| break; |
| case COLUMN_RANK_DIMS: |
| setDim1(getInput().get(0).getDim2()); |
| setDim2(getInput().get(1).getDim2()); |
| break; |
| case COLUMN_RANK_DIMS_T: |
| setDim1(getInput().get(1).getDim2()); |
| setDim2(getInput().get(0).getDim2()); |
| break; |
| default: |
| throw new RuntimeException("Failed to refresh size information " |
| + "for type: "+_dimsType.toString()); |
| } |
| } |
| |
| @Override |
| public Object clone() throws CloneNotSupportedException |
| { |
| SpoofFusedOp ret = new SpoofFusedOp(); |
| |
| //copy generic attributes |
| ret.clone(this, false); |
| |
| //copy specific attributes |
| ret._class = _class; |
| ret._distSupported = _distSupported; |
| ret._numThreads = _numThreads; |
| ret._dimsType = _dimsType; |
| return ret; |
| } |
| |
| @Override |
| public boolean compare( Hop that ) |
| { |
| if( !(that instanceof SpoofFusedOp) ) |
| return false; |
| |
| SpoofFusedOp that2 = (SpoofFusedOp)that; |
| boolean ret = ( _class.equals(that2._class) |
| && _distSupported == that2._distSupported |
| && _numThreads == that2._numThreads |
| && getInput().size() == that2.getInput().size()); |
| |
| if( ret ) { |
| for( int i=0; i<getInput().size(); i++ ) |
| ret &= (getInput().get(i) == that2.getInput().get(i)); |
| } |
| |
| return ret; |
| } |
| |
| @Override |
| public boolean isGPUEnabled() { |
| return false; |
| } |
| } |