blob: ba17e3c3c6d1e822ebd145ae029839bd5e9537f1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.instructions.cp;
import java.util.ArrayList;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.matrix.data.DnnParameters;
import org.apache.sysds.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysds.runtime.matrix.data.LibMatrixDNN.PoolingType;
import org.apache.sysds.runtime.matrix.data.LibMatrixNative;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.DnnUtils;
import org.apache.sysds.utils.NativeHelper;
public class DnnCPInstruction extends UnaryCPInstruction {
private static final Log LOG = LogFactory.getLog(DnnCPInstruction.class.getName());
private static boolean warnedUnderUtilitization = false;
private final CPOperand _in2;
private final CPOperand _in3;
private final CPOperand _in4;
private final CPOperand _in5;
private final CPOperand _in6;
private final CPOperand _in7;
private final CPOperand _in8;
private final CPOperand _out2;
private final CPOperand _out3;
private final CPOperand _out4;
private final CPOperand _out5;
private final ArrayList<CPOperand> _input_shape;
private final ArrayList<CPOperand> _filter_shape;
private final ArrayList<CPOperand> _stride;
private final ArrayList<CPOperand> _padding;
private final int _numThreads;
private final double _intermediateMemoryBudget;
public DnnCPInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand out,
ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape,
ArrayList<CPOperand> filter_shape, int numThreads, double intermediateMemoryBudget, String opcode, String istr) {
super(CPType.Dnn, null, in, out, opcode, istr);
_in2 = in2;
_in3 = in3;
_in4 = null; _in5 = null; _in6 = null; _in7 = null; _in8 = null;
_out2 = null; _out3 = null; _out4 = null; _out5 = null;
_stride = stride;
_padding = padding;
_input_shape = input_shape;
_filter_shape = filter_shape;
_numThreads = numThreads;
_intermediateMemoryBudget = intermediateMemoryBudget;
}
public DnnCPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode, String istr, int numThreads, double intermediateMemoryBudget) {
this(in, in2, null, out, null, null, null, null, numThreads, intermediateMemoryBudget, opcode, istr);
if( !(opcode.equals("bias_add") || opcode.equals("relu_backward") || opcode.equals("bias_multiply") ) ) {
throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be bias_add or bias_multiply or relu_backward, but found " + opcode);
}
}
private DnnCPInstruction(CPOperand in, CPOperand out, String opcode, String istr,
ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape,
ArrayList<CPOperand> filter_shape, int numThreads, double intermediateMemoryBudget) {
this(in, null, null, out, stride, padding, input_shape, filter_shape, numThreads, intermediateMemoryBudget, opcode, istr);
}
public DnnCPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode,
String istr, ArrayList<CPOperand> stride,
ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape,
ArrayList<CPOperand> filter_shape, int numThreads, double intermediateMemoryBudget) {
this(in, in2, null, out, stride, padding, input_shape, filter_shape, numThreads, intermediateMemoryBudget, opcode, istr);
}
public DnnCPInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand out, String opcode,
String istr, ArrayList<CPOperand> stride,
ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape,
ArrayList<CPOperand> filter_shape, int numThreads, double intermediateMemoryBudget) {
this(in, in2, in3, out, stride, padding, input_shape, filter_shape, numThreads, intermediateMemoryBudget, opcode, istr);
}
public DnnCPInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5,
CPOperand in6, CPOperand in7, CPOperand in8,
CPOperand out, CPOperand out2, CPOperand out3, CPOperand out4, CPOperand out5, String opcode, String istr,
double intermediateMemoryBudget) throws DMLRuntimeException {
super(CPType.Dnn, null, in1, out, opcode, istr);
_in2 = in2;
_in3 = in3;
_in4 = in4;
_in5 = in5;
_in6 = in6;
_in7 = in7;
_in8 = in8;
_out2 = out2;
_out3 = out3;
_out4 = out4;
_out5 = out5;
_stride = null;
_padding = null;
_input_shape = null;
_filter_shape = null;
_numThreads = 0;
_intermediateMemoryBudget = intermediateMemoryBudget;
}
public static DnnCPInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("relu_maxpooling") ||
opcode.equalsIgnoreCase("avgpooling")) {
InstructionUtils.checkNumFields(parts, 16);
// stride1, stride2, padding1, padding2
// input_shape1, input_shape2, input_shape3, input_shape4,
// filter_shape1, filter_shape2, filter_shape3, filter_shape4, k
CPOperand in = new CPOperand(parts[1]);
CPOperand out = new CPOperand(parts[14]);
ArrayList<CPOperand> stride = new ArrayList<>();
ArrayList<CPOperand> padding = new ArrayList<>();
ArrayList<CPOperand> input_shape = new ArrayList<>();
ArrayList<CPOperand> filter_shape = new ArrayList<>();
stride.add(new CPOperand(parts[2]));
stride.add(new CPOperand(parts[3]));
padding.add(new CPOperand(parts[4]));
padding.add(new CPOperand(parts[5]));
input_shape.add(new CPOperand(parts[6]));
input_shape.add(new CPOperand(parts[7]));
input_shape.add(new CPOperand(parts[8]));
input_shape.add(new CPOperand(parts[9]));
filter_shape.add(new CPOperand(parts[10]));
filter_shape.add(new CPOperand(parts[11]));
filter_shape.add(new CPOperand(parts[12]));
filter_shape.add(new CPOperand(parts[13]));
int k = Integer.parseInt(parts[15]);
return new DnnCPInstruction(in, out, opcode, str, stride,
padding, input_shape, filter_shape, k, Double.parseDouble(parts[16]));
}
else if (opcode.equalsIgnoreCase("maxpooling_backward") || opcode.equalsIgnoreCase("relu_maxpooling_backward")
|| opcode.equalsIgnoreCase("avgpooling_backward")
|| opcode.equalsIgnoreCase("conv2d")
|| opcode.equalsIgnoreCase("conv2d_backward_filter")
|| opcode.equalsIgnoreCase("conv2d_backward_data")) {
InstructionUtils.checkNumFields(parts, 17);
// dout, stride1, stride2, padding1, padding2
// input_shape1, input_shape2, input_shape3, input_shape4,
// filter_shape1, filter_shape2, filter_shape3, filter_shape4, k
CPOperand in = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[15]);
ArrayList<CPOperand> stride = new ArrayList<>();
ArrayList<CPOperand> padding = new ArrayList<>();
ArrayList<CPOperand> input_shape = new ArrayList<>();
ArrayList<CPOperand> filter_shape = new ArrayList<>();
stride.add(new CPOperand(parts[3]));
stride.add(new CPOperand(parts[4]));
padding.add(new CPOperand(parts[5]));
padding.add(new CPOperand(parts[6]));
input_shape.add(new CPOperand(parts[7]));
input_shape.add(new CPOperand(parts[8]));
input_shape.add(new CPOperand(parts[9]));
input_shape.add(new CPOperand(parts[10]));
filter_shape.add(new CPOperand(parts[11]));
filter_shape.add(new CPOperand(parts[12]));
filter_shape.add(new CPOperand(parts[13]));
filter_shape.add(new CPOperand(parts[14]));
int k = Integer.parseInt(parts[16]);
return new DnnCPInstruction(in, in2, out, opcode, str, stride,
padding, input_shape, filter_shape, k, Double.parseDouble(parts[17]));
}
else if (opcode.equalsIgnoreCase("conv2d_bias_add")) {
InstructionUtils.checkNumFields(parts, 18);
// dout, stride1, stride2, padding1, padding2
// input_shape1, input_shape2, input_shape3, input_shape4,
// filter_shape1, filter_shape2, filter_shape3, filter_shape4, k
CPOperand in = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand out = new CPOperand(parts[16]);
ArrayList<CPOperand> stride = new ArrayList<>();
ArrayList<CPOperand> padding = new ArrayList<>();
ArrayList<CPOperand> input_shape = new ArrayList<>();
ArrayList<CPOperand> filter_shape = new ArrayList<>();
stride.add(new CPOperand(parts[4]));
stride.add(new CPOperand(parts[5]));
padding.add(new CPOperand(parts[6]));
padding.add(new CPOperand(parts[7]));
input_shape.add(new CPOperand(parts[8]));
input_shape.add(new CPOperand(parts[9]));
input_shape.add(new CPOperand(parts[10]));
input_shape.add(new CPOperand(parts[11]));
filter_shape.add(new CPOperand(parts[12]));
filter_shape.add(new CPOperand(parts[13]));
filter_shape.add(new CPOperand(parts[14]));
filter_shape.add(new CPOperand(parts[15]));
int k = Integer.parseInt(parts[17]);
return new DnnCPInstruction(in, in2, in3, out, opcode, str, stride,
padding, input_shape, filter_shape, k, Double.parseDouble(parts[18]));
}
else if (opcode.equalsIgnoreCase("bias_add") || opcode.equals("relu_backward") || opcode.equalsIgnoreCase("bias_multiply") ) {
InstructionUtils.checkNumFields(parts, 5);
CPOperand in = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
int k = Integer.parseInt(parts[4]);
return new DnnCPInstruction(in, in2, out, opcode, str, k, Double.parseDouble(parts[5]));
}
else if (opcode.equalsIgnoreCase("batch_norm2d")) {
InstructionUtils.checkNumFields(parts, 13);
CPOperand in1 = new CPOperand(parts[1]); // image
CPOperand in2 = new CPOperand(parts[2]); // scale
CPOperand in3 = new CPOperand(parts[3]); // bias
CPOperand in4 = new CPOperand(parts[4]); // runningMean
CPOperand in5 = new CPOperand(parts[5]); // runningVar
CPOperand in6 = new CPOperand(parts[6]); // mode
CPOperand in7 = new CPOperand(parts[7]); // epsilon
CPOperand in8 = new CPOperand(parts[8]); // exponentialAverageFactor
CPOperand out = new CPOperand(parts[9]); // ret
CPOperand out2 = new CPOperand(parts[10]); // retRunningMean
CPOperand out3 = new CPOperand(parts[11]); // retRunningVar
CPOperand out4 = new CPOperand(parts[12]); // resultSaveMean
CPOperand out5 = new CPOperand(parts[13]); // resultSaveInvVariance
return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0);
}
else if (opcode.equalsIgnoreCase("batch_norm2d_backward")) {
InstructionUtils.checkNumFields(parts, 9);
CPOperand in1 = new CPOperand(parts[1]); // image
CPOperand in2 = new CPOperand(parts[2]); // dout
CPOperand in3 = new CPOperand(parts[3]); // scale
CPOperand in4 = new CPOperand(parts[4]); // epsilon
CPOperand in5 = new CPOperand(parts[5]); // resultSaveMean
CPOperand in6 = new CPOperand(parts[6]); // resultSaveInvVariance
CPOperand out = new CPOperand(parts[7]); // dX
CPOperand out2 = new CPOperand(parts[8]); // dScale
CPOperand out3 = new CPOperand(parts[9]); // dBias
return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0);
}
else {
throw new DMLRuntimeException("Unknown opcode while parsing a DnnCPInstruction: " + str);
}
}
private static int getScalarInput(ExecutionContext ec, ArrayList<CPOperand> aL, int index) {
return (int) ec.getScalarInput(aL.get(index)).getLongValue();
}
public void processReluBackwardInstruction(ExecutionContext ec) {
// (X > 0) * dout
MatrixBlock input = ec.getMatrixInput(input1.getName());
MatrixBlock dout = ec.getMatrixInput(_in2.getName());
MatrixBlock outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(),
input.isInSparseFormat() || dout.isInSparseFormat() );
if( !input.isEmpty() && !dout.isEmpty() ) { //sparse-safe
outputBlock.allocateBlock();
LibMatrixDNN.reluBackward(input, dout, outputBlock, _numThreads);
}
// release inputs/outputs
ec.releaseMatrixInput(input1.getName());
ec.releaseMatrixInput(_in2.getName());
ec.setMatrixOutput(getOutputVariableName(), outputBlock);
}
public void processBiasAddInstruction(ExecutionContext ec) {
MatrixBlock input = ec.getMatrixInput(input1.getName());
MatrixBlock bias = ec.getMatrixInput(_in2.getName());
MatrixBlock outputBlock = null;
if(bias.getNumColumns() != 1) {
throw new DMLRuntimeException("Expected the number of columns of bias matrix to be 1, but found " + bias.getNumColumns());
}
if(input.isEmpty() && bias.isEmpty()) {
outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), true);
}
else if(bias.isEmpty()) {
outputBlock = new MatrixBlock(input);
}
else {
// As we always fill the output first with bias
outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), false);
outputBlock.allocateDenseBlock();
LibMatrixDNN.biasAdd(input, bias, outputBlock, _numThreads);
}
// release inputs/outputs
ec.releaseMatrixInput(input1.getName());
ec.releaseMatrixInput(_in2.getName());
ec.setMatrixOutput(getOutputVariableName(), outputBlock);
}
public void processBiasMultiplyInstruction(ExecutionContext ec) {
MatrixBlock input = ec.getMatrixInput(input1.getName());
MatrixBlock bias = ec.getMatrixInput(_in2.getName());
MatrixBlock outputBlock = null;
if(bias.getNumColumns() != 1) {
throw new DMLRuntimeException("Expected the number of columns of bias matrix to be 1, but found " + bias.getNumColumns());
}
if(bias.isEmpty()) {
// Anything multiplied by zero is zero
outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), true);
}
else {
// As we always fill the output first with bias
outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(),
input.isInSparseFormat()).allocateBlock();
LibMatrixDNN.biasMultiply(input, bias, outputBlock, _numThreads);
}
// release inputs/outputs
ec.releaseMatrixInput(input1.getName());
ec.releaseMatrixInput(_in2.getName());
ec.setMatrixOutput(getOutputVariableName(), outputBlock);
}
public void processBatchNorm2dInstruction(ExecutionContext ec) {
MatrixBlock image = ec.getMatrixInput(input1.getName());
MatrixBlock scale = ec.getMatrixInput(_in2.getName());
MatrixBlock bias = ec.getMatrixInput(_in3.getName());
MatrixBlock runningMean = ec.getMatrixInput(_in4.getName());
MatrixBlock runningVar = ec.getMatrixInput(_in5.getName());
String phase = ec.getScalarInput(_in6).getStringValue();
double epsilon = ec.getScalarInput(_in7).getDoubleValue();
double mu = ec.getScalarInput(_in8.getName(), _in8.getValueType(), _in8.isLiteral()).getDoubleValue();
MatrixBlock ret = new MatrixBlock(image.getNumRows(), image.getNumColumns(), false).allocateBlock();
MatrixBlock retRunningMean = new MatrixBlock(runningMean.getNumRows(), runningMean.getNumColumns(), false).allocateBlock();
MatrixBlock retRunningVar = new MatrixBlock(runningVar.getNumRows(), runningVar.getNumColumns(), false).allocateBlock();
MatrixBlock resultSaveMean = new MatrixBlock(runningMean.getNumRows(), runningMean.getNumColumns(), false).allocateBlock();
MatrixBlock resultSaveInvVariance = new MatrixBlock(runningVar.getNumRows(), runningVar.getNumColumns(), false).allocateBlock();
LibMatrixDNN.batchNorm2D(image, scale, bias, runningMean, runningVar, phase, epsilon, mu, ret,
retRunningMean, retRunningVar, resultSaveMean, resultSaveInvVariance);
// release inputs/outputs
ec.releaseMatrixInput(input1.getName(), _in2.getName(),
_in3.getName(), _in4.getName(), _in5.getName());
ec.setMatrixOutput(output.getName(), ret);
ec.setMatrixOutput(_out2.getName(), retRunningMean);
ec.setMatrixOutput(_out3.getName(), retRunningVar);
ec.setMatrixOutput(_out4.getName(), resultSaveMean);
ec.setMatrixOutput(_out5.getName(), resultSaveInvVariance);
}
public void processBatchNorm2dBackwardInstruction(ExecutionContext ec) {
MatrixBlock image = ec.getMatrixInput(input1.getName());
MatrixBlock dout = ec.getMatrixInput(_in2.getName());
MatrixBlock scale = ec.getMatrixInput(_in3.getName());
double epsilon = ec.getScalarInput(_in4).getDoubleValue();
MatrixBlock resultSaveMean = ec.getMatrixInput(_in5.getName());
MatrixBlock resultSaveInvVariance = ec.getMatrixInput(_in6.getName());
MatrixBlock dX = new MatrixBlock(image.getNumRows(), image.getNumColumns(), false).allocateBlock();
MatrixBlock dScale = new MatrixBlock(scale.getNumRows(), scale.getNumColumns(), false).allocateBlock();
MatrixBlock dBias = new MatrixBlock(scale.getNumRows(), scale.getNumColumns(), false).allocateBlock();
LibMatrixDNN.batchNorm2DBackward(image, dout, scale, epsilon, resultSaveMean, resultSaveInvVariance, dX, dScale, dBias);
// release inputs/outputs
ec.releaseMatrixInput(input1.getName(), _in2.getName(),
_in3.getName(), _in5.getName(), _in6.getName());
ec.setMatrixOutput(output.getName(), dX);
ec.setMatrixOutput(_out2.getName(), dScale);
ec.setMatrixOutput(_out3.getName(), dBias);
}
// Assumption: enableNative && NativeHelper.isNativeLibraryLoaded() is true
// This increases the number of native calls. For example:the cases where filter is sparse but input is dense
private static boolean isFilterSparse(MatrixBlock filter) {
long numElems = filter.getNumRows()*filter.getNumColumns();
// if filter is less than 10 MB in dense format (which handles almost all the cases).
// In fact, using threshold of 1 MB is still sufficient for common CNNs.
if(filter.isInSparseFormat() && numElems < 10e+6)
filter.sparseToDense();
return filter.isInSparseFormat();
}
@Override
public void processInstruction(ExecutionContext ec) {
if (instOpcode.equalsIgnoreCase("bias_add")) {
processBiasAddInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase("bias_multiply")) {
processBiasMultiplyInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase("relu_backward")) {
processReluBackwardInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase("batch_norm2d")) {
processBatchNorm2dInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase("batch_norm2d_backward")) {
processBatchNorm2dBackwardInstruction(ec);
return;
}
// acquire inputs
MatrixBlock outputBlock = null;
MatrixBlock matBlock = instOpcode.equalsIgnoreCase("avgpooling_backward") ? null : ec.getMatrixInput(input1.getName());
int pad_h = getScalarInput(ec, _padding, 0);
int pad_w = getScalarInput(ec, _padding, 1);
int stride_h = getScalarInput(ec, _stride, 0);
int stride_w = getScalarInput(ec, _stride, 1);
int N = getScalarInput(ec, _input_shape, 0);
int C = getScalarInput(ec, _input_shape, 1);
int H = getScalarInput(ec, _input_shape, 2);
int W = getScalarInput(ec, _input_shape, 3);
int K = getScalarInput(ec, _filter_shape, 0);
int R = getScalarInput(ec, _filter_shape, 2);
int S = getScalarInput(ec, _filter_shape, 3);
int P = (int) DnnUtils.getP(H, R, stride_h, pad_h);
int Q = (int) DnnUtils.getQ(W, S, stride_w, pad_w);
DnnParameters params = new DnnParameters(N, C, H, W, K, R, S, stride_h, stride_w, pad_h, pad_w, _numThreads);
params.enableNative = NativeHelper.isNativeLibraryLoaded();
if (instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling") ||
instOpcode.equalsIgnoreCase("avgpooling")) {
if(matBlock.isEmpty()) {
outputBlock = new MatrixBlock(N, C*P*Q, true);
}
else {
outputBlock = new MatrixBlock(N, C*P*Q, false).allocateBlock();
PoolingType poolType = (instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) ? PoolingType.MAX : PoolingType.AVG;
if(instOpcode.equalsIgnoreCase("relu_maxpooling"))
params.minValForMaxPoolOperations = 0;
LibMatrixDNN.pooling(matBlock, outputBlock, params, poolType);
}
}
else if (instOpcode.equalsIgnoreCase("maxpooling_backward") || instOpcode.equalsIgnoreCase("relu_maxpooling_backward") ||
instOpcode.equalsIgnoreCase("avgpooling_backward")) {
MatrixBlock dout = ec.getMatrixInput(_in2.getName());
boolean isEmpty = instOpcode.equalsIgnoreCase("avgpooling_backward") ? dout.isEmpty() : (matBlock.isEmpty() || dout.isEmpty());
if(isEmpty) {
outputBlock = new MatrixBlock(N, C*H*W, true);
}
else {
outputBlock = new MatrixBlock(N, C*H*W, false).allocateBlock();
PoolingType poolType = (instOpcode.equalsIgnoreCase("maxpooling_backward") || instOpcode.equalsIgnoreCase("relu_maxpooling_backward")) ? PoolingType.MAX : PoolingType.AVG;
boolean performReLUBackward = instOpcode.equalsIgnoreCase("relu_maxpooling_backward");
if(performReLUBackward)
params.minValForMaxPoolOperations = 0;
LibMatrixDNN.poolingBackward(matBlock, dout, outputBlock, params, performReLUBackward, poolType);
}
ec.releaseMatrixInput(_in2.getName());
}
else if (instOpcode.equalsIgnoreCase("conv2d")) {
resetNumThreads(params, C*R*S, P*Q, matBlock.getNonZeros() / (matBlock.getNumRows()*matBlock.getNumColumns()));
MatrixBlock filter = ec.getMatrixInput(_in2.getName());
if(filter.isEmpty() || matBlock.isEmpty()) {
outputBlock = new MatrixBlock(N, K*P*Q, true);
}
else {
boolean sparse = matBlock.isUltraSparse(false) && params.bias == null
&& matBlock.getInMemorySize() < MatrixBlock.estimateSizeDenseInMemory(N, K*P*Q);
outputBlock = new MatrixBlock(N, K*P*Q, sparse).allocateBlock();
if(params.enableNative && !isFilterSparse(filter) && !matBlock.isInSparseFormat())
LibMatrixNative.conv2d(matBlock, filter, outputBlock, params);
else
LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params);
}
ec.releaseMatrixInput(_in2.getName());
}
else if (instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
resetNumThreads(params, C*R*S, P*Q, matBlock.getNonZeros() / (matBlock.getNumRows()*matBlock.getNumColumns()));
MatrixBlock filter = ec.getMatrixInput(_in3.getName());
MatrixBlock bias = ec.getMatrixInput(_in2.getName());
if(bias.getNumRows() != params.K || bias.getNumColumns() != 1) {
throw new DMLRuntimeException("Incorrect shape of bias matrix: [" + bias.getNumRows() + " " + bias.getNumColumns() + "]. "
+ "Expected: [" + params.K + ", 1]");
}
boolean isOutputConvEmpty = filter.isEmpty() || matBlock.isEmpty();
if(isOutputConvEmpty && bias.isEmpty()) {
// bias_add(empty mb, empty mb) = empty mb
outputBlock = new MatrixBlock(N, K*P*Q, true);
}
else if(isOutputConvEmpty && !bias.isEmpty()) {
// Add bias to empty output block
// bias_add(empty mb, bias)
outputBlock = new MatrixBlock(N, K*P*Q, false).allocateBlock();
for(int n = 0; n < params.N; n++)
DnnUtils.fillBias(bias, outputBlock.getDenseBlockValues(),
n, n+1, params.N, params.K, params.P*params.Q);
}
else {
outputBlock = new MatrixBlock(N, K*P*Q, false).allocateBlock();
if(!bias.isEmpty()) {
// Handle situation where both input and filter are non empty, but bias is empty
params.bias = bias;
}
if(params.enableNative && !isFilterSparse(filter) && !matBlock.isInSparseFormat())
LibMatrixNative.conv2d(matBlock, filter, outputBlock, params);
else
LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params);
}
ec.releaseMatrixInput(_in3.getName(), _in2.getName());
}
else if (instOpcode.equalsIgnoreCase("conv2d_backward_filter")) {
MatrixBlock dout = ec.getMatrixInput(_in2.getName());
if(dout.isEmpty() || matBlock.isEmpty()) {
outputBlock = new MatrixBlock(K, C*R*S, true);
}
else {
outputBlock = new MatrixBlock(K, C*R*S, false).allocateBlock();
if(params.enableNative && !matBlock.isInSparseFormat() && !dout.isInSparseFormat())
LibMatrixNative.conv2dBackwardFilter(matBlock, dout, outputBlock, params);
else
LibMatrixDNN.conv2dBackwardFilter(matBlock, dout, outputBlock, params);
}
ec.releaseMatrixInput(_in2.getName());
}
else if (instOpcode.equalsIgnoreCase("conv2d_backward_data")) {
MatrixBlock dout = ec.getMatrixInput(_in2.getName());
if(dout.isEmpty() || matBlock.isEmpty()) {
outputBlock = new MatrixBlock(N, C * H * W, true);
}
else {
outputBlock = new MatrixBlock(N, C * H * W, false).allocateBlock();
if(params.enableNative && !isFilterSparse(matBlock) && !dout.isInSparseFormat())
LibMatrixNative.conv2dBackwardData(matBlock, dout, outputBlock, params);
else
LibMatrixDNN.conv2dBackwardData(matBlock, dout, outputBlock, params);
}
ec.releaseMatrixInput(_in2.getName());
}
else {
throw new DMLRuntimeException("Unsupported op code " + instOpcode);
}
// release inputs/outputs
if(!instOpcode.equalsIgnoreCase("avgpooling_backward"))
ec.releaseMatrixInput(input1.getName());
ec.setMatrixOutput(getOutputVariableName(), outputBlock);
}
/**
* Reset the number of thread to respect the intermediate CP memory budget
*
* @param params convolution parameters
* @param numRows number of rows of intermediate matrix used per thread
* @param numCols number of rows of intermediate matrix used per thread
* @param sparsity sparsity of intermediate matrix used per thread
*/
private void resetNumThreads(DnnParameters params, int numRows, int numCols, double sparsity) {
if(DMLScript.USE_ACCELERATOR) {
double memBudget1Thread = OptimizerUtils.estimateSizeExactSparsity(numRows, numCols, sparsity);
int limitedDegreeOfParallelism = (int) Math.floor(_intermediateMemoryBudget / memBudget1Thread);
if(params.numThreads > limitedDegreeOfParallelism) {
params.numThreads = limitedDegreeOfParallelism;
if(!warnedUnderUtilitization)
LOG.warn("CPU Under-utilization to respect the intermediate memory budget. To avoid this, please try reducing the mini-batch or forcing gpu execution.");
warnedUnderUtilitization = true;
}
}
}
}