blob: 3ac1b5ad07ab213cda4e0212af6cb08bb1c92171 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.instructions.gpu;
import java.util.ArrayList;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.gpu.context.ExecutionConfig;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;
import org.apache.sysds.runtime.matrix.data.LibMatrixCuDNN;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.LibMatrixDNN.PoolingType;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.util.DnnUtils;
import org.apache.sysds.utils.GPUStatistics;
import jcuda.Pointer;
public class DnnGPUInstruction extends GPUInstruction {
private CPOperand _input1;
private CPOperand _input2;
private CPOperand _input3;
private CPOperand _input4;
private CPOperand _input5;
private CPOperand _input6;
private CPOperand _input7;
private CPOperand _input8;
private CPOperand _output;
private CPOperand _output2;
private CPOperand _output3;
private CPOperand _output4;
private CPOperand _output5;
private ArrayList<CPOperand> _input_shape;
private ArrayList<CPOperand> _filter_shape;
private ArrayList<CPOperand> _stride = new ArrayList<>();
private ArrayList<CPOperand> _padding = new ArrayList<>();
private double _intermediateMemoryBudget = 0;
public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, double intermediateMemoryBudget) {
super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
if (!(opcode.equals("bias_add") || opcode.equals("bias_multiply") || opcode.equals("relu_backward"))) {
throw new DMLRuntimeException(
"Incorrect usage. Expected the opcode to be bias_add or bias_multiply or relu_backward, but found "
+ opcode);
}
_input1 = in1;
_input2 = in2;
_gputype = GPUINSTRUCTION_TYPE.Dnn;
_output = out;
_intermediateMemoryBudget = intermediateMemoryBudget;
}
public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5, CPOperand in6,
CPOperand out, CPOperand out2, String opcode, String istr,
double intermediateMemoryBudget) throws DMLRuntimeException {
super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
_input1 = in1;
_input2 = in2;
_input3 = in3;
_input4 = in4;
_input5 = in5;
_input6 = in6;
_gputype = GPUINSTRUCTION_TYPE.Dnn;
_output = out;
_output2 = out2;
_intermediateMemoryBudget = intermediateMemoryBudget;
}
public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5,
CPOperand in6, CPOperand in7, CPOperand in8,
CPOperand out, CPOperand out2, CPOperand out3, CPOperand out4, CPOperand out5, String opcode, String istr,
double intermediateMemoryBudget) throws DMLRuntimeException {
super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
_input1 = in1;
_input2 = in2;
_input3 = in3;
_input4 = in4;
_input5 = in5;
_input6 = in6;
_input7 = in7;
_input8 = in8;
_gputype = GPUINSTRUCTION_TYPE.Dnn;
_output = out;
_output2 = out2;
_output3 = out3;
_output4 = out4;
_output5 = out5;
_intermediateMemoryBudget = intermediateMemoryBudget;
}
public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr,
double intermediateMemoryBudget) throws DMLRuntimeException {
super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
if( !opcode.equals("channel_sums") ) {
throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be channel_sums, but found " + opcode);
}
_input1 = in1;
_input2 = in2;
_input3 = in3;
_gputype = GPUINSTRUCTION_TYPE.Dnn;
_output = out;
_intermediateMemoryBudget = intermediateMemoryBudget;
}
public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String istr,
double intermediateMemoryBudget) throws DMLRuntimeException {
super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
if( !opcode.equals("update_nesterov_x") ) {
throw new DMLRuntimeException("Incorrect opcode: " + opcode);
}
_input1 = in1;
_input2 = in2;
_input3 = in3;
_input4 = in4;
_gputype = GPUINSTRUCTION_TYPE.Dnn;
_output = out;
_intermediateMemoryBudget = intermediateMemoryBudget;
}
public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode,
String istr, ArrayList<CPOperand> stride,
ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape,
ArrayList<CPOperand> filter_shape, double intermediateMemoryBudget)
{
this(in1, in2, out, opcode, istr, stride, padding, input_shape, filter_shape, intermediateMemoryBudget);
_input3 = in3;
}
public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand out, String opcode,
String istr, ArrayList<CPOperand> stride,
ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape,
ArrayList<CPOperand> filter_shape, double intermediateMemoryBudget)
{
super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
_gputype = GPUINSTRUCTION_TYPE.Dnn;
_input1 = in1;
_input2 = in2;
_output = out;
_stride = stride;
_padding = padding;
_input_shape = input_shape;
_filter_shape = filter_shape;
_intermediateMemoryBudget = intermediateMemoryBudget;
}
public DnnGPUInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5, CPOperand in6,
CPOperand out, String opcode, String istr, double intermediateMemoryBudget) {
super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
if( !opcode.equals("batch_norm2d_test") ) {
throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be batch_norm2d_test, but found " + opcode);
}
_input1 = in;
_input2 = in2;
_input3 = in3;
_input4 = in4;
_input5 = in5;
_input6 = in6;
_gputype = GPUINSTRUCTION_TYPE.Dnn;
_output = out;
_intermediateMemoryBudget = intermediateMemoryBudget;
}
public static DnnGPUInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
if( ( opcode.equalsIgnoreCase("conv2d")
|| opcode.equalsIgnoreCase("conv2d_backward_filter")
|| opcode.equalsIgnoreCase("conv2d_backward_data")) ) {
InstructionUtils.checkNumFields(parts, 16);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[15]);
ArrayList<CPOperand> stride = new ArrayList<>();
ArrayList<CPOperand> padding = new ArrayList<>();
ArrayList<CPOperand> input_shape = new ArrayList<>();
ArrayList<CPOperand> filter_shape = new ArrayList<>();
stride.add(new CPOperand(parts[3]));
stride.add(new CPOperand(parts[4]));
padding.add(new CPOperand(parts[5]));
padding.add(new CPOperand(parts[6]));
input_shape.add(new CPOperand(parts[7]));
input_shape.add(new CPOperand(parts[8]));
input_shape.add(new CPOperand(parts[9]));
input_shape.add(new CPOperand(parts[10]));
filter_shape.add(new CPOperand(parts[11]));
filter_shape.add(new CPOperand(parts[12]));
filter_shape.add(new CPOperand(parts[13]));
filter_shape.add(new CPOperand(parts[14]));
return new DnnGPUInstruction(in1, in2, out, opcode, str, stride,
padding, input_shape, filter_shape, Double.parseDouble(parts[16]));
}
else if( opcode.equalsIgnoreCase("maxpooling_backward") || opcode.equalsIgnoreCase("avgpooling_backward") ) {
boolean withMaxPoolOut = false;
if(parts.length == 18) {
withMaxPoolOut = true;
}
else
InstructionUtils.checkNumFields(parts, 16);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = withMaxPoolOut ? new CPOperand(parts[15]) : null;
CPOperand out = withMaxPoolOut ? new CPOperand(parts[16]) : new CPOperand(parts[15]);
double memBudget = withMaxPoolOut ? Double.parseDouble(parts[17]) : Double.parseDouble(parts[16]);
ArrayList<CPOperand> stride = new ArrayList<>();
ArrayList<CPOperand> padding = new ArrayList<>();
ArrayList<CPOperand> input_shape = new ArrayList<>();
ArrayList<CPOperand> filter_shape = new ArrayList<>();
stride.add(new CPOperand(parts[3]));
stride.add(new CPOperand(parts[4]));
padding.add(new CPOperand(parts[5]));
padding.add(new CPOperand(parts[6]));
input_shape.add(new CPOperand(parts[7]));
input_shape.add(new CPOperand(parts[8]));
input_shape.add(new CPOperand(parts[9]));
input_shape.add(new CPOperand(parts[10]));
filter_shape.add(new CPOperand(parts[11]));
filter_shape.add(new CPOperand(parts[12]));
filter_shape.add(new CPOperand(parts[13]));
filter_shape.add(new CPOperand(parts[14]));
return new DnnGPUInstruction(in1, in2, in3, out, opcode, str, stride,
padding, input_shape, filter_shape, memBudget);
}
else if (opcode.equalsIgnoreCase("conv2d_bias_add")) {
InstructionUtils.checkNumFields(parts, 17);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand out = new CPOperand(parts[16]);
ArrayList<CPOperand> stride = new ArrayList<>();
ArrayList<CPOperand> padding = new ArrayList<>();
ArrayList<CPOperand> input_shape = new ArrayList<>();
ArrayList<CPOperand> filter_shape = new ArrayList<>();
stride.add(new CPOperand(parts[4]));
stride.add(new CPOperand(parts[5]));
padding.add(new CPOperand(parts[6]));
padding.add(new CPOperand(parts[7]));
input_shape.add(new CPOperand(parts[8]));
input_shape.add(new CPOperand(parts[9]));
input_shape.add(new CPOperand(parts[10]));
input_shape.add(new CPOperand(parts[11]));
filter_shape.add(new CPOperand(parts[12]));
filter_shape.add(new CPOperand(parts[13]));
filter_shape.add(new CPOperand(parts[14]));
filter_shape.add(new CPOperand(parts[15]));
return new DnnGPUInstruction(in1, in2, in3, out, opcode, str, stride,
padding, input_shape, filter_shape, Double.parseDouble(parts[17]));
}
else if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("avgpooling")) {
InstructionUtils.checkNumFields(parts, 15);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand out = new CPOperand(parts[14]);
ArrayList<CPOperand> stride = new ArrayList<>();
ArrayList<CPOperand> padding = new ArrayList<>();
ArrayList<CPOperand> input_shape = new ArrayList<>();
ArrayList<CPOperand> filter_shape = new ArrayList<>();
stride.add(new CPOperand(parts[2]));
stride.add(new CPOperand(parts[3]));
padding.add(new CPOperand(parts[4]));
padding.add(new CPOperand(parts[5]));
input_shape.add(new CPOperand(parts[6]));
input_shape.add(new CPOperand(parts[7]));
input_shape.add(new CPOperand(parts[8]));
input_shape.add(new CPOperand(parts[9]));
filter_shape.add(new CPOperand(parts[10]));
filter_shape.add(new CPOperand(parts[11]));
filter_shape.add(new CPOperand(parts[12]));
filter_shape.add(new CPOperand(parts[13]));
return new DnnGPUInstruction(in1, null, out, opcode, str, stride,
padding, input_shape, filter_shape, Double.parseDouble(parts[15]));
}
else if( opcode.equalsIgnoreCase("bias_add") || opcode.equalsIgnoreCase("relu_backward") || opcode.equalsIgnoreCase("bias_multiply") ) {
InstructionUtils.checkNumFields(parts, 4);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
return new DnnGPUInstruction(in1, in2, out, opcode, str, Double.parseDouble(parts[4]));
}
else if (opcode.equalsIgnoreCase("channel_sums")) {
InstructionUtils.checkNumFields(parts, 4);
CPOperand in = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand out = new CPOperand(parts[4]);
return new DnnGPUInstruction(in, in2, in3, out, opcode, str, 0);
}
else if (opcode.equalsIgnoreCase("update_nesterov_x")) {
InstructionUtils.checkNumFields(parts, 5);
CPOperand in = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand in4 = new CPOperand(parts[4]);
CPOperand out = new CPOperand(parts[5]);
return new DnnGPUInstruction(in, in2, in3, in4, out, opcode, str, 0);
}
else if (opcode.equalsIgnoreCase("lstm")) {
InstructionUtils.checkNumFields(parts, 8);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand in4 = new CPOperand(parts[4]);
CPOperand in5 = new CPOperand(parts[5]);
CPOperand in6 = new CPOperand(parts[6]);
CPOperand out = new CPOperand(parts[7]);
CPOperand out2 = new CPOperand(parts[8]);
return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, out, out2, opcode, str, 0);
}
else if (opcode.equalsIgnoreCase("batch_norm2d") || opcode.equalsIgnoreCase("lstm_backward")) {
InstructionUtils.checkNumFields(parts, 13);
CPOperand in1 = new CPOperand(parts[1]); // image
CPOperand in2 = new CPOperand(parts[2]); // scale
CPOperand in3 = new CPOperand(parts[3]); // bias
CPOperand in4 = new CPOperand(parts[4]); // runningMean
CPOperand in5 = new CPOperand(parts[5]); // runningVar
CPOperand in6 = new CPOperand(parts[6]); // mode
CPOperand in7 = new CPOperand(parts[7]); // epsilon
CPOperand in8 = new CPOperand(parts[8]); // exponentialAverageFactor
CPOperand out = new CPOperand(parts[9]); // ret
CPOperand out2 = new CPOperand(parts[10]); // retRunningMean
CPOperand out3 = new CPOperand(parts[11]); // retRunningVar
CPOperand out4 = new CPOperand(parts[12]); // resultSaveMean
CPOperand out5 = new CPOperand(parts[13]); // resultSaveInvVariance
return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0);
}
else if (opcode.equalsIgnoreCase("batch_norm2d_backward")) {
InstructionUtils.checkNumFields(parts, 9);
CPOperand in1 = new CPOperand(parts[1]); // image
CPOperand in2 = new CPOperand(parts[2]); // dout
CPOperand in3 = new CPOperand(parts[3]); // scale
CPOperand in4 = new CPOperand(parts[4]); // epsilon
CPOperand in5 = new CPOperand(parts[5]); // resultSaveMean
CPOperand in6 = new CPOperand(parts[6]); // resultSaveInvVariance
CPOperand out = new CPOperand(parts[7]); // dX
CPOperand out2 = new CPOperand(parts[8]); // dScale
CPOperand out3 = new CPOperand(parts[9]); // dBias
return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0);
}
else if (opcode.equalsIgnoreCase("batch_norm2d_test")) {
InstructionUtils.checkNumFields(parts, 7);
CPOperand in = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand in4 = new CPOperand(parts[4]);
CPOperand in5 = new CPOperand(parts[5]);
CPOperand in6 = new CPOperand(parts[6]);
CPOperand out = new CPOperand(parts[7]);
return new DnnGPUInstruction(in, in2, in3, in4, in5, in6, out, opcode, str, 0);
}
else if (opcode.equalsIgnoreCase("batch_norm2d_train")) {
InstructionUtils.checkNumFields(parts, 12);
CPOperand in1 = new CPOperand(parts[1]); // image
CPOperand in2 = new CPOperand(parts[2]); // gamma
CPOperand in3 = new CPOperand(parts[3]); // beta
CPOperand in4 = new CPOperand(parts[4]); // ema_mean
CPOperand in5 = new CPOperand(parts[5]); // ema_var
CPOperand in6 = new CPOperand(parts[6]); // eps
CPOperand in7 = new CPOperand(parts[7]); // mu
CPOperand out = new CPOperand(parts[8]); // out
CPOperand out2 = new CPOperand(parts[9]); // ema_mean_upd
CPOperand out3 = new CPOperand(parts[10]); // ema_var_upd
CPOperand out4 = new CPOperand(parts[11]); // cache_mean
CPOperand out5 = new CPOperand(parts[12]); // cache_inv_var
return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, null, out, out2, out3, out4, out5, opcode, str, 0);
}
else {
throw new DMLRuntimeException("Unknown opcode while parsing a DnnGPUInstruction: " + str);
}
}
private void processBiasInstruction(String instOpcode, ExecutionContext ec) {
GPUStatistics.incrementNoOfExecutedGPUInst();
MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName());
MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input2.getName());
MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), input.getNumRows(), input.getNumColumns());
if(instOpcode.equalsIgnoreCase("bias_add"))
LibMatrixCUDA.biasAdd(ec.getGPUContext(0), getExtendedOpcode(), input, bias, out);
else if(instOpcode.equalsIgnoreCase("bias_multiply"))
LibMatrixCUDA.biasMultiply(ec.getGPUContext(0), getExtendedOpcode(), input, bias, out);
// release inputs/outputs
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
ec.releaseMatrixInputForGPUInstruction(_input2.getName());
ec.releaseMatrixOutputForGPUInstruction(_output.getName());
}
private void processBatchNorm2dInstruction(ExecutionContext ec) throws DMLRuntimeException {
GPUStatistics.incrementNoOfExecutedGPUInst();
MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName());
MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName());
MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, _input4.getName());
MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, _input5.getName());
String phase = ec.getScalarInput(_input6).getStringValue();
double epsilon = ec.getScalarInput(_input7).getDoubleValue();
MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns());
if(phase.equalsIgnoreCase("train")) {
double exponentialAverageFactor = 1-ec.getScalarInput(_input8.getName(), _input8.getValueType(), _input8.isLiteral()).getDoubleValue();
MatrixObject retRunningMean = getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), runningMean.getNumRows(), runningMean.getNumColumns());
MatrixObject retRunningVar = getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), runningVar.getNumRows(), runningVar.getNumColumns());
MatrixObject resultSaveMean = getDenseMatrixOutputForGPUInstruction(ec, _output4.getName(), runningMean.getNumRows(), runningMean.getNumColumns());
MatrixObject resultSaveInvVariance = getDenseMatrixOutputForGPUInstruction(ec, _output5.getName(), runningVar.getNumRows(), runningVar.getNumColumns());
LibMatrixCuDNN.batchNormalizationForwardTraining(ec.getGPUContext(0), getExtendedOpcode(),
image, scale, bias, runningMean, runningVar, ret,
retRunningMean, retRunningVar, epsilon, exponentialAverageFactor, resultSaveMean, resultSaveInvVariance);
ec.releaseMatrixOutputForGPUInstruction(_output2.getName());
ec.releaseMatrixOutputForGPUInstruction(_output3.getName());
ec.releaseMatrixOutputForGPUInstruction(_output4.getName());
ec.releaseMatrixOutputForGPUInstruction(_output5.getName());
}
else if(phase.equalsIgnoreCase("test")) {
LibMatrixCuDNN.batchNormalizationForwardInference(ec.getGPUContext(0), getExtendedOpcode(),
image, scale, bias, runningMean, runningVar, ret, epsilon);
ec.setMatrixOutput(_output2.getName(), new MatrixBlock((int)runningMean.getNumRows(), (int)runningMean.getNumColumns(), true));
ec.setMatrixOutput(_output3.getName(), new MatrixBlock((int)runningVar.getNumRows(), (int)runningVar.getNumColumns(), true));
ec.setMatrixOutput(_output4.getName(), new MatrixBlock((int)runningMean.getNumRows(), (int)runningMean.getNumColumns(), true));
ec.setMatrixOutput(_output5.getName(), new MatrixBlock((int)runningVar.getNumRows(), (int)runningVar.getNumColumns(), true));
}
else {
throw new DMLRuntimeException("Incorrect mode: Expected either train or test, but found " + phase);
}
// release inputs/outputs
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
ec.releaseMatrixInputForGPUInstruction(_input2.getName());
ec.releaseMatrixInputForGPUInstruction(_input3.getName());
ec.releaseMatrixInputForGPUInstruction(_input4.getName());
ec.releaseMatrixInputForGPUInstruction(_input5.getName());
ec.releaseMatrixOutputForGPUInstruction(_output.getName());
}
private void processBatchNorm2dTrainInstruction(ExecutionContext ec) throws DMLRuntimeException {
GPUStatistics.incrementNoOfExecutedGPUInst();
MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName());
MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName());
MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, _input4.getName());
MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, _input5.getName());
double epsilon = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getDoubleValue();
double exponentialAverageFactor = 1-ec.getScalarInput(_input7.getName(), _input7.getValueType(), _input7.isLiteral()).getDoubleValue();
MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns());
MatrixObject retRunningMean = getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), runningMean.getNumRows(), runningMean.getNumColumns());
MatrixObject retRunningVar = getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), runningVar.getNumRows(), runningVar.getNumColumns());
MatrixObject resultSaveMean = getDenseMatrixOutputForGPUInstruction(ec, _output4.getName(), runningMean.getNumRows(), runningMean.getNumColumns());
MatrixObject resultSaveInvVariance = getDenseMatrixOutputForGPUInstruction(ec, _output5.getName(), runningVar.getNumRows(), runningVar.getNumColumns());
LibMatrixCuDNN.batchNormalizationForwardTraining(ec.getGPUContext(0), getExtendedOpcode(),
image, scale, bias, runningMean, runningVar, ret,
retRunningMean, retRunningVar, epsilon, exponentialAverageFactor, resultSaveMean, resultSaveInvVariance);
// release inputs/outputs
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
ec.releaseMatrixInputForGPUInstruction(_input2.getName());
ec.releaseMatrixInputForGPUInstruction(_input3.getName());
ec.releaseMatrixInputForGPUInstruction(_input4.getName());
ec.releaseMatrixInputForGPUInstruction(_input5.getName());
ec.releaseMatrixOutputForGPUInstruction(_output.getName());
ec.releaseMatrixOutputForGPUInstruction(_output2.getName());
ec.releaseMatrixOutputForGPUInstruction(_output3.getName());
ec.releaseMatrixOutputForGPUInstruction(_output4.getName());
ec.releaseMatrixOutputForGPUInstruction(_output5.getName());
}
private void processBatchNorm2dTestInstruction(ExecutionContext ec) throws DMLRuntimeException {
GPUStatistics.incrementNoOfExecutedGPUInst();
MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName());
MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName());
MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, _input4.getName());
MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, _input5.getName());
double epsilon = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getDoubleValue();
MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns());
LibMatrixCuDNN.batchNormalizationForwardInference(ec.getGPUContext(0), getExtendedOpcode(),
image, scale, bias, runningMean, runningVar, ret, epsilon);
// release inputs/outputs
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
ec.releaseMatrixInputForGPUInstruction(_input2.getName());
ec.releaseMatrixInputForGPUInstruction(_input3.getName());
ec.releaseMatrixInputForGPUInstruction(_input4.getName());
ec.releaseMatrixInputForGPUInstruction(_input5.getName());
ec.releaseMatrixOutputForGPUInstruction(_output.getName());
}
public void processBatchNorm2dBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException {
GPUStatistics.incrementNoOfExecutedGPUInst();
MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName());
MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input3.getName());
double epsilon = ec.getScalarInput(_input4).getDoubleValue();
MatrixObject resultSaveMean = getMatrixInputForGPUInstruction(ec, _input5.getName());
MatrixObject resultSaveInvVariance = getMatrixInputForGPUInstruction(ec, _input6.getName());
MatrixObject dX = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns());
MatrixObject dScale = getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), scale.getNumRows(), scale.getNumColumns());
MatrixObject dBias = getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), scale.getNumRows(), scale.getNumColumns());
LibMatrixCuDNN.batchNormalizationBackward(ec.getGPUContext(0), getExtendedOpcode(), image,
dout, scale, dX, dScale, dBias,
epsilon, resultSaveMean, resultSaveInvVariance);
// release inputs/outputs
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
ec.releaseMatrixInputForGPUInstruction(_input2.getName());
ec.releaseMatrixInputForGPUInstruction(_input3.getName());
ec.releaseMatrixInputForGPUInstruction(_input5.getName());
ec.releaseMatrixInputForGPUInstruction(_input6.getName());
ec.releaseMatrixOutputForGPUInstruction(_output.getName());
ec.releaseMatrixOutputForGPUInstruction(_output2.getName());
ec.releaseMatrixOutputForGPUInstruction(_output3.getName());
}
// (X > 0) * dout
public void processReLUBackwardInstruction(ExecutionContext ec) {
GPUStatistics.incrementNoOfExecutedGPUInst();
MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName());
MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName());
MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), input.getNumRows(), input.getNumColumns());
LibMatrixCUDA.reluBackward(ec.getGPUContext(0), getExtendedOpcode(), input, dout, out);
// release inputs/outputs
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
ec.releaseMatrixInputForGPUInstruction(_input2.getName());
ec.releaseMatrixOutputForGPUInstruction(_output.getName());
}
private void processChannelSumsInstruction(ExecutionContext ec) {
GPUStatistics.incrementNoOfExecutedGPUInst();
MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName());
int C = (int) ec.getScalarInput(_input2.getName(), _input2.getValueType(), _input2.isLiteral()).getLongValue();
int HW = (int) ec.getScalarInput(_input3.getName(), _input3.getValueType(), _input3.isLiteral()).getLongValue();
if(C*HW != input.getNumColumns()) {
throw new DMLRuntimeException("Expected rows*cols" + C + "*" + HW + " to be equal to number of columns of input " + input.getNumColumns());
}
MatrixObject outputBlock = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), C, 1);
LibMatrixCUDA.channelSums(ec.getGPUContext(0), getExtendedOpcode(), input, outputBlock, C, HW);
// release inputs/outputs
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
ec.releaseMatrixOutputForGPUInstruction(_output.getName());
}
private void processNesterovUpdateInstruction(ExecutionContext ec) {
GPUStatistics.incrementNoOfExecutedGPUInst();
MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName());
MatrixObject v = getMatrixInputForGPUInstruction(ec, _input2.getName());
MatrixObject v_prev = getMatrixInputForGPUInstruction(ec, _input3.getName());
double mu = (int) ec.getScalarInput(_input4).getDoubleValue();
int rows = LibMatrixCUDA.toInt(input.getNumRows());
int cols = LibMatrixCUDA.toInt(input.getNumColumns());
MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), rows, cols);
GPUContext gCtx = ec.getGPUContext(0);
String instName = getExtendedOpcode();
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("update_nesterov_x",
ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCUDA.toInt(rows*cols)),
LibMatrixCUDA.getDensePointer(gCtx, input, instName),
LibMatrixCUDA.getDensePointer(gCtx, v, instName),
LibMatrixCUDA.getDensePointer(gCtx, v_prev, instName),
mu,
LibMatrixCUDA.getDensePointer(gCtx, out, instName),
rows*cols);
// release inputs/outputs
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
ec.releaseMatrixInputForGPUInstruction(_input2.getName());
ec.releaseMatrixInputForGPUInstruction(_input3.getName());
ec.releaseMatrixOutputForGPUInstruction(_output.getName());
}
private static int toInt(long num) throws DMLRuntimeException {
if(num >= Integer.MAX_VALUE || num <= Integer.MIN_VALUE) {
throw new DMLRuntimeException("GPU : Exceeded supported size " + num);
}
return (int)num;
}
// private Pointer transpose(ExecutionContext ec, MatrixObject X) throws DMLRuntimeException {
// GPUContext gCtx = ec.getGPUContext(0);
// String instructionName = getExtendedOpcode();
// long numRowsX = X.getNumRows(); long numColsX = X.getNumColumns();
// Pointer tX = gCtx.allocate(instructionName, numRowsX*numColsX*LibMatrixCUDA.sizeOfDataType);
// jcuda.runtime.JCuda.cudaMemcpy(tX, LibMatrixCUDA.getDensePointer(gCtx, X, instructionName), numRowsX*numColsX*LibMatrixCUDA.sizeOfDataType, jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice);
// // LibMatrixCUDA.denseTranspose(ec, gCtx, instructionName, LibMatrixCUDA.getDensePointer(gCtx, X, instructionName), tX, numRowsX, numColsX);
// return tX;
// }
private void processLstmBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException {
GPUStatistics.incrementNoOfExecutedGPUInst();
GPUContext gCtx = ec.getGPUContext(0);
String instructionName = getExtendedOpcode();
MatrixObject out0 = getMatrixInputForGPUInstruction(ec, _input4.getName());
int M = toInt(out0.getNumColumns()); // hiddenSize .. since out0: (N, M)
Pointer out0Pointer = LibMatrixCUDA.getDensePointer(gCtx, out0, instructionName);
MatrixObject W = getMatrixInputForGPUInstruction(ec, _input2.getName());
MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName());
long numRowsW = W.getNumRows();
int D = toInt(numRowsW) - M; // since W:(D+M, 4M) ... numFeatures
Pointer sysdsWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instructionName, D+M, 4*M);
Pointer sysdsBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instructionName, 1, 4*M);
Pointer cudnnWPointer = gCtx.allocate(instructionName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)),
sysdsWPointer, sysdsBiasPointer, cudnnWPointer, D, M);
ec.releaseMatrixInputForGPUInstruction(_input2.getName());
ec.releaseMatrixInputForGPUInstruction(_input3.getName());
MatrixObject X = getMatrixInputForGPUInstruction(ec, _input1.getName());
Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instructionName);
int N = toInt(X.getNumRows()); // batchSize .. since X:(N, T*D)
long numColsX = X.getNumColumns();
int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength
Pointer cudnnInput = gCtx.allocate(instructionName, (N*T*D)*LibMatrixCUDA.sizeOfDataType);
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D),
xPointer, cudnnInput, N, D, T*D, N*T*D);
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instructionName);
boolean return_sequences = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getBooleanValue();
// LibMatrixCuDNN.lstm(ec, gCtx, instructionName,
// cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T);
// String xName, Pointer hx, Pointer cx, Pointer wPointer, String doutName, String dcyName, // input
// String dxName, String dwName, String dbName, String dhxName, String dcxName, // output
String dxName = _output.getName();
String dwName = _output2.getName();
String dbName = _output3.getName();
String dhxName = _output4.getName();
String dcxName = _output5.getName();
String doutName = _input7.getName();
String dcyName = _input8.getName();
LibMatrixCuDNN.lstmBackward(ec, gCtx, instructionName,
cudnnInput, out0Pointer, c0Pointer, cudnnWPointer, doutName, dcyName, // input
dxName, dwName, dbName, dhxName, dcxName, // output
return_sequences, N, M, D, T);
gCtx.cudaFreeHelper(instructionName, cudnnWPointer, DMLScript.EAGER_CUDA_FREE);
gCtx.cudaFreeHelper(instructionName, cudnnInput, DMLScript.EAGER_CUDA_FREE);
// release inputs/outputs
ec.releaseMatrixInputForGPUInstruction(_input4.getName());
ec.releaseMatrixInputForGPUInstruction(_input5.getName());
}
private void processLstmInstruction(ExecutionContext ec) throws DMLRuntimeException {
// batchSize=N, seqLength=T, numFeatures=D and hiddenSize=M
// input X:(N, T*D), ==> (T, D, N)
// weight W:(D+M+2, 4M)
// previous output out0 (also represented by hx) and cell state c0 (also represented by cx): (N, M) ==> (1, M, N)
// out: (N, T*M) or (N, M) ==> (T, M, N)
GPUStatistics.incrementNoOfExecutedGPUInst();
GPUContext gCtx = ec.getGPUContext(0);
String instructionName = getExtendedOpcode();
MatrixObject out0 = getMatrixInputForGPUInstruction(ec, _input4.getName());
int M = toInt(out0.getNumColumns()); // hiddenSize .. since out0: (N, M)
Pointer out0Pointer = LibMatrixCUDA.getDensePointer(gCtx, out0, instructionName);
MatrixObject W = getMatrixInputForGPUInstruction(ec, _input2.getName());
MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName());
long numRowsW = W.getNumRows();
int D = toInt(numRowsW) - M; // since W:(D+M, 4M) ... numFeatures
Pointer sysdsWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instructionName, D+M, 4*M);
Pointer sysdsBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instructionName, 1, 4*M);
Pointer cudnnWPointer = gCtx.allocate(instructionName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)),
sysdsWPointer, sysdsBiasPointer, cudnnWPointer, D, M);
ec.releaseMatrixInputForGPUInstruction(_input2.getName());
ec.releaseMatrixInputForGPUInstruction(_input3.getName());
boolean return_sequences = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getBooleanValue();
// Beause the matrices are released immediately, the output for transpose need not be taken into account
MatrixObject X = getMatrixInputForGPUInstruction(ec, _input1.getName());
Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instructionName);
int N = toInt(X.getNumRows()); // batchSize .. since X:(N, T*D)
long numColsX = X.getNumColumns();
int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength
Pointer cudnnInput = gCtx.allocate(instructionName, (N*T*D)*LibMatrixCUDA.sizeOfDataType);
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D),
xPointer, cudnnInput, N, D, T*D, N*T*D);
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instructionName);
LibMatrixCuDNN.lstm(ec, gCtx, instructionName, cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T);
gCtx.cudaFreeHelper(instructionName, cudnnWPointer, DMLScript.EAGER_CUDA_FREE);
gCtx.cudaFreeHelper(instructionName, cudnnInput, DMLScript.EAGER_CUDA_FREE);
// release inputs/outputs
ec.releaseMatrixInputForGPUInstruction(_input4.getName());
ec.releaseMatrixInputForGPUInstruction(_input5.getName());
ec.releaseMatrixOutputForGPUInstruction(_output2.getName());
ec.releaseMatrixOutputForGPUInstruction(_output.getName());
}
@Override
public void processInstruction(ExecutionContext ec) {
if (instOpcode.equalsIgnoreCase("bias_add") || instOpcode.equalsIgnoreCase("bias_multiply")) {
processBiasInstruction(instOpcode, ec);
return;
}
else if (instOpcode.equalsIgnoreCase("relu_backward")) {
processReLUBackwardInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase("channel_sums")) {
processChannelSumsInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase("update_nesterov_x")) {
processNesterovUpdateInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase("lstm")) {
processLstmInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase("lstm_backward")) {
processLstmBackwardInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase("batch_norm2d")) {
processBatchNorm2dInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase("batch_norm2d_backward")) {
processBatchNorm2dBackwardInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase("batch_norm2d_test")) {
processBatchNorm2dTestInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase("batch_norm2d_train")) {
processBatchNorm2dTrainInstruction(ec);
return;
}
GPUStatistics.incrementNoOfExecutedGPUInst();
int pad_h = getScalarInput(ec, _padding, 0);
int pad_w = getScalarInput(ec, _padding, 1);
int stride_h = getScalarInput(ec, _stride, 0);
int stride_w = getScalarInput(ec, _stride, 1);
int N = getScalarInput(ec, _input_shape, 0);
int C = getScalarInput(ec, _input_shape, 1);
int H = getScalarInput(ec, _input_shape, 2);
int W = getScalarInput(ec, _input_shape, 3);
int K = getScalarInput(ec, _filter_shape, 0);
int R = getScalarInput(ec, _filter_shape, 2);
int S = getScalarInput(ec, _filter_shape, 3);
int P = (int) DnnUtils.getP(H, R, stride_h, pad_h);
int Q = (int) DnnUtils.getQ(W, S, stride_w, pad_w);
if (instOpcode.equalsIgnoreCase("conv2d")) {
MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
MatrixObject filter = getMatrixInputForGPUInstruction(ec, _input2.getName());
if(image.getNumRows() != N || image.getNumColumns() != C*H*W)
throw new DMLRuntimeException("Incorrect dimensions for image in conv2d");
if(filter.getNumRows() != K || filter.getNumColumns() != C*R*S)
throw new DMLRuntimeException("Incorrect dimensions for filter in conv2d");
MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, K * P * Q);
LibMatrixCuDNN.conv2d(ec.getGPUContext(0), getExtendedOpcode(), image, filter, out, N, C, H, W,
K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, _intermediateMemoryBudget);
}
else if (instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input2.getName());
MatrixObject filter = getMatrixInputForGPUInstruction(ec, _input3.getName());
if(image.getNumRows() != N || image.getNumColumns() != C*H*W)
throw new DMLRuntimeException("Incorrect dimensions for image in conv2d");
if(filter.getNumRows() != K || filter.getNumColumns() != C*R*S)
throw new DMLRuntimeException("Incorrect dimensions for filter in conv2d");
MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, K * P * Q);
LibMatrixCuDNN.conv2dBiasAdd(ec.getGPUContext(0), getExtendedOpcode(), image, bias, filter, out, N, C, H, W,
K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, _intermediateMemoryBudget);
}
else if (instOpcode.equalsIgnoreCase("conv2d_backward_filter")) {
MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName());
if(image.getNumRows() != N || image.getNumColumns() != C*H*W)
throw new DMLRuntimeException("Incorrect dimensions for image in conv2d_backward_filter");
if(dout.getNumRows() != N || dout.getNumColumns() != K*P*Q)
throw new DMLRuntimeException("Incorrect dimensions for dout in conv2d_backward_filter: " +
dout.getNumRows() + " != " + N + " || " + dout.getNumColumns() + " != " + K*P*Q);
MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), K, C * R * S);
LibMatrixCuDNN.conv2dBackwardFilter(ec.getGPUContext(0), getExtendedOpcode(), image, dout, out, N, C, H, W,
K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, _intermediateMemoryBudget);
// TODO: For now always copy the device data to host
// ec.gpuCtx.copyDeviceToHost(outputBlock);
}
else if (instOpcode.equalsIgnoreCase("conv2d_backward_data")) {
MatrixObject filter = getMatrixInputForGPUInstruction(ec, _input1.getName());
MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName());
if(filter.getNumRows() != K || filter.getNumColumns() != C*R*S)
throw new DMLRuntimeException("Incorrect dimensions for filter in convolution_backward_data");
if(dout.getNumRows() != N || dout.getNumColumns() != K*P*Q)
throw new DMLRuntimeException("Incorrect dimensions for dout in conv2d_backward_data: " +
dout.getNumRows() + " != " + N + " || " + dout.getNumColumns() + " != " + K*P*Q);
MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, C * H * W);
LibMatrixCuDNN.conv2dBackwardData(ec.getGPUContext(0), getExtendedOpcode(), filter, dout, out, N, C, H, W,
K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, _intermediateMemoryBudget);
}
else if (instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("avgpooling")) {
MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
if(image.getNumRows() != N || image.getNumColumns() != C*H*W)
throw new DMLRuntimeException("Incorrect dimensions for image in maxpooling: " +
image.getNumRows() + " != " + N + " || " + image.getNumColumns() + " != " + C*H*W);
MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, C * P * Q);
PoolingType poolType = instOpcode.equalsIgnoreCase("maxpooling") ? PoolingType.MAX : PoolingType.AVG;
LibMatrixCuDNN.pooling(ec.getGPUContext(0), getExtendedOpcode(), image, out, N, C, H, W,
K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolType, _intermediateMemoryBudget);
}
else if (instOpcode.equalsIgnoreCase("maxpooling_backward") || instOpcode.equalsIgnoreCase("avgpooling_backward")) {
MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName());
MatrixObject maxPoolOutput = _input3 != null ? getMatrixInputForGPUInstruction(ec, _input3.getName()) : null;
if(dout.getNumRows() != N || dout.getNumColumns() != C*P*Q)
throw new DMLRuntimeException("Incorrect dimensions for dout in maxpooling_backward");
if(image.getNumRows() != N || image.getNumColumns() != C*H*W)
throw new DMLRuntimeException("Incorrect dimensions for image in maxpooling_backward: " +
image.getNumRows() + " != " + N + " || " + image.getNumColumns() + " != " + K*P*Q);
MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, C * H * W);
PoolingType poolType = instOpcode.equalsIgnoreCase("maxpooling_backward") ? PoolingType.MAX : PoolingType.AVG;
LibMatrixCuDNN.poolingBackward(ec.getGPUContext(0), getExtendedOpcode(), image, dout, maxPoolOutput, out, N, C, H, W,
K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolType, _intermediateMemoryBudget);
}
else {
throw new DMLRuntimeException("Unsupported GPU context for " + instOpcode);
}
// release inputs/outputs
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
boolean isPool = instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("avgpooling");
boolean isPoolBackward = instOpcode.equalsIgnoreCase("maxpooling_backward") || instOpcode.equalsIgnoreCase("avgpooling_backward");
if ( !isPool )
ec.releaseMatrixInputForGPUInstruction(_input2.getName());
if (instOpcode.equalsIgnoreCase("conv2d_bias_add") ||
(isPoolBackward && _input3 != null))
ec.releaseMatrixInputForGPUInstruction(_input3.getName());
ec.releaseMatrixOutputForGPUInstruction(_output.getName());
}
private static int getScalarInput(ExecutionContext ec, ArrayList<CPOperand> aL, int index) {
return (int) ec.getScalarInput(aL.get(index)).getLongValue();
}
}