blob: e9023d443b67f91f0da4a70760f4e58220bb22e7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops;
import java.util.ArrayList;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.Hop.MultiThreadedHop;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.lops.ConvolutionTransform;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.lops.LopProperties.ExecType;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.ConvolutionParameters;
public class ConvolutionOp extends Hop implements MultiThreadedHop
{
private Hop.ConvOp op;
private int _maxNumThreads = -1; //-1 for unlimited
private ConvolutionOp() {
//default constructor for clone
}
public ConvolutionOp(String l, DataType dt, ValueType vt, ConvOp o, Hop inp)
{
super(l, dt, vt);
op = o;
getInput().add(0, inp);
inp.getParent().add(this);
//compute unknown dims and nnz
refreshSizeInformation();
}
public ConvolutionOp(String l, DataType dt, ValueType vt, ConvOp o, ArrayList<Hop> inp)
{
super(l, dt, vt);
op = o;
for( int i=0; i<inp.size(); i++ ) {
Hop in = inp.get(i);
getInput().add(i, in);
in.getParent().add(this);
}
//compute unknown dims and nnz
refreshSizeInformation();
}
public ConvOp getOp()
{
return op;
}
@Override
public String getOpString() {
return "" + HopsConv2Lops.get(op);
}
@Override
public Lop constructLops()
throws HopsException, LopsException
{
//return already created lops
if( getLops() != null )
return getLops();
ExecType et = optFindExecType();
ArrayList<Hop> inputs = getInput();
switch( op )
{
case IM2COL:
case RESHAPE_COL:
case ROTATE180:
case COL2IM:
case MAX_POOLING:
case MAX_POOLING_BACKWARD:
case DIRECT_CONV2D:
case DIRECT_CONV2D_BACKWARD_DATA:
case DIRECT_CONV2D_BACKWARD_FILTER:
{
if( et == ExecType.CP )
{
setLops(constructConvolutionLops(et, inputs));
break;
}
else {
// TODO: Add support for SPARK/MR backends once we are happy with the performance of
// single node Lenet script.
throw new HopsException("Unimplemented ConvolutionOp for execution type: " + et.name());
}
// break;
}
default:
throw new HopsException("Unsupported lops construction for operation type '"+op+"'.");
}
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
return getLops();
}
public void setOp(ConvOp op) {
this.op = op;
}
public Lop constructConvolutionLops(ExecType et, ArrayList<Hop> inputs) throws HopsException, LopsException {
int expectedNumInputs = 13;
if(op == ConvOp.MAX_POOLING_BACKWARD
|| op == ConvOp.DIRECT_CONV2D
|| op == ConvOp.DIRECT_CONV2D_BACKWARD_FILTER
|| op == ConvOp.DIRECT_CONV2D_BACKWARD_DATA) {
expectedNumInputs = 14;
}
if(inputs.size() != expectedNumInputs) {
throw new HopsException("Incorrect number of inputs for " + op.name());
}
Lop in = inputs.get(0).constructLops();
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
ConvolutionTransform transform1 = new ConvolutionTransform( in,
HopsConv2Lops.get(op), getDataType(), getValueType(), et, k);
setOutputDimensions(transform1);
setLineNumbers(transform1);
// stride1, stride2, padding1, padding2
// input_shape1, input_shape2, input_shape3, input_shape4,
// filter_shape1, filter_shape2, filter_shape3, filter_shape4
for( int i=1; i <= (expectedNumInputs-1); i++ )
{
Lop ltmp = inputs.get(i).constructLops();
transform1.addInput(ltmp);
ltmp.addOutput(transform1);
}
transform1.setLevel(); //force order of added lops
return transform1;
}
@Override
protected double computeOutputMemEstimate( long dim1, long dim2, long nnz )
{
double sparsity = 1.0;
switch(op)
{
case RESHAPE_COL:
case ROTATE180:
{
sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz);
break;
}
case IM2COL:
case COL2IM:
case MAX_POOLING:
case MAX_POOLING_BACKWARD:
case DIRECT_CONV2D:
case DIRECT_CONV2D_BACKWARD_FILTER:
case DIRECT_CONV2D_BACKWARD_DATA:
sparsity = 1.0; // worst-case estimate
break;
}
return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
}
@Override
protected double computeIntermediateMemEstimate( long dim1, long dim2, long nnz )
{
//default: no intermediate memory requirements
return 0;
}
@Override
protected long[] inferOutputCharacteristics( MemoTable memo )
{
// [numRows, numCols, NNZ]
long[] ret = null;
Hop input1 = getInput().get(0);
ConvolutionParameters params;
MatrixCharacteristics mc = memo.getAllInputStats(input1);
try {
params = parseInput();
} catch (DMLRuntimeException e) {
throw new RuntimeException(e);
}
switch(op)
{
case RESHAPE_COL:
{
ret = new long[3];
ret[0] = params.N;
ret[1] = getExtractedVal(params.K, params.P, params.Q);
ret[2] = mc.getNonZeros(); // exact estimates
break;
}
case ROTATE180:
{
ret = new long[3];
ret[0] = getExtractedVal(params.N, params.P, params.Q);
ret[1] = params.K;
ret[2] = mc.getNonZeros(); // exact estimates
break;
}
case IM2COL:
case COL2IM:
case MAX_POOLING:
case MAX_POOLING_BACKWARD:
case DIRECT_CONV2D:
case DIRECT_CONV2D_BACKWARD_FILTER:
case DIRECT_CONV2D_BACKWARD_DATA:
break;
}
return ret;
}
@Override
public boolean allowsAllExecTypes()
{
return true;
}
@Override
protected ExecType optFindExecType() throws HopsException {
checkAndSetForcedPlatform();
ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
if( _etypeForced != null )
{
_etype = _etypeForced;
}
else
{
// TODO: After adding Spark backend, uncomment this
if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
_etype = findExecTypeByMemEstimate();
}
// Choose CP, if the input dimensions are below threshold or if the input is a vector
else if ( getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVector() )
{
_etype = ExecType.CP;
}
else
{
_etype = REMOTE;
}
//check for valid CP dimensions and matrix size
checkAndSetInvalidCPDimsAndSize();
}
//mark for recompile (forever)
if( ConfigurationManager.isDynamicRecompilation() && !dimsKnown(true) && _etype==REMOTE )
setRequiresRecompile();
_etype = ExecType.CP;
return _etype;
}
// stride1, stride2, padding1, padding2
// input_shape1, input_shape2, input_shape3, input_shape4,
// filter_shape1, filter_shape2, filter_shape3, filter_shape4
ConvolutionParameters parseInput() throws DMLRuntimeException {
ConvolutionParameters params = new ConvolutionParameters(
extractValue(getInput().get(5)),
extractValue(getInput().get(6)),
extractValue(getInput().get(7)),
extractValue(getInput().get(8)),
extractValue(getInput().get(9)),
extractValue(getInput().get(11)),
extractValue(getInput().get(12)),
extractValue(getInput().get(1)),
extractValue(getInput().get(2)),
extractValue(getInput().get(3)),
extractValue(getInput().get(4)), _maxNumThreads);
return params;
}
long getExtractedVal(long val1, long val2) {
if(val1 == -1 || val2 == -1) {
return -1;
}
return val1*val2;
}
long getExtractedVal(long val1, long val2, long val3) {
if(val1 == -1 || val2 == -1 || val3 == -1) {
return -1;
}
return val1*val2*val3;
}
@Override
public void refreshSizeInformation()
{
Hop input1 = getInput().get(0);
ConvolutionParameters params;
try {
params = parseInput();
} catch (DMLRuntimeException e) {
throw new RuntimeException(e);
}
switch(op)
{
case IM2COL:
{
_dim1 = getExtractedVal(params.C, params.R, params.S);
_dim2 = getExtractedVal(params.N, params.P, params.Q);
_nnz = -1;
break;
}
case COL2IM:
{
// Set _dim1, _dim2 and if possible _nnz (use input1.getNnz())
_dim1 = params.N;
_dim2 = getExtractedVal(params.C, params.H, params.W);
_nnz = -1; // cannot infer stats
break;
}
case RESHAPE_COL:
{
_dim1 = params.N;
_dim2 = getExtractedVal(params.K, params.P, params.Q);
_nnz = input1.getNnz(); // exact estimates
break;
}
case ROTATE180:
{
_dim1 = getExtractedVal(params.N, params.P, params.Q);
_dim2 = params.K;
_nnz = input1.getNnz(); // exact estimates
break;
}
case MAX_POOLING:
{
_dim1 = params.N;
_dim2 = getExtractedVal(params.C, params.P, params.Q);
_nnz = -1; // cannot infer stats
break;
}
case MAX_POOLING_BACKWARD:
{
_dim1 = params.N;
_dim2 = getExtractedVal(params.C, params.H, params.W);
_nnz = -1;
break;
}
case DIRECT_CONV2D:
{
_dim1 = params.N;
_dim2 = getExtractedVal(params.K, params.P, params.Q);
_nnz = -1; // cannot infer stats
break;
}
case DIRECT_CONV2D_BACKWARD_DATA:
{
_dim1 = params.N;
_dim2 = getExtractedVal(params.C, params.H, params.W);
_nnz = -1; // cannot infer stats
break;
}
case DIRECT_CONV2D_BACKWARD_FILTER:
{
_dim1 = params.K;
_dim2 = getExtractedVal(params.C, params.R, params.S);
_nnz = -1; // cannot infer stats
break;
}
default:
throw new RuntimeException("The sizes are not refreshed for " + op.name());
}
}
private long extractValue(Hop hop) {
if(hop instanceof LiteralOp)
return (long) HopRewriteUtils.getDoubleValueSafe((LiteralOp)hop);
return -1;
}
@Override
public Object clone() throws CloneNotSupportedException
{
ConvolutionOp ret = new ConvolutionOp();
//copy generic attributes
ret.clone(this, false);
//copy specific attributes
ret.op = op;
ret._maxNumThreads = _maxNumThreads;
return ret;
}
@Override
public boolean compare( Hop that )
{
if( !(that instanceof ConvolutionOp) )
return false;
ConvolutionOp that2 = (ConvolutionOp)that;
boolean ret = (op == that2.op)
&& (getInput().size()==that.getInput().size())
&& _maxNumThreads == that2._maxNumThreads;
//compare all childs
if( ret ) //sizes matched
for( int i=0; i<_input.size(); i++ )
ret &= getInput().get(i) == that2.getInput().get(i);
return ret;
}
@Override
public void printMe() throws HopsException
{
if (LOG.isDebugEnabled()){
if (getVisited() != VisitStatus.DONE) {
super.printMe();
LOG.debug(" Operation: " + op);
for (Hop h : getInput()) {
h.printMe();
}
}
setVisited(VisitStatus.DONE);
}
}
@Override
public void setMaxNumThreads( int k ) {
_maxNumThreads = k;
}
@Override
public int getMaxNumThreads() {
return _maxNumThreads;
}
}