blob: 54978f1c34c25006a4b29b46e433f7275c907f77 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.hops;
import java.util.ArrayList;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOpDnn;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.DnnTransform;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool;
import org.apache.sysds.runtime.matrix.data.DnnParameters;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
public class DnnOp extends MultiThreadedHop {
private static final Log LOG = LogFactory.getLog(DnnOp.class.getName());
// -------------------------------------------------------------------------
// This flag allows us to compile plans with less unknowns and also serves as future tensorblock integration.
// By default, these flags are turned on.
// When this flag is turned on, we attempt to check the parent convolution hop for unknown dimensions.
// For example: in case of conv -> maxpool, the input channel/height/width of maxpool will match output channel/height/width of conv.
private static final boolean INFER_TENSOR_SHAPE_FROM_PARENT_CONV_OP = true;
// This guards us from cases where the user provides incorrect C,H,W parameters.
private static final boolean THROW_ERROR_IF_INFERRED_SHAPE_MISMATCH = true;
// -------------------------------------------------------------------------
// Specifies the type of this hop
private OpOpDnn op;
private DnnOp() {
//default constructor for clone
}
/**
* Create a hop from the builtin expression
*
* @param l name of the hop
* @param dt datatype (only supports matrix datatype)
* @param vt valuetype (only supports matrix valuetype)
* @param o type of this hop
* @param inp input hops
*/
public DnnOp(String l, DataType dt, ValueType vt, OpOpDnn o, ArrayList<Hop> inp)
{
super(l, dt, vt);
op = o;
for( int i=0; i<inp.size(); i++ ) {
Hop in = inp.get(i);
getInput().add(i, in);
in.getParent().add(this);
}
//compute unknown dims and nnz
refreshSizeInformation();
}
@Override
public void checkArity() {
HopsException.check(_input.size() >= 1, this, "should have at least one input but has %d inputs", _input.size());
}
public OpOpDnn getOp() {
return op;
}
@Override
public String getOpString() {
return op.toString();
}
private static boolean isEligibleForSpark() {
return false;
}
@Override
public boolean isGPUEnabled() {
if(!DMLScript.USE_ACCELERATOR)
return false;
return true;
}
@Override
public boolean isMultiThreadedOpType() {
return true;
}
@Override
public Lop constructLops()
{
//return already created lops
if( getLops() != null )
return getLops();
ExecType et = optFindExecType();
ArrayList<Hop> inputs = getInput();
switch( op )
{
case MAX_POOL:
case MAX_POOL_BACKWARD:
case AVG_POOL:
case AVG_POOL_BACKWARD:
case CONV2D:
case CONV2D_BACKWARD_DATA:
case CONV2D_BACKWARD_FILTER:
case BIASADD:
case BIASMULT: {
if(et == ExecType.CP || et == ExecType.GPU) {
setLops(constructDnnLops(et, inputs));
break;
}
throw new HopsException("Unimplemented DnnOp for execution type: " + et.name());
}
case BATCH_NORM2D_TEST:
case CHANNEL_SUMS:
case UPDATE_NESTEROV_X: {
if(et == ExecType.GPU) {
setLops(constructDnnLops(et, inputs));
break;
}
throw new HopsException("Unimplemented DnnOp for execution type: " + et.name());
}
default:
throw new HopsException("Unsupported lops construction for operation type '"+op+"'.");
}
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
return getLops();
}
public void setOp(OpOpDnn op) {
this.op = op;
}
private int getNumExpectedInputs() {
switch(op) {
case MAX_POOL_BACKWARD:
case AVG_POOL_BACKWARD:
case CONV2D:
case CONV2D_BACKWARD_FILTER:
case CONV2D_BACKWARD_DATA:
return 14;
case BIASADD:
case BIASMULT:
return 2;
case BATCH_NORM2D_TEST:
return 6;
case CHANNEL_SUMS:
return 3;
case UPDATE_NESTEROV_X:
return 4;
default:
return 13;
}
}
/**
* Returns parent matrix X or null
* @param input input hop
* @return either null or X if input is max(X,0) or max(0,X)
*/
private static Hop isInputReLU(Hop input) {
if(HopRewriteUtils.isBinary(input, OpOp2.MAX)) {
if(HopRewriteUtils.isLiteralOfValue(input.getInput().get(0), 0)) {
return input.getInput().get(1);
}
else if(HopRewriteUtils.isLiteralOfValue(input.getInput().get(1), 0)) {
return input.getInput().get(0);
}
else
return null;
}
else
return null;
}
private static boolean isInputConv2d(Hop input) {
return HopRewriteUtils.isDnn(input, OpOpDnn.CONV2D);
}
/**
* Compares the input parameters for max_pool/max_pool_backward operations
*
* @return true if the following parameters match: stride=[stride, stride], padding=[pad, pad], input_shape=[numImg, numChannels, imgSize, imgSize], pool_size=[poolSize1, poolSize2]
*/
private static boolean isPoolingParametersEqualAndKnown(DnnParameters param1, DnnParameters param2) {
return isEqualAndKnown(param1.stride_h, param2.stride_h) && isEqualAndKnown(param1.stride_w, param2.stride_w) &&
isEqualAndKnown(param1.pad_h, param2.pad_h) && isEqualAndKnown(param1.pad_w, param2.pad_w) &&
isEqualAndKnown(param1.R, param2.R) && isEqualAndKnown(param1.S, param2.S) &&
isEqualAndKnown(param1.N, param2.N) && isEqualAndKnown(param1.C, param2.C) &&
isEqualAndKnown(param1.H, param2.H) && isEqualAndKnown(param1.W, param2.W);
}
public boolean isStride1Pad0() {
DnnParameters tmp = parseInput();
return tmp.stride_h == 1 && tmp.stride_w == 1
&& tmp.pad_h == 0 && tmp.pad_w == 0;
}
private static boolean isEqualAndKnown(int val1, int val2) {
return val1 >= 0 && val2 >= 0 && val1 == val2;
}
/**
* Returns the output lop of max_pool/avg_pool operation with same parameters as this hop.
* If corresponding output lop is not found or if this is not a max_pool_backward operation, this function returns null
*
* @return output lop of max_pool/avg_pool operation with same parameters as this hop
*/
private Lop getMaxPoolOutputLop() {
if(op == OpOpDnn.MAX_POOL_BACKWARD || op == OpOpDnn.AVG_POOL_BACKWARD) {
OpOpDnn opType = (op == OpOpDnn.MAX_POOL_BACKWARD) ? OpOpDnn.MAX_POOL : OpOpDnn.AVG_POOL;
Hop inputImage = getInput().get(0);
for(Hop tmpParent : inputImage.getParent()) {
if(!(tmpParent instanceof DnnOp))
continue;
DnnOp parent = (DnnOp) tmpParent;
if(parent.getOp() == opType && isPoolingParametersEqualAndKnown(parent._cachedParams, _cachedParams)) {
return parent.constructLops();
}
}
}
return null;
}
public Lop constructDnnLops(ExecType et, ArrayList<Hop> inputs) {
if(inputs.size() != getNumExpectedInputs())
throw new HopsException("Incorrect number of inputs for " + op.name());
//TODO move these custom rewrites to the general hop rewrites
// ---------------------------------------------------------------
// Deal with fused operators and contruct lhsInputLop/optionalRhsInputLop
Lop lhsInputLop = null; Lop optionalRhsInputLop = null;
ArrayList<Hop> inputsOfPotentiallyFusedOp = inputs;
OpOpDnn lopOp = op;
// RELU_MAX_POOLING and RELU_MAX_POOLING_BACKWARD is extremely useful for CP backend
// by reducing unnecessary sparse-to-dense-to-sparse conversion.
// For other backends, this operators is not necessary as it reduces an additional relu operator.
Hop parentReLU = isInputReLU(inputs.get(0));
if(OptimizerUtils.ALLOW_OPERATOR_FUSION && et == ExecType.CP && op == OpOpDnn.MAX_POOL && parentReLU != null) {
lhsInputLop = parentReLU.constructLops();
lopOp = OpOpDnn.RELU_MAX_POOL;
}
else if(OptimizerUtils.ALLOW_OPERATOR_FUSION && et == ExecType.CP && op == OpOpDnn.MAX_POOL_BACKWARD && parentReLU != null) {
lhsInputLop = parentReLU.constructLops();
lopOp = OpOpDnn.RELU_MAX_POOL_BACKWARD;
}
else if(OptimizerUtils.ALLOW_OPERATOR_FUSION && op == OpOpDnn.BIASADD && isInputConv2d(inputs.get(0))) {
lopOp = OpOpDnn.CONV2D_BIAS_ADD;
// the first lop is image
lhsInputLop = inputs.get(0).getInput().get(0).constructLops();
// the second lop is bias
optionalRhsInputLop = inputs.get(1).constructLops();
// Use the inputs from conv2d rather than bias_add
inputsOfPotentiallyFusedOp = inputs.get(0).getInput();
}
else {
lhsInputLop = inputs.get(0).constructLops();
}
// ---------------------------------------------------------------
// ---------------------------------------------------------------
// Compute intermediate memory budget that can be passed to GPU operators
// for better CuDNN operator selection at runtime
double intermediateMemEstimate = computeIntermediateMemEstimate(-1, -1, -1 );
if(et == ExecType.GPU && getDim1() >= 0 && getDim2() >= 0) {
// This enables us to compile more efficient matrix-matrix CuDNN operation instead of
// row-by-row invocation of multiple vector-matrix CuDNN operations.
// This is possible as the operations on GPU are single-threaded
double optimisticIntermediateMemEstimate = GPUContextPool.initialGPUMemBudget() - getOutputMemEstimate() - inputs.get(0).getOutputMemEstimate();
if(optionalRhsInputLop != null) {
optimisticIntermediateMemEstimate -= inputs.get(1).getOutputMemEstimate();
}
intermediateMemEstimate = Math.max(intermediateMemEstimate, optimisticIntermediateMemEstimate);
}
// ---------------------------------------------------------------
// Construct the lop
Lop optionalMaxPoolOutput = (et == ExecType.GPU) ? getMaxPoolOutputLop() : null;
Lop[] l2inputs = new Lop[inputsOfPotentiallyFusedOp.size()-1];
for( int i=1; i < inputsOfPotentiallyFusedOp.size(); i++ )
l2inputs[i-1] = inputsOfPotentiallyFusedOp.get(i).constructLops();
DnnTransform convolutionLop = new DnnTransform(
lhsInputLop, lopOp, getDataType(), getValueType(), et,
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads), intermediateMemEstimate);
setOutputDimensions(convolutionLop);
setLineNumbers(convolutionLop);
// ---------------------------------------------------------------
// Add input/output for parent lops of convolutionLop
lhsInputLop.addOutput(convolutionLop);
if(optionalRhsInputLop != null) {
convolutionLop.addInput(optionalRhsInputLop);
optionalRhsInputLop.addOutput(convolutionLop);
}
for( int i=0; i < l2inputs.length; i++ ) {
convolutionLop.addInput(l2inputs[i]);
l2inputs[i].addOutput(convolutionLop);
}
// Only valid for MAX_POOLING_BACKWARD on GPU
if(optionalMaxPoolOutput != null) {
convolutionLop.addInput(optionalMaxPoolOutput);
optionalMaxPoolOutput.addOutput(convolutionLop);
}
convolutionLop.updateLopProperties();
// TODO double check that optionalMaxPoolOutput adheres to proper
// ID ordering of constructed lops (previously hidden by setLevel)
// ---------------------------------------------------------------
return convolutionLop;
}
@Override
protected double computeOutputMemEstimate( long dim1, long dim2, long nnz )
{
if(getOp() == OpOpDnn.BIASMULT) {
// in non-gpu mode, the worst case size of bias multiply operation is same as that of input.
if(DMLScript.USE_ACCELERATOR)
return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 1.0);
else
return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, getInput().get(0).getSparsity());
}
else {
double sparsity = 1.0;
return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
}
}
// ---------------------------------------------------------------
// Utility methods to guard the computation of memory estimates in presense of unknowns
private static class IntermediateDimensions {
int dim1; int dim2; double sp;
public IntermediateDimensions(DnnOp h, String dim1Str, String dim2Str, double sp) {
dim1 = (int) h.getDim(dim1Str);
dim2 = (int) h.getDim(dim2Str);
this.sp = sp;
}
public IntermediateDimensions(DnnOp h, String dim1Str, String dim2Str) {
dim1 = (int) h.getDim(dim1Str);
dim2 = (int) h.getDim(dim2Str);
sp = 1;
}
public IntermediateDimensions(DnnOp h, int dim1, String dim2Str) {
this.dim1 = dim1;
dim2 = (int) h.getDim(dim2Str);
sp = 1;
}
/**
* Add two computed memory estimates
*
* @param val1 memory estimate 1
* @param val2 memory estimate 2
* @return sum of memory estimates
*/
static double guardedAdd(double val1, double val2) {
if(val1 < 0 || val2 < 0) return OptimizerUtils.DEFAULT_SIZE;
double ret = val1 + val2;
if(ret >= OptimizerUtils.DEFAULT_SIZE) return OptimizerUtils.DEFAULT_SIZE;
else return ret;
}
/**
* Compute memory estimates for given intermediate matrices
*
* @param intermediates list of intermediates
* @param numWorkers number of workers
* @return memory estimate
*/
public static double addEstimateSizes(ArrayList<IntermediateDimensions> intermediates, int numWorkers) {
double memBudget = 0;
for(int i = 0; i < intermediates.size(); i++) {
memBudget = guardedAdd(memBudget, OptimizerUtils.estimateSizeExactSparsity(
intermediates.get(i).dim1, intermediates.get(i).dim2, intermediates.get(i).sp)*numWorkers);
}
return memBudget;
}
/**
* Compute max of two computed memory estimates
* @param val1 memory estimate 1
* @param val2 memory estimate 2
* @return max of memory estimates
*/
public static double guardedMax(double val1, double val2) {
if(val1 < 0 || val2 < 0) return OptimizerUtils.DEFAULT_SIZE;
double ret = Math.max(val1, val2);
if(ret >= OptimizerUtils.DEFAULT_SIZE) return OptimizerUtils.DEFAULT_SIZE;
else return ret;
}
}
/**
* Helper utility to compute intermediate memory estimate
*
* @param gpuIntermediates intermediates for GPU
* @param cpIntermediates intermediates for CP
* @return memory estimates
*/
private double computeIntermediateMemEstimateHelper(
ArrayList<IntermediateDimensions> gpuIntermediates,
ArrayList<IntermediateDimensions> cpIntermediates) {
// Since CP operators use row-level parallelism by default
int numWorkers = (int) Math.min(OptimizerUtils.getConstrainedNumThreads(_maxNumThreads), Math.max(getDim("N"), 1));
if(DMLScript.USE_ACCELERATOR) {
// Account for potential sparse-to-dense conversion
double gpuMemBudget = IntermediateDimensions.addEstimateSizes(gpuIntermediates, 1);
double cpMemoryBudget = IntermediateDimensions.addEstimateSizes(cpIntermediates, numWorkers);
if(cpMemoryBudget > gpuMemBudget) {
double oneThreadCPMemBudget = IntermediateDimensions.addEstimateSizes(cpIntermediates, 1);
if(oneThreadCPMemBudget <= gpuMemBudget) {
// Why limit CPU ? in-order to give more opportunity to compile GPU operators
cpMemoryBudget = oneThreadCPMemBudget;
}
}
// Finally, use the maximum of CP and GPU memory budget
return IntermediateDimensions.guardedMax(cpMemoryBudget, gpuMemBudget);
}
else {
// When -gpu flag is not provided, the memory estimates for CP are not affected.
return IntermediateDimensions.addEstimateSizes(cpIntermediates, numWorkers);
}
}
@Override
protected double computeIntermediateMemEstimate( long ignoreDim1, long ignoreDim2, long ignoreNnz )
{
ArrayList<IntermediateDimensions> gpuIntermediates = new ArrayList<>();
ArrayList<IntermediateDimensions> cpIntermediates = new ArrayList<>();
if(getOp() == OpOpDnn.CONV2D) {
// Assumption: To compile a GPU conv2d operator, following should fit on the GPU:
// 1. output in dense format (i.e. computeOutputMemEstimate)
// 2. input in any format
// 3. atleast one input row in dense format
// 4. filter in dense format
// Account for potential sparse-to-dense conversion of atleast 1 input row and filter
gpuIntermediates.add(new IntermediateDimensions(this, 1, "CHW"));
gpuIntermediates.add(new IntermediateDimensions(this, "K", "CRS"));
// im2col operation preserves the worst-case sparsity of the input.
cpIntermediates.add(new IntermediateDimensions(this, "CRS", "PQ", getInput().get(0).getSparsity()));
}
else if(getOp() == OpOpDnn.CONV2D_BACKWARD_DATA) {
// Assumption: To compile a GPU conv2d_backward_data operator, following should fit on the GPU:
// 1. output in dense format (i.e. computeOutputMemEstimate)
// 2. dout in any format
// 3. atleast one dout row in dense format
// 4. filter in dense format
// Account for potential sparse-to-dense conversion of atleast 1 input row and filter
gpuIntermediates.add(new IntermediateDimensions(this, 1, "KPQ"));
gpuIntermediates.add(new IntermediateDimensions(this, "K", "CRS"));
// There are 2 intermediates: rotate180 and input to col2im for conv2d_backward_data
// rotate180 preserves the "exact" sparsity of the dout matrix
cpIntermediates.add(new IntermediateDimensions(this, "PQ", "K", getInput().get(1).getSparsity()));
// Note: worst-case sparsity for the input of col2im (of size NPQ x CRS where N is determined by degree of parallelism)
cpIntermediates.add(new IntermediateDimensions(this, "PQ", "CRS"));
}
else if(getOp() == OpOpDnn.CONV2D_BACKWARD_FILTER) {
// Assumption: To compile a GPU conv2d_backward_filter operator, following should fit on the GPU:
// 1. output in dense format (i.e. computeOutputMemEstimate)
// 2. dout in any format
// 3. atleast one dout and input row in dense format
// Account for potential sparse-to-dense conversion of atleast 1 input + dout row
gpuIntermediates.add(new IntermediateDimensions(this, 1, "CHW"));
gpuIntermediates.add(new IntermediateDimensions(this, 1, "KPQ"));
// There are 2 intermediates: im2col and rotate180 for conv2d_backward_filter
// rotate180 preserves the "exact" sparsity of the dout matrix
cpIntermediates.add(new IntermediateDimensions(this, "PQ", "K", getInput().get(1).getSparsity()));
// im2col operation preserves the worst-case sparsity of the input.
cpIntermediates.add(new IntermediateDimensions(this, "CRS", "PQ", getInput().get(0).getSparsity()));
}
else if(getOp() == OpOpDnn.MAX_POOL || getOp() == OpOpDnn.AVG_POOL) {
// Account for potential sparse-to-dense conversion of at least 1 input row
gpuIntermediates.add(new IntermediateDimensions(this, 1, "CHW"));
}
else if(getOp() == OpOpDnn.MAX_POOL_BACKWARD || getOp() == OpOpDnn.AVG_POOL_BACKWARD) {
// Account for potential sparse-to-dense conversion of at least 1 input + dout row
gpuIntermediates.add(new IntermediateDimensions(this, 1, "CHW"));
gpuIntermediates.add(new IntermediateDimensions(this, 1, "CPQ"));
}
if(gpuIntermediates.size() > 0 || cpIntermediates.size() > 0)
return computeIntermediateMemEstimateHelper(gpuIntermediates, cpIntermediates);
else
return 0;
}
@Override
protected DataCharacteristics inferOutputCharacteristics( MemoTable memo )
{
// [numRows, numCols, NNZ]
DataCharacteristics ret = new MatrixCharacteristics();
if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST ||
op == OpOpDnn.UPDATE_NESTEROV_X) {
// Same dimension as the first input
DataCharacteristics[] mc = memo.getAllInputStats(getInput());
ret = new MatrixCharacteristics(
mc[0].rowsKnown() ? mc[0].getRows() : -1,
mc[0].colsKnown() ? mc[0].getCols() : -1, -1, -1);
return ret.dimsKnown() ? ret : null;
}
else if(op == OpOpDnn.CHANNEL_SUMS) {
long numChannels = Hop.computeSizeInformation(getInput().get(1));
return new MatrixCharacteristics(numChannels, 1, -1, -1);
}
//safe return (create entry only if at least dims known)
refreshSizeInformation();
ret = _dc;
return ret.dimsKnown() ? ret : null;
}
@Override
public boolean allowsAllExecTypes()
{
return true;
}
@Override
protected ExecType optFindExecType() {
checkAndSetForcedPlatform();
if( _etypeForced != null ) {
_etype = _etypeForced;
}
else {
if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
_etype = findExecTypeByMemEstimate();
}
else {
_etype = ExecType.SPARK;
}
//check for valid CP dimensions and matrix size
checkAndSetInvalidCPDimsAndSize();
}
// TODO: Fix this after adding remaining spark instructions
_etype = !isEligibleForSpark() && _etype == ExecType.SPARK ? ExecType.CP : _etype;
//mark for recompile (forever)
setRequiresRecompileIfNecessary();
return _etype;
}
// Parameters recomputed in refreshSizeInformation and passed across many calls of getDim
private DnnParameters _cachedParams = new DnnParameters(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, _maxNumThreads);
@SuppressWarnings("null")
// stride1, stride2, padding1, padding2
// input_shape1, input_shape2, input_shape3, input_shape4,
// filter_shape1, filter_shape2, filter_shape3, filter_shape4
DnnParameters parseInput() {
Hop imageHeightHop = null; Hop filterHeightHop = null;
if(op == OpOpDnn.MAX_POOL_BACKWARD || op == OpOpDnn.AVG_POOL_BACKWARD
|| op == OpOpDnn.CONV2D
|| op == OpOpDnn.CONV2D_BACKWARD_FILTER
|| op == OpOpDnn.CONV2D_BACKWARD_DATA) {
_cachedParams.setIfUnknown(
getInput().get(6), // N
getInput().get(7), // C
getInput().get(8), // H
getInput().get(9), // W
getInput().get(10), // K
getInput().get(12), // R
getInput().get(13), // S
getInput().get(2), // stride_h
getInput().get(3), // stride_w
getInput().get(4), // pad+h
getInput().get(5), _maxNumThreads);
}
else {
_cachedParams.setIfUnknown(
getInput().get(5),
getInput().get(6),
getInput().get(7),
getInput().get(8),
getInput().get(9),
getInput().get(11),
getInput().get(12),
getInput().get(1),
getInput().get(2),
getInput().get(3),
getInput().get(4), _maxNumThreads);
}
if(INFER_TENSOR_SHAPE_FROM_PARENT_CONV_OP) {
boolean isPool = (getOp() == OpOpDnn.MAX_POOL || getOp() == OpOpDnn.AVG_POOL);
boolean isConv = getOp() == OpOpDnn.CONV2D;
boolean unknownCHWPQ = _cachedParams.C < 0 || _cachedParams.H < 0 || _cachedParams.W < 0 || _cachedParams.P < 0 || _cachedParams.Q < 0;
if((isPool || isConv) && unknownCHWPQ) {
// Only infer input shape for convolution and maxpool
inferCHWPQFromParentOp();
}
}
if(imageHeightHop == filterHeightHop && _cachedParams.R < 0 && _cachedParams.H > 0) {
// Unknown R, but known H and both are equal
// This happens for one-dimensional conv2d where H=R and H can be inferred from the parent hop
_cachedParams.R = _cachedParams.H;
}
// Compute P and Q if unknown. At script level, they are computed using following script:
// P = as.integer(floor((H + 2*pad_h - R)/stride_h + 1))
// Q = as.integer(floor((W + 2*pad_w - S)/stride_w + 1))
if(_cachedParams.P < 0 && _cachedParams.H >= 0 && _cachedParams.R >= 0 && _cachedParams.stride_h >= 0 && _cachedParams.pad_h >= 0) {
_cachedParams.P = (int) org.apache.sysds.runtime.util.DnnUtils.getP(_cachedParams.H, _cachedParams.R, _cachedParams.stride_h, _cachedParams.pad_h);
}
if(_cachedParams.Q < 0 && _cachedParams.W >= 0 && _cachedParams.S >= 0 && _cachedParams.stride_w >= 0 && _cachedParams.pad_w >= 0) {
_cachedParams.Q = (int) org.apache.sysds.runtime.util.DnnUtils.getQ(_cachedParams.W, _cachedParams.S, _cachedParams.stride_w, _cachedParams.pad_w);
}
return _cachedParams;
}
/**
* Utility method to check if the given hop is a BIAS_ADD hop
*
* @param hop the given hop
* @return true if the given hop is BIAS_ADD
*/
private static boolean isInputBiasAdd(Hop hop) {
return HopRewriteUtils.isDnn(hop, OpOpDnn.BIASADD);
}
/**
* Utility method to check if the inferred shapes are equal to the given shape with a guard for unknown
*
* @param dim1 inferred shape
* @param dim2 given shape
* @param paramType string denoting the parameter for pretty printing of the error message
*/
private static void throwExceptionIfNotEqual(int dim1, int dim2, String paramType) {
if(dim1 >= 0 && dim2 >= 0 && dim1 != dim2) {
throw new DMLRuntimeException("Inferred " + paramType + " from parent doesn't match with given " + paramType + ":" + dim1 + " != " + dim2);
}
}
/**
* Gets the values for the parameters C, H, W, P, Q from parent hops
*/
private void inferCHWPQFromParentOp() {
Hop tmp = getInput().get(0);
// Skip bias_add and go to its parent
tmp = isInputBiasAdd(tmp) ? tmp.getInput().get(0) : tmp;
Hop parentReLU = isInputReLU(tmp);
// Skip ReLU and go to its parent
tmp = (parentReLU != null) ? parentReLU : tmp;
// Cast tmp as parent
DnnOp parentOp = (tmp instanceof DnnOp) ? ((DnnOp) tmp) : null;
if(parentOp == null)
return;
else if(parentOp.getOp() == OpOpDnn.MAX_POOL || parentOp.getOp() == OpOpDnn.AVG_POOL) {
DnnParameters parentParam = parentOp.parseInput();
int prevC = _cachedParams.C; int prevH = _cachedParams.H; int prevW = _cachedParams.W;
// [C, P, Q] from maxpool becomes [C, H, W] of next op
_cachedParams.C = (_cachedParams.C < 0) ? parentParam.C : _cachedParams.C;
_cachedParams.H = (_cachedParams.H < 0) ? parentParam.P : _cachedParams.H;
_cachedParams.W = (_cachedParams.W < 0) ? parentParam.Q : _cachedParams.W;
if(LOG.isDebugEnabled()) {
LOG.debug("Inferring [C,H,W] from maxpool parent: [" + prevC + "," + prevH + "," + prevW + "]-> [" + _cachedParams.C + "," + _cachedParams.H + "," + _cachedParams.W + "]");
}
if(THROW_ERROR_IF_INFERRED_SHAPE_MISMATCH) {
throwExceptionIfNotEqual(prevC, _cachedParams.C, "C");
throwExceptionIfNotEqual(prevH, _cachedParams.H, "H");
throwExceptionIfNotEqual(prevW, _cachedParams.W, "W");
}
}
else if(parentOp.getOp() == OpOpDnn.CONV2D) {
DnnParameters parentParam = parentOp.parseInput();
int prevC = _cachedParams.C; int prevH = _cachedParams.H; int prevW = _cachedParams.W;
// [K, P, Q] from convolution becomes [C, H, W] of next op
_cachedParams.C = (_cachedParams.C < 0) ? parentParam.K : _cachedParams.C;
_cachedParams.H = (_cachedParams.H < 0) ? parentParam.P : _cachedParams.H;
_cachedParams.W = (_cachedParams.W < 0) ? parentParam.Q : _cachedParams.W;
if(LOG.isDebugEnabled()) {
LOG.debug("Inferring [C,H,W] from maxpool parent: [" + prevC + "," + prevH + "," + prevW + "]-> [" + _cachedParams.C + "," + _cachedParams.H + "," + _cachedParams.W + "]");
}
if(THROW_ERROR_IF_INFERRED_SHAPE_MISMATCH) {
throwExceptionIfNotEqual(prevC, _cachedParams.C, "C");
throwExceptionIfNotEqual(prevH, _cachedParams.H, "H");
throwExceptionIfNotEqual(prevW, _cachedParams.W, "W");
}
}
}
@Override
public void refreshSizeInformation()
{
if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT
|| op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.UPDATE_NESTEROV_X) {
// Same dimension as the first input
Hop input1 = getInput().get(0);
setDim1(input1.getDim1());
setDim2(input1.getDim2());
setNnz(-1); // cannot infer stats
return;
}
else if(op == OpOpDnn.CHANNEL_SUMS) {
long numChannels = Hop.computeSizeInformation(getInput().get(1));
setDim1(numChannels);
setDim2(1);
setNnz(-1); // cannot infer stats
return;
}
// Reset the _cachedParams to avoid incorrect sizes
_cachedParams = new DnnParameters(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, _maxNumThreads);
switch(op)
{
case MAX_POOL:
case AVG_POOL: {
setDim1(getDim("N"));
setDim2(getDim("CPQ"));
setNnz(-1); // cannot infer stats
break;
}
case MAX_POOL_BACKWARD:
case AVG_POOL_BACKWARD: {
setDim1(getDim("N"));
setDim2(getDim("CHW"));
setNnz(-1);
break;
}
case CONV2D: {
setDim1(getDim("N"));
setDim2(getDim("KPQ"));
setNnz(-1); // cannot infer stats
break;
}
case CONV2D_BACKWARD_DATA: {
setDim1(getDim("N"));
setDim2(getDim("CHW"));
setNnz(-1); // cannot infer stats
break;
}
case CONV2D_BACKWARD_FILTER: {
setDim1(getDim("K"));
setDim2(getDim("CRS"));
setNnz(-1); // cannot infer stats
break;
}
default:
throw new RuntimeException("The sizes are not refreshed for " + op.name());
}
}
@Override
public Object clone() throws CloneNotSupportedException {
DnnOp ret = new DnnOp();
//copy generic attributes
ret.clone(this, false);
//copy specific attributes
ret.op = op;
ret._maxNumThreads = _maxNumThreads;
return ret;
}
@Override
public boolean compare( Hop that )
{
if( !(that instanceof DnnOp) )
return false;
DnnOp that2 = (DnnOp)that;
boolean ret = (op == that2.op)
&& (getInput().size()==that.getInput().size())
&& _maxNumThreads == that2._maxNumThreads;
//compare all childs
if( ret ) //sizes matched
for( int i=0; i<_input.size(); i++ )
ret &= getInput().get(i) == that2.getInput().get(i);
return ret;
}
// ------------------------------------------------------------------------------------------------------
// Utility methods to get the dimensions taking into account unknown dimensions
/**
* Convenient method to get the dimensions required by ConvolutionOp.
*
* @param dimString can be K, CRS, N, CHW, KPQ, PQ
* @return either -1 or value associated with the dimString
*/
private long getDim(String dimString) {
if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT
|| op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS ||
op == OpOpDnn.UPDATE_NESTEROV_X) {
throw new RuntimeException("getDim method should not be invoked for " + op.name());
}
try {
parseInput();
} catch (DMLRuntimeException e) {
throw new RuntimeException(e);
}
Hop filter = null; // shape: K x CRS
Hop input = null; // shape: N x CHW
Hop dout = null; // shape: N x KPQ
Hop dout1 = null; // shape: N x CPQ
if(getOp() == OpOpDnn.CONV2D) {
input = getInput().get(0);
filter = getInput().get(1);
}
else if(getOp() == OpOpDnn.CONV2D_BACKWARD_DATA) {
filter = getInput().get(0);
dout = getInput().get(1);
}
else if(getOp() == OpOpDnn.CONV2D_BACKWARD_FILTER) {
input = getInput().get(0);
dout = getInput().get(1);
}
else if(getOp() == OpOpDnn.MAX_POOL || getOp() == OpOpDnn.AVG_POOL) {
input = getInput().get(0);
}
else if(getOp() == OpOpDnn.MAX_POOL_BACKWARD || getOp() == OpOpDnn.AVG_POOL_BACKWARD) {
input = getInput().get(0);
dout1 = getInput().get(1);
}
long ret = -1;
if(dimString.equals("K") && filter != null) {
ret = getNonNegative(ret, getNonNegative(_cachedParams.K, filter.getDim1()));
}
else if(dimString.equals("CRS") && filter != null) {
ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(_cachedParams.C, _cachedParams.R, _cachedParams.S), filter.getDim2()));
}
else if(dimString.equals("N") && input != null) {
ret = getNonNegative(ret, getNonNegative(_cachedParams.N, input.getDim1()));
}
else if(dimString.equals("CHW") && input != null) {
ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(_cachedParams.C, _cachedParams.H, _cachedParams.W), input.getDim2()));
}
else if(dimString.equals("N") && dout != null) {
ret = getNonNegative(ret, getNonNegative(_cachedParams.N, dout.getDim1()));
}
else if(dimString.equals("KPQ") && dout != null) {
ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(_cachedParams.K, _cachedParams.P, _cachedParams.Q), dout.getDim2()));
}
else if(dimString.equals("N") && dout1 != null) {
ret = getNonNegative(ret, getNonNegative(_cachedParams.N, dout1.getDim1()));
}
else if(dimString.equals("CPQ") && dout1 != null) {
ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(_cachedParams.C, _cachedParams.P, _cachedParams.Q), dout1.getDim2()));
}
else if(dimString.equals("K")) {
ret = getNonNegative(ret, _cachedParams.K >= 0 ? _cachedParams.K : -1);
}
else if(dimString.equals("CRS")) {
ret = getNonNegative(ret, nonNegativeMultiply(_cachedParams.C, _cachedParams.R, _cachedParams.S));
}
else if(dimString.equals("N")) {
ret = getNonNegative(ret, _cachedParams.N >= 0 ? _cachedParams.N : -1);
}
else if(dimString.equals("CHW")) {
ret = getNonNegative(ret, nonNegativeMultiply(_cachedParams.C, _cachedParams.H, _cachedParams.W));
}
else if(dimString.equals("KPQ")) {
ret = getNonNegative(ret, nonNegativeMultiply(_cachedParams.K, _cachedParams.P, _cachedParams.Q));
}
else if(dimString.equals("PQ")) {
ret = getNonNegative(ret, nonNegativeMultiply(_cachedParams.P, _cachedParams.Q));
}
else if(dimString.equals("CPQ")) {
ret = getNonNegative(ret, nonNegativeMultiply(_cachedParams.C, _cachedParams.P, _cachedParams.Q));
}
else {
throw new RuntimeException("Unsupported dimension:" + dimString + " for operator " + getOp().name());
}
if(LOG.isDebugEnabled() && ret < 0) {
LOG.debug("Unknown dimension " + dimString + " for DnnOp:" + op.name() +
" img_dim=[" + _cachedParams.N + " " + _cachedParams.C + " " + _cachedParams.H + " " + _cachedParams.W + "]" +
" filter_dim=[" + _cachedParams.K + " " + _cachedParams.C + " " + _cachedParams.R + " " + _cachedParams.S + "]" +
" output_feature_map=[" + _cachedParams.P + " " + _cachedParams.Q + "] stride=[" + _cachedParams.stride_h + " " + _cachedParams.stride_w + "]" +
" pad=[" + _cachedParams.pad_h + " " + _cachedParams.pad_w + "]");
}
return ret;
}
private static long nonNegativeMultiply(long val1, long val2, long val3) {
if(val1 >= 0 && val2 >= 0 && val3 >= 0) {
return val1 * val2 * val3;
}
else return -1;
}
private static long nonNegativeMultiply(long val1, long val2) {
if(val1 >= 0 && val2 >= 0) {
return val1 * val2;
}
else return -1;
}
private static long getNonNegative(long val1, long val2) {
if(val1 >= 0 && val2 >= 0) {
if(val1 == val2) return val1;
else throw new RuntimeException("Incorrect dimensions in DnnOp: " + val1 + " != " + val2);
}
else if(val1 >= 0) return val1;
else if(val2 >= 0) return val2;
else return -1;
}
// ------------------------------------------------------------------------------------------------------
}