blob: 3aca219c3dcfcccdfb9a76cc562350ec976932d4 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops.codegen;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.MemoTable;
import org.apache.sysds.hops.MultiThreadedHop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.lops.SpoofFused;
import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import java.util.ArrayList;
public class SpoofFusedOp extends MultiThreadedHop
{
public enum SpoofOutputDimsType {
INPUT_DIMS,
INPUT_DIMS_CONST2,
ROW_DIMS,
COLUMN_DIMS_ROWS,
COLUMN_DIMS_COLS,
RANK_DIMS_COLS,
SCALAR,
MULTI_SCALAR,
ROW_RANK_DIMS, // right wdivmm, row mm
COLUMN_RANK_DIMS, // left wdivmm, row mm
COLUMN_RANK_DIMS_T,
VECT_CONST2;
}
private Class<?> _class = null;
private boolean _distSupported = false;
private long _constDim2 = -1;
private SpoofOutputDimsType _dimsType;
public SpoofFusedOp ( ) {
}
public SpoofFusedOp( String name, DataType dt, ValueType vt, Class<?> cla, boolean dist, SpoofOutputDimsType type ) {
super(name, dt, vt);
_class = cla;
_distSupported = dist;
_dimsType = type;
}
@Override
public void checkArity() {}
@Override
public boolean allowsAllExecTypes() {
return _distSupported;
}
public void setConstDim2(long constDim2) {
_constDim2 = constDim2;
}
@Override
public boolean isGPUEnabled() {
return false;
}
@Override
public boolean isMultiThreadedOpType() {
return true;
}
@Override
protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) {
return _class.getGenericSuperclass().equals(SpoofRowwise.class) ?
OptimizerUtils.estimateSize(dim1, dim2) :
OptimizerUtils.estimatePartitionedSizeExactSparsity(
dim1, dim2, getBlocksize(), nnz);
}
@Override
protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) {
return 0;
}
@Override
public Lop constructLops() {
if( getLops() != null )
return getLops();
ExecType et = optFindExecType();
ArrayList<Lop> inputs = new ArrayList<>();
for( Hop c : getInput() )
inputs.add(c.constructLops());
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
SpoofFused lop = new SpoofFused(inputs, getDataType(), getValueType(), _class, k, et);
setOutputDimensions(lop);
setLineNumbers(lop);
setLops(lop);
return lop;
}
@Override
protected ExecType optFindExecType() {
checkAndSetForcedPlatform();
if( _etypeForced != null ) {
_etype = _etypeForced;
}
else {
_etype = findExecTypeByMemEstimate();
checkAndSetInvalidCPDimsAndSize();
}
return _etype;
}
@Override
public String getOpString() {
return "spoof("+_class.getSimpleName()+")";
}
public String getClassName() {
return _class.getName();
}
@Override
protected DataCharacteristics inferOutputCharacteristics( MemoTable memo )
{
//get statistics of main input
DataCharacteristics mc = memo.getAllInputStats(getInput().get(0));
DataCharacteristics ret = null;
if( mc.dimsKnown() ) {
switch(_dimsType)
{
case ROW_DIMS:
ret = new MatrixCharacteristics(mc.getRows(), 1, -1, -1);
break;
case COLUMN_DIMS_ROWS:
ret = new MatrixCharacteristics(mc.getCols(), 1, -1, -1);
break;
case COLUMN_DIMS_COLS:
ret = new MatrixCharacteristics(1, mc.getCols(), -1, -1);
break;
case RANK_DIMS_COLS: {
DataCharacteristics dc2 = memo.getAllInputStats(getInput().get(1));
if( dc2.dimsKnown() )
ret = new MatrixCharacteristics(1, dc2.getCols(), -1, -1);
break;
}
case INPUT_DIMS:
ret = new MatrixCharacteristics(mc.getRows(), mc.getCols(), -1, -1);
break;
case INPUT_DIMS_CONST2:
ret = new MatrixCharacteristics(mc.getRows(), _constDim2, -1, -1);
break;
case VECT_CONST2:
ret = new MatrixCharacteristics(1, _constDim2, -1, -1);
break;
case SCALAR:
ret = new MatrixCharacteristics(0, 0, -1, -1);
break;
case MULTI_SCALAR:
//dim2 statically set from outside
ret = new MatrixCharacteristics(1, _dc.getCols(), -1, -1);
break;
case ROW_RANK_DIMS: {
DataCharacteristics dc2 = memo.getAllInputStats(getInput().get(1));
if( dc2.dimsKnown() )
ret = new MatrixCharacteristics(mc.getRows(), dc2.getCols(), -1, -1);
break;
}
case COLUMN_RANK_DIMS: {
DataCharacteristics dc2 = memo.getAllInputStats(getInput().get(1));
if( dc2.dimsKnown() )
ret = new MatrixCharacteristics(mc.getCols(), dc2.getCols(), -1, -1);
break;
}
case COLUMN_RANK_DIMS_T: {
DataCharacteristics dc2 = memo.getAllInputStats(getInput().get(1));
if( dc2.dimsKnown() )
ret = new MatrixCharacteristics(dc2.getCols(), mc.getCols(), -1, -1);
break;
}
default:
throw new RuntimeException("Failed to infer worst-case size information "
+ "for type: "+_dimsType.toString());
}
}
return ret;
}
@Override
public void refreshSizeInformation() {
switch(_dimsType)
{
case ROW_DIMS:
setDim1(getInput().get(0).getDim1());
setDim2(1);
break;
case COLUMN_DIMS_ROWS:
setDim1(getInput().get(0).getDim2());
setDim2(1);
break;
case COLUMN_DIMS_COLS:
setDim1(1);
setDim2(getInput().get(0).getDim2());
break;
case RANK_DIMS_COLS:
setDim1(1);
setDim2(getInput().get(1).getDim2());
break;
case INPUT_DIMS:
setDim1(getInput().get(0).getDim1());
setDim2(getInput().get(0).getDim2());
break;
case INPUT_DIMS_CONST2:
setDim1(getInput().get(0).getDim1());
setDim2(_constDim2);
break;
case VECT_CONST2:
setDim1(1);
setDim2(_constDim2);
break;
case SCALAR:
setDim1(0);
setDim2(0);
break;
case MULTI_SCALAR:
setDim1(1); //row vector
//dim2 statically set from outside
break;
case ROW_RANK_DIMS:
setDim1(getInput().get(0).getDim1());
setDim2(getInput().get(1).getDim2());
break;
case COLUMN_RANK_DIMS:
setDim1(getInput().get(0).getDim2());
setDim2(getInput().get(1).getDim2());
break;
case COLUMN_RANK_DIMS_T:
setDim1(getInput().get(1).getDim2());
setDim2(getInput().get(0).getDim2());
break;
default:
throw new RuntimeException("Failed to refresh size information "
+ "for type: "+_dimsType.toString());
}
}
@Override
public Object clone() throws CloneNotSupportedException
{
SpoofFusedOp ret = new SpoofFusedOp();
//copy generic attributes
ret.clone(this, false);
//copy specific attributes
ret._class = _class;
ret._distSupported = _distSupported;
ret._maxNumThreads = _maxNumThreads;
ret._constDim2 = _constDim2;
ret._dimsType = _dimsType;
return ret;
}
@Override
public boolean compare( Hop that )
{
if( !(that instanceof SpoofFusedOp) )
return false;
SpoofFusedOp that2 = (SpoofFusedOp)that;
//note: class implies dims type as well
boolean ret = ( _class.equals(that2._class)
&& _distSupported == that2._distSupported
&& _maxNumThreads == that2._maxNumThreads
&& _constDim2 == that2._constDim2
&& getInput().size() == that2.getInput().size());
if( ret ) {
for( int i=0; i<getInput().size(); i++ )
ret &= (getInput().get(i) == that2.getInput().get(i));
}
return ret;
}
}