blob: eac274f96a9d42fd9abaf622a53c95e0a5ef85b6 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file rnn-inl.h
* \brief
* \author Sebastian Bodenstein, Shu Zhang
*/
#ifndef MXNET_OPERATOR_RNN_INL_H_
#define MXNET_OPERATOR_RNN_INL_H_
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <mxnet/storage.h>
#include <algorithm>
#include <random>
#include <map>
#include <vector>
#include <string>
#include <utility>
#include <cstdint>
#include "./math.h"
#include "./math_functions-inl.h"
#include "./operator_common.h"
#include "./rnn_impl.h"
#include "../profiler/storage_profiler.h"
#if MXNET_USE_CUDNN == 1
STATIC_ASSERT_CUDNN_VERSION_GE(7000);
#endif
#define MXNET_USE_CUDNN_GE_7200 MXNET_USE_CUDNN == 1 && CUDNN_VERSION >= 7200
namespace mxnet {
namespace op {
namespace rnn_enum {
enum RNNOpInputs { kData, kParams, kState, kStateCell, kSequenceLength };
enum RNNOpOutputs { kOut, kStateOut, kStateCellOut };
enum RNNModeType { kRnnRelu, kRnnTanh, kLstm, kGru };
enum RNNOpResource { kTempSpace, kCuDNNDropoutDescSpace };
} // namespace rnn_enum
struct RNNParam : public dmlc::Parameter<RNNParam> {
uint32_t state_size;
uint32_t num_layers;
bool bidirectional, state_outputs;
int mode;
float p;
#pragma GCC diagnostic push
#if __GNUC__ >= 6
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#endif
index_t seq_length_, batch_size_, input_size_;
#pragma GCC diagnostic pop
bool use_sequence_length;
dmlc::optional<int> projection_size;
dmlc::optional<double> lstm_state_clip_min, lstm_state_clip_max;
bool lstm_state_clip_nan;
DMLC_DECLARE_PARAMETER(RNNParam) {
DMLC_DECLARE_FIELD(state_size).describe("size of the state for each layer");
DMLC_DECLARE_FIELD(num_layers).describe("number of stacked layers");
DMLC_DECLARE_FIELD(bidirectional)
.set_default(false)
.describe("whether to use bidirectional recurrent layers");
DMLC_DECLARE_FIELD(mode)
.add_enum("rnn_relu", rnn_enum::kRnnRelu)
.add_enum("rnn_tanh", rnn_enum::kRnnTanh)
.add_enum("lstm", rnn_enum::kLstm)
.add_enum("gru", rnn_enum::kGru)
.describe("the type of RNN to compute");
DMLC_DECLARE_FIELD(p).set_default(0.).set_range(0, 1).describe(
"drop rate of the dropout on the outputs of each RNN layer, except the last layer.");
DMLC_DECLARE_FIELD(state_outputs)
.set_default(false)
.describe("Whether to have the states as symbol outputs.");
DMLC_DECLARE_FIELD(projection_size)
.set_default(dmlc::optional<int>())
.describe("size of project size");
DMLC_DECLARE_FIELD(lstm_state_clip_min)
.set_default(dmlc::optional<double>())
.describe(
"Minimum clip value of LSTM states. This option must be used together with "
"lstm_state_clip_max.");
DMLC_DECLARE_FIELD(lstm_state_clip_max)
.set_default(dmlc::optional<double>())
.describe(
"Maximum clip value of LSTM states. This option must be used together with "
"lstm_state_clip_min.");
DMLC_DECLARE_FIELD(lstm_state_clip_nan)
.set_default(false)
.describe(
"Whether to stop NaN from propagating in state by clipping it to min/max. "
"If clipping range is not specified, this option is ignored.");
DMLC_DECLARE_FIELD(use_sequence_length)
.set_default(false)
.describe(
"If set to true, this layer takes in an extra input parameter "
"`sequence_length` "
"to specify variable length sequence");
}
std::string ComputeMode2String(int mode) {
switch (mode) {
case rnn_enum::kRnnRelu:
return "rnn_relu";
case rnn_enum::kRnnTanh:
return "rnn_tanh";
case rnn_enum::kLstm:
return "lstm";
case rnn_enum::kGru:
return "gru";
default:
LOG(FATAL) << "Unknown mode enum " << mode;
}
LOG(FATAL) << "should not reach here ";
return "";
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream state_size_s, num_layers_s, bidirectional_s, state_outputs_s, mode_s, p_s,
use_sequence_length_s, projection_size_s, lstm_state_clip_min_s, lstm_state_clip_max_s,
lstm_state_clip_nan_s;
state_size_s << state_size;
num_layers_s << num_layers;
bidirectional_s << bidirectional;
state_outputs_s << state_outputs;
mode_s << mode;
p_s << p;
use_sequence_length_s << use_sequence_length;
projection_size_s << projection_size;
lstm_state_clip_min_s << lstm_state_clip_min;
lstm_state_clip_max_s << lstm_state_clip_max;
lstm_state_clip_nan_s << lstm_state_clip_nan;
(*dict)["state_size"] = state_size_s.str();
(*dict)["num_layers"] = num_layers_s.str();
(*dict)["bidirectional"] = bidirectional_s.str();
(*dict)["state_outputs"] = state_outputs_s.str();
(*dict)["mode"] = ComputeMode2String(mode);
(*dict)["p"] = p_s.str();
(*dict)["use_sequence_length"] = use_sequence_length_s.str();
(*dict)["projection_size"] = projection_size_s.str();
(*dict)["lstm_state_clip_min"] = lstm_state_clip_min_s.str();
(*dict)["lstm_state_clip_max"] = lstm_state_clip_max_s.str();
(*dict)["lstm_state_clip_nan"] = lstm_state_clip_nan_s.str();
}
};
inline index_t GetRnnParamSize(int num_layer,
index_t input_size,
int state_size,
int direction,
int mode,
const dmlc::optional<int>& projection_size) {
int size = state_size * direction;
switch (mode) {
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
break;
case rnn_enum::kLstm:
size *= 4;
break;
case rnn_enum::kGru:
size *= 3;
break;
}
index_t size1 = (input_size + state_size + 2) * size; // first layer size
index_t size2 = (state_size * direction + state_size + 2) * size; // other layers size
if (projection_size.has_value()) {
index_t proj_size = projection_size.value();
size1 = (input_size + proj_size + 2) * size;
size2 = (proj_size * direction + proj_size + 2) * size;
}
index_t param_size = size1 + (num_layer - 1) * size2;
if (projection_size.has_value()) {
param_size += projection_size.value() * state_size * num_layer * direction;
}
return param_size;
}
inline int GetRnnBiasSize(int num_layer, int state_size, int direction, int mode) {
int size = 2 * state_size * direction * num_layer;
switch (mode) {
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
break;
case rnn_enum::kLstm:
size *= 4;
break;
case rnn_enum::kGru:
size *= 3;
break;
}
return size;
}
/*
* Calculate the space size of the intermediate results for RNN inference.
* The inference procedure of a fusion RNN operator calculates the outputs
* layer by layer. In one layer calculation, the steps are:
* - wx[1...Ngates] * x[1...T] among all time stamp(sz: TxNxHxNgates)
* - wh[1...Ngates] * h[t] time by time(sz: NxHxNgates)
* - output -> h[t](, c[t] additionally with Lstm) time by time(sz: NxH(x2))
* - intermediate y[1...T] as next layer's inputs(sz: TxNxHxD)
*/
inline size_t GetRNNWorkspaceSize(index_t seq_length,
index_t batch_size,
int hidden_size,
int projection_size,
int direction,
int mode) {
size_t size = 0;
switch (mode) {
case rnn_enum::kLstm:
size = seq_length * batch_size * hidden_size * (4 + direction) + // wx*x + inter-y
batch_size * hidden_size * 6 + // wh*h + h + c
seq_length * hidden_size * 8 + // Used in Backward, Δbx, Δbh
// temporary dy in backward computation for bidirectional layers
seq_length * batch_size * hidden_size * (direction - 1 ? direction : 0);
break;
case rnn_enum::kGru:
// Differs with Lstm, the outputs of three gates are also held in memory
size = seq_length * batch_size * hidden_size * direction * (3 + 1) + // wx*x + inter-y
batch_size * hidden_size * (6 + direction); // wh*h + h + Ngates
break;
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
size = seq_length * batch_size * hidden_size * direction * 2 + // wx*x + inter-y
batch_size * hidden_size * (1 + direction); // h + Ngates
break;
default:
LOG(FATAL) << "unknown RNN mode " << mode;
break;
}
return size;
}
inline size_t GetRNNReserveSpaceSize(int num_layer,
int direction,
index_t seq_length,
index_t batch_size,
int hidden_size,
int mode) {
size_t size = 0;
switch (mode) {
case rnn_enum::kLstm:
size = direction * seq_length * batch_size * hidden_size * (num_layer * 7 - 1);
break;
case rnn_enum::kGru:
size = seq_length * batch_size * hidden_size * direction * (num_layer * 9 - 1) +
batch_size * hidden_size * direction * 9 + hidden_size * seq_length * 6 +
seq_length * batch_size * 7 * hidden_size * direction;
break;
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
size = seq_length * batch_size * hidden_size * direction * (num_layer * 6 - 1) +
batch_size * hidden_size * direction * 3 + hidden_size * seq_length * 2 +
seq_length * batch_size * 2 * hidden_size * direction;
break;
default:
LOG(FATAL) << "unknown RNN mode " << mode;
break;
}
return size;
}
inline size_t GetRnnNumInputs(RNNParam param) {
size_t num_inputs = (param.mode == rnn_enum::kLstm) ? 4U : 3U;
if (param.use_sequence_length)
num_inputs += 1U;
return num_inputs;
}
/**
* @params: ws: Temp workspace for gemm's output storage.
* rs: Reserve space of forward intermediate data used for training.
* num_layers: The number of recurrent layers.
* direction: direction is 2 if use bidirectional recurrent layers, else is 1;
* seq_length: The number of iterations to unroll over.
* batch_size: size of batch.
* input_size: The number of expected input features.
* state_size: The number of hidden state features.
* x_ptr: Pointer of tensor x containing the features of the input sequence.
* x's shape is [seq_length, batch_size, input_size]
* hx_ptr: Pointer of tensor hx containing the initial hidden state.
* hx's shape is [num_layers, batch_size, state_size]
* cx_ptr: Only used in lstm mode. pointer of tensor cx containing the initial cell state.
* cx's shape is [num_layers, batch_size, state_size]
* w_ptr: Pointer of tensor w containing weights.
* b_ptr: Pointer of tensor w containing bias.
* y_ptr: Pointer of tensor y containing the features of the output features from the
* last layers of the RNN. y's shape is [seq_length, batch_size, state_size]
* hy_ptr: Pointer of tensor hy containing the hidden state for t=seq_length.
* hy's shape is [num_layers, batch_size, state_size]
* cy_ptr: Only used in lstm mode. pointer of tensor cy containing the cell state
* for t=seq_length. cy' shape is [num_layers, batch_size, state_size]
* dropout: should be 0 <= dropout < 1
* mode: Specifies the type of RNN to compute.
*/
template <typename DType>
void RNNForwardTraining(DType* ws,
DType* rs,
bool state_outputs,
const int num_layers,
const int direction,
const index_t seq_length,
const index_t batch_size,
const index_t input_size,
const int state_size,
DType* x_ptr,
DType* hx_ptr,
DType* cx_ptr,
DType* w_ptr,
DType* b_ptr,
DType* y_ptr,
DType* hy_ptr,
DType* cy_ptr,
const float dropout,
int mode,
std::mt19937& rnd_engine) { // NOLINT(runtime/references)
switch (mode) {
case rnn_enum::kLstm:
LstmForwardTraining<DType>(ws,
rs,
state_outputs,
num_layers,
direction,
seq_length,
batch_size,
input_size,
state_size,
x_ptr,
hx_ptr,
cx_ptr,
w_ptr,
b_ptr,
y_ptr,
hy_ptr,
cy_ptr,
dropout,
rnd_engine);
break;
case rnn_enum::kGru:
GruForwardTraining<DType>(ws,
rs,
state_outputs,
num_layers,
direction,
seq_length,
batch_size,
input_size,
state_size,
x_ptr,
hx_ptr,
w_ptr,
y_ptr,
hy_ptr,
dropout,
rnd_engine);
break;
case rnn_enum::kRnnTanh:
case rnn_enum::kRnnRelu:
VanillaRNNForwardTraining<DType>(ws,
rs,
state_outputs,
num_layers,
direction,
seq_length,
batch_size,
input_size,
state_size,
x_ptr,
hx_ptr,
w_ptr,
y_ptr,
hy_ptr,
dropout,
mode,
rnd_engine);
break;
default:
LOG(FATAL) << "unknown RNN mode " << mode;
break;
}
}
template <typename DType>
void RNNForwardInference(DType* ws,
bool state_outputs,
const int num_layers,
const int direction,
const index_t seq_length,
const index_t batch_size,
const index_t input_size,
const int state_size,
const int projection_size,
DType* x_ptr,
DType* hx_ptr,
DType* cx_ptr,
DType* w_ptr,
DType* b_ptr,
DType* y_ptr,
DType* hy_ptr,
DType* cy_ptr,
int mode) {
switch (mode) {
case rnn_enum::kLstm:
LstmForwardInference<DType>(ws,
state_outputs,
num_layers,
direction,
seq_length,
batch_size,
input_size,
state_size,
projection_size,
x_ptr,
hx_ptr,
cx_ptr,
w_ptr,
b_ptr,
y_ptr,
hy_ptr,
cy_ptr);
break;
case rnn_enum::kGru:
GruForwardInference<DType>(ws,
state_outputs,
num_layers,
direction,
seq_length,
batch_size,
input_size,
state_size,
x_ptr,
hx_ptr,
w_ptr,
y_ptr,
hy_ptr);
break;
case rnn_enum::kRnnTanh:
case rnn_enum::kRnnRelu:
VanillaRNNForwardInference<DType>(ws,
state_outputs,
num_layers,
direction,
seq_length,
batch_size,
input_size,
state_size,
x_ptr,
hx_ptr,
w_ptr,
y_ptr,
hy_ptr,
mode);
break;
default:
LOG(FATAL) << "unknown RNN mode" << mode;
break;
}
}
template <typename DType>
void RNNBackward(DType* ws,
DType* rs,
const int num_layers,
const int direction,
const index_t seq_length,
const index_t batch_size,
const index_t input_size,
const int state_size,
DType* x_ptr,
DType* hx_ptr,
DType* cx_ptr,
DType* w_ptr,
DType* y_ptr,
DType* dy_ptr,
DType* dhy_ptr,
DType* dcy_ptr,
DType* dx_ptr,
DType* dhx_ptr,
DType* dcx_ptr,
DType* dw_ptr,
DType* db_ptr,
int req_data,
int req_params,
int req_state,
int req_statecell,
const float dropout,
int mode) {
switch (mode) {
case rnn_enum::kLstm:
LstmBackward<DType>(ws,
rs,
num_layers,
direction,
seq_length,
batch_size,
input_size,
state_size,
x_ptr,
hx_ptr,
cx_ptr,
w_ptr,
y_ptr,
dy_ptr,
dhy_ptr,
dcy_ptr,
dx_ptr,
dhx_ptr,
dcx_ptr,
dw_ptr,
db_ptr,
req_data,
req_params,
req_state,
req_statecell,
dropout);
break;
case rnn_enum::kGru:
GruBackward<DType>(ws,
rs,
num_layers,
direction,
seq_length,
batch_size,
input_size,
state_size,
x_ptr,
hx_ptr,
w_ptr,
dy_ptr,
dhy_ptr,
dx_ptr,
dhx_ptr,
dw_ptr,
req_data,
req_params,
req_state,
dropout);
break;
case rnn_enum::kRnnTanh:
case rnn_enum::kRnnRelu:
VanillaRNNBackward<DType>(ws,
rs,
num_layers,
direction,
seq_length,
batch_size,
input_size,
state_size,
x_ptr,
hx_ptr,
w_ptr,
dy_ptr,
dhy_ptr,
dx_ptr,
dhx_ptr,
dw_ptr,
req_data,
req_params,
req_state,
dropout,
mode);
break;
default:
LOG(FATAL) << "unknown RNN mode" << mode;
break;
}
}
template <typename xpu, typename DType, typename IType>
class RNNOp {
public:
RNNParam param_;
Context ctx_;
explicit RNNOp(RNNParam param, Context ctx) {
this->param_ = param;
this->ctx_ = ctx;
if (ctx_.dev_type == kGPU) {
#if MXNET_USE_CUDNN == 1
init_cudnn_ = false;
dtype_ = mshadow::DataType<DType>::kCudnnFlag;
// TensorCore algos only allowed on fp16-I/O convolutions if permitted by the global policy.
// No tests in place for fp16 RNNs, so leave TensorCore disabled for now.
cudnn_tensor_core_ = false;
// When fp16 RNN tests are introduced, we can enable TensorCore as follows:
// cudnn_tensor_core =
// mshadow::DataType<DType>::kFlag == mshadow::kFloat16 && GetEnvAllowTensorCore();
// Defaults
input_mode_ = CUDNN_LINEAR_INPUT; // Don't support this yet
// RNN Mode
switch (param_.mode) {
case rnn_enum::kRnnRelu:
mode_ = CUDNN_RNN_RELU;
break;
case rnn_enum::kRnnTanh:
mode_ = CUDNN_RNN_TANH;
break;
case rnn_enum::kLstm:
mode_ = CUDNN_LSTM;
break;
case rnn_enum::kGru:
mode_ = CUDNN_GRU;
break;
default:
LOG(FATAL) << "Not implmented";
}
#if MXNET_USE_CUDNN_GE_7200
if (param_.projection_size.has_value()) {
CHECK_EQ(param_.mode, rnn_enum::kLstm) << "Projection is only supported for LSTM.";
CHECK_GE(param_.state_size, param_.projection_size.value())
<< "State size must be larger than projection size.";
}
#else
CHECK(!param_.projection_size.has_value())
<< "Projection is only supported for LSTM with CuDNN version later than 7.1.1.";
#endif // MXNET_USE_CUDNN_GE_7200
#if MXNET_USE_CUDNN_GE_7200
if (param_.lstm_state_clip_min.has_value() || param_.lstm_state_clip_max.has_value()) {
CHECK_EQ(param_.mode, rnn_enum::kLstm) << "State clipping is only supported for LSTM.";
CHECK(param_.lstm_state_clip_min.has_value() && param_.lstm_state_clip_max.has_value())
<< "lstm_state_clip_min and lstm_state_clip_max must be specified together.";
CHECK_GE(param_.lstm_state_clip_max.value(), param_.lstm_state_clip_min.value())
<< "lstm_state_clip_max must be greater or equal to lstm_state_clip_min";
}
#else
CHECK(!param_.lstm_state_clip_min.has_value() && !param_.lstm_state_clip_max.has_value())
<< "State clipping is only supported for LSTM with CuDNN version later than 7.2.1.";
#endif // MXNET_USE_CUDNN_GE_7200
// RNN Direction
direction_ = param_.bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
// Create descriptors
CUDNN_CALL(cudnnCreateTensorDescriptor(&hx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&cx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&hy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&cy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dhx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dcx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dhy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dcy_desc_));
CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc_));
CUDNN_CALL(cudnnCreateFilterDescriptor(&dw_desc_));
CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_));
CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_));
#if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dy_data_desc_));
#endif // MXNET_USE_CUDNN_GE_7200
#else
if (ctx_.dev_type == kGPU) {
LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment.";
}
#endif // MXNET_USE_CUDNN == 1
}
if (ctx_.dev_type == kCPU) {
this->init_space_ = false;
this->temp_init_space_ = false;
this->reserve_cpu_space_size_ = 0;
this->temp_cpu_space_size_ = 0;
if (param_.lstm_state_clip_min.has_value() || param_.lstm_state_clip_max.has_value()) {
LOG(FATAL) << "LSTM state clipping is only supported for GPU with CuDNN later than 7.2.1";
}
}
}
~RNNOp() {
if (ctx_.dev_type == kGPU) {
#if MXNET_USE_CUDNN == 1
CUDNN_CALL(cudnnDestroyTensorDescriptor(hx_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(cx_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(hy_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(cy_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(dhx_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(dcx_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(dhy_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(dcy_desc_));
CUDNN_CALL(cudnnDestroyFilterDescriptor(w_desc_));
CUDNN_CALL(cudnnDestroyFilterDescriptor(dw_desc_));
CUDNN_CALL(cudnnDestroyRNNDescriptor(rnn_desc_));
CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_));
if (dgrad_sync_event_created_)
CUDA_CALL(cudaEventDestroy(dgrad_sync_event_));
if (init_cudnn_) {
for (size_t i = 0; i < x_desc_vec_.size(); ++i) {
CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc_vec_[i]));
CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc_vec_[i]));
CUDNN_CALL(cudnnDestroyTensorDescriptor(dx_desc_vec_[i]));
CUDNN_CALL(cudnnDestroyTensorDescriptor(dy_desc_vec_[i]));
}
init_cudnn_ = false;
Storage::Get()->Free(reserve_space_);
}
#if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnDestroyRNNDataDescriptor(x_data_desc_));
CUDNN_CALL(cudnnDestroyRNNDataDescriptor(y_data_desc_));
CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dx_data_desc_));
CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dy_data_desc_));
#endif // MXNET_USE_CUDNN_GE_7200
#endif // MXNET_USE_CUDNN
}
}
void Forward(const OpContext& ctx,
const std::vector<TBlob>& in_data,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& out_data) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK(param_.p >= 0.0f && param_.p < 1.0f)
<< "unsupported dropout value, should be 0 <= dropout < 1";
size_t num_inputs = GetRnnNumInputs(param_);
// kOut
size_t num_outputs = 1;
if (param_.state_outputs) {
// kOut, kStateOut, kStateCellOut
num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
}
CHECK_EQ(in_data.size(), num_inputs);
CHECK_EQ(out_data.size(), num_outputs);
Stream<xpu>* s = ctx.get_stream<xpu>();
// get input + output tensors
Tensor<xpu, 3, DType> x = in_data[rnn_enum::kData].get<xpu, 3, DType>(s);
Tensor<xpu, 1, DType> w = in_data[rnn_enum::kParams].get<xpu, 1, DType>(s);
Tensor<xpu, 3, DType> hx = in_data[rnn_enum::kState].get<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> y = out_data[rnn_enum::kOut].get<xpu, 3, DType>(s);
param_.seq_length_ = x.shape_[0];
param_.batch_size_ = x.shape_[1];
param_.input_size_ = x.shape_[2];
const int direction = param_.bidirectional ? 2 : 1;
const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size, direction, param_.mode);
DType* b_ptr = w.dptr_ + w.shape_[0] - bsize;
DType* hy_ptr = nullptr;
if (param_.state_outputs) {
hy_ptr = out_data[rnn_enum::kStateOut].dptr<DType>();
}
#if MXNET_USE_CUDNN_GE_7200
Tensor<cpu, 1, char> host_workspace;
int* sequence_length_cpu_int = nullptr;
IType* sequence_length_cpu_itype = nullptr;
if (ctx_.dev_type == kGPU) {
int host_workspace_bytes =
param_.batch_size_ * sizeof(IType) + param_.batch_size_ * sizeof(int);
host_workspace = ctx.requested[rnn_enum::kTempSpace].get_host_space_typed<1, char>(
Shape1(host_workspace_bytes));
sequence_length_cpu_int = reinterpret_cast<int*>(host_workspace.dptr_);
sequence_length_cpu_itype =
reinterpret_cast<IType*>(host_workspace.dptr_ + sizeof(int) * param_.batch_size_);
(void)sequence_length_cpu_int;
(void)sequence_length_cpu_itype;
}
#endif
if (param_.use_sequence_length) {
#if MXNET_USE_CUDNN_GE_7200
if (ctx_.dev_type == kCPU) {
LOG(FATAL) << "RNN use_sequence_length option is only available for cuDNN at the moment."
<< " Not supported on CPU";
}
// We can assume we are on GPU for now
size_t seq_len_input_idx = rnn_enum::kSequenceLength;
if (param_.mode != rnn_enum::kLstm) {
seq_len_input_idx -= 1;
}
IType* sequence_length_ptr_gpu = (in_data[seq_len_input_idx].get<xpu, 1, IType>(s)).dptr_;
// Need to copy from GPU -> CPU, becuase cuDNN API requires this array on CPU memory.
// TODO(stephenrawls): In future, allow users to pass this array on the CPU so we don't have
// to do this copy For now however it is required as several places in backend assume that
// all data arrays share the same context.
CUDA_CALL(cudaMemcpy(sequence_length_cpu_itype,
sequence_length_ptr_gpu,
sizeof(IType) * param_.batch_size_,
cudaMemcpyDeviceToHost));
#else
LOG(FATAL) << "RNN use_sequence_length option is only available for cuDNN version >= 7.2";
#endif
}
DType* cx_ptr = nullptr;
DType* cy_ptr = nullptr;
if (param_.mode == rnn_enum::kLstm) {
cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
}
if (param_.mode == rnn_enum::kLstm && param_.state_outputs) {
cy_ptr = (out_data[rnn_enum::kStateCellOut].get<xpu, 3, DType>(s)).dptr_;
}
CHECK_EQ(x.CheckContiguous(), true);
CHECK_EQ(w.CheckContiguous(), true);
CHECK_EQ(hx.CheckContiguous(), true);
CHECK_EQ(y.CheckContiguous(), true);
#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
if (!init_cudnn_) {
Init(ctx, s, in_data, out_data);
}
// Get temp space
int temp_size = workspace_size_;
Tensor<gpu, 1, DType> temp_space =
ctx.requested[rnn_enum::kTempSpace].get_space_typed<gpu, 1, DType>(
mshadow::Shape1(temp_size), s);
#if MXNET_USE_CUDNN_GE_7200
cudnnRNNDataLayout_t layout_t;
if (param_.use_sequence_length) {
// Note: Can't mempcy, sequence_length_ptr_cpu is of type Itype, not nescesarily int
for (int i = 0; i < param_.batch_size_; ++i) {
sequence_length_cpu_int[i] = sequence_length_cpu_itype[i];
}
layout_t = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED;
} else {
for (int i = 0; i < param_.batch_size_; ++i) {
sequence_length_cpu_int[i] = param_.seq_length_;
}
layout_t = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED;
}
CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_,
dtype_,
layout_t,
param_.seq_length_,
param_.batch_size_,
param_.input_size_,
sequence_length_cpu_int,
reinterpret_cast<void*>(&padding_fill_)));
int out_size =
(param_.projection_size.has_value()) ? param_.projection_size.value() : param_.state_size;
out_size = (param_.bidirectional) ? (out_size * 2) : out_size;
CUDNN_CALL(cudnnSetRNNDataDescriptor(y_data_desc_,
dtype_,
layout_t,
param_.seq_length_,
param_.batch_size_,
out_size,
sequence_length_cpu_int,
reinterpret_cast<void*>(&padding_fill_)));
if (ctx.is_train) {
CUDNN_CALL(cudnnSetRNNDataDescriptor(dx_data_desc_,
dtype_,
layout_t,
param_.seq_length_,
param_.batch_size_,
param_.input_size_,
sequence_length_cpu_int,
reinterpret_cast<void*>(&padding_fill_)));
CUDNN_CALL(cudnnSetRNNDataDescriptor(dy_data_desc_,
dtype_,
layout_t,
param_.seq_length_,
param_.batch_size_,
out_size,
sequence_length_cpu_int,
reinterpret_cast<void*>(&padding_fill_)));
}
bool clip_state = param_.lstm_state_clip_min.has_value();
bool clip_nan = param_.lstm_state_clip_nan;
CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_,
rnn_desc_,
clip_state ? CUDNN_RNN_CLIP_MINMAX : CUDNN_RNN_CLIP_NONE,
clip_nan ? CUDNN_NOT_PROPAGATE_NAN : CUDNN_PROPAGATE_NAN,
clip_state ? param_.lstm_state_clip_min.value() : 0.0,
clip_state ? param_.lstm_state_clip_max.value() : 0.0));
#endif // MXNET_USE_CUDNN_GE_7200
if (ctx.is_train) {
#if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_,
rnn_desc_,
x_data_desc_,
x.dptr_,
hx_desc_,
hx.dptr_,
cx_desc_,
cx_ptr,
w_desc_,
w.dptr_,
y_data_desc_,
y.dptr_,
hy_desc_,
hy_ptr,
cy_desc_,
cy_ptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
temp_space.dptr_,
workspace_byte_,
reserve_space_.dptr,
reserve_space_byte_));
#else
CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_,
rnn_desc_,
param_.seq_length_,
x_desc_vec_.data(),
x.dptr_,
hx_desc_,
hx.dptr_,
cx_desc_,
cx_ptr,
w_desc_,
w.dptr_,
y_desc_vec_.data(),
y.dptr_,
hy_desc_,
hy_ptr,
cy_desc_,
cy_ptr,
temp_space.dptr_,
workspace_byte_,
reserve_space_.dptr,
reserve_space_byte_));
#endif // MXNET_USE_CUDNN_GE_7200
} else {
#if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_,
rnn_desc_,
x_data_desc_,
x.dptr_,
hx_desc_,
hx.dptr_,
cx_desc_,
cx_ptr,
w_desc_,
w.dptr_,
y_data_desc_,
y.dptr_,
hy_desc_,
hy_ptr,
cy_desc_,
cy_ptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
temp_space.dptr_,
workspace_byte_));
#else
CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_,
rnn_desc_,
param_.seq_length_,
x_desc_vec_.data(),
x.dptr_,
hx_desc_,
hx.dptr_,
cx_desc_,
cx_ptr,
w_desc_,
w.dptr_,
y_desc_vec_.data(),
y.dptr_,
hy_desc_,
hy_ptr,
cy_desc_,
cy_ptr,
temp_space.dptr_,
workspace_byte_));
#endif // MXNET_USE_CUDNN_GE_7200
}
#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
#if !defined(__CUDACC__) // cuda doesn't support C++17
if constexpr (std::is_same<xpu, cpu>::value) {
int projection_size = 0;
if (param_.projection_size.has_value()) {
projection_size = param_.projection_size.value();
}
// allocate temp space
const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_,
param_.batch_size_,
param_.state_size,
projection_size,
direction,
param_.mode);
if (!temp_init_space_ || temp_cpu_space_size_ < work_cpu_space_size) {
temp_cpu_space_size_ = work_cpu_space_size;
temp_cpu_space_ = NDArray(TShape({static_cast<dim_t>(temp_cpu_space_size_)}),
ctx_,
false,
in_data[rnn_enum::kData].type_flag_);
temp_init_space_ = true;
}
DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.data().dptr_);
if (ctx.is_train || ctx.need_grad) {
mshadow::Random<cpu, unsigned>* prnd = ctx.requested[0].get_random<xpu, unsigned int>(s);
std::mt19937& rnd_engine = prnd->GetRndEngine();
// allocate reserve space
if (param_.projection_size.has_value()) {
LOG(FATAL) << "No training support for LSTM with projection on CPU currently.";
}
const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers,
direction,
param_.seq_length_,
param_.batch_size_,
param_.state_size,
param_.mode);
if (!init_space_ || reserve_cpu_space_size_ < r_size) {
reserve_cpu_space_size_ = r_size;
reserve_cpu_space_ = NDArray(TShape({static_cast<dim_t>(reserve_cpu_space_size_)}),
ctx_,
false,
in_data[rnn_enum::kData].type_flag_);
init_space_ = true;
}
DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.data().dptr_);
RNNForwardTraining<DType>(work_cpu_space,
reserve_space_ptr,
param_.state_outputs,
param_.num_layers,
direction,
param_.seq_length_,
param_.batch_size_,
param_.input_size_,
param_.state_size,
x.dptr_,
hx.dptr_,
cx_ptr,
w.dptr_,
b_ptr,
y.dptr_,
hy_ptr,
cy_ptr,
param_.p,
param_.mode,
rnd_engine);
} else {
RNNForwardInference<DType>(work_cpu_space,
param_.state_outputs,
param_.num_layers,
direction,
param_.seq_length_,
param_.batch_size_,
param_.input_size_,
param_.state_size,
projection_size,
x.dptr_,
hx.dptr_,
cx_ptr,
w.dptr_,
b_ptr,
y.dptr_,
hy_ptr,
cy_ptr,
param_.mode);
}
}
#endif
}
void Backward(const OpContext& ctx,
const std::vector<TBlob>& out_grad,
const std::vector<TBlob>& in_data,
const std::vector<TBlob>& out_data,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& in_grad) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK(param_.p >= 0.0f && param_.p < 1.0f)
<< "unsupported dropout value, should be 0 <= dropout < 1";
size_t num_inputs = GetRnnNumInputs(param_);
// kOut
size_t num_outputs = 1;
if (param_.state_outputs) {
// kOut, kStateOut, kStateCellOut
num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
}
CHECK_EQ(in_data.size(), num_inputs);
CHECK_EQ(out_data.size(), num_outputs);
CHECK_EQ(in_grad.size(), num_inputs);
CHECK_EQ(out_grad.size(), num_outputs);
CHECK_EQ(req.size(), num_inputs);
CHECK_NE(req[rnn_enum::kData], kAddTo) << "AddTo is not supported for data";
CHECK_NE(req[rnn_enum::kState], kAddTo) << "AddTo is not supported for state";
Stream<xpu>* s = ctx.get_stream<xpu>();
// get input + output tensors
Tensor<xpu, 3, DType> x = in_data[rnn_enum::kData].get<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> dx = in_grad[rnn_enum::kData].get<xpu, 3, DType>(s);
Tensor<xpu, 1, DType> w = in_data[rnn_enum::kParams].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<xpu, 1, DType>(s);
Tensor<xpu, 3, DType> hx = in_data[rnn_enum::kState].get<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> y = out_data[rnn_enum::kOut].get<xpu, 3, DType>(s);
Tensor<xpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<xpu, 3, DType>(s);
CHECK_EQ(x.CheckContiguous(), true);
CHECK_EQ(w.CheckContiguous(), true);
CHECK_EQ(dw.CheckContiguous(), true);
CHECK_EQ(hx.CheckContiguous(), true);
CHECK_EQ(dhx.CheckContiguous(), true);
CHECK_EQ(y.CheckContiguous(), true);
CHECK_EQ(dy.CheckContiguous(), true);
CHECK_EQ(dx.CheckContiguous(), true);
if (req[rnn_enum::kParams] != kAddTo) {
dw = mshadow::expr::ScalarExp<DType>(0.0f);
}
param_.seq_length_ = x.shape_[0];
param_.batch_size_ = x.shape_[1];
param_.input_size_ = x.shape_[2];
const int direction = param_.bidirectional ? 2 : 1;
const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size, direction, param_.mode);
DType* db_ptr = dw.dptr_ + w.shape_[0] - bsize;
DType* dhy_ptr = nullptr;
if (param_.state_outputs) {
dhy_ptr = out_grad[rnn_enum::kStateOut].dptr<DType>();
}
DType* dcx_ptr = nullptr;
DType* dcy_ptr = nullptr;
DType* cx_ptr = nullptr;
if (param_.mode == rnn_enum::kLstm) {
CHECK_NE(req[rnn_enum::kStateCell], kAddTo) << "AddTo is not supported for state cell";
cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
dcx_ptr = (in_grad[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
}
if ((param_.mode == rnn_enum::kLstm) && param_.state_outputs) {
dcy_ptr = (out_grad[rnn_enum::kStateCellOut].get<xpu, 3, DType>(s)).dptr_;
}
#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
if (!init_cudnn_) {
Init(ctx, s, in_data, out_data);
}
// Get temp space
int temp_size = workspace_size_;
Tensor<gpu, 1, DType> temp_space =
ctx.requested[rnn_enum::kTempSpace].get_space_typed<gpu, 1, DType>(
mshadow::Shape1(temp_size), s);
#if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_,
rnn_desc_,
y_data_desc_,
y.dptr_,
dy_data_desc_,
dy.dptr_,
nullptr,
nullptr,
dhy_desc_,
dhy_ptr,
dcy_desc_,
dcy_ptr,
w_desc_,
w.dptr_,
hx_desc_,
hx.dptr_,
cx_desc_,
cx_ptr,
dx_data_desc_,
dx.dptr_,
dhx_desc_,
dhx.dptr_,
dcx_desc_,
dcx_ptr,
nullptr,
nullptr,
temp_space.dptr_,
workspace_byte_,
reserve_space_.dptr,
reserve_space_byte_));
SyncDgrad();
if (req[rnn_enum::kParams] != kNullOp) {
CUDNN_CALL(cudnnRNNBackwardWeightsEx(s->dnn_handle_,
rnn_desc_,
x_data_desc_,
x.dptr_,
hx_desc_,
hx.dptr_,
y_data_desc_,
y.dptr_,
temp_space.dptr_,
workspace_byte_,
dw_desc_,
dw.dptr_,
reserve_space_.dptr,
reserve_space_byte_));
}
#else
CUDNN_CALL(cudnnRNNBackwardData(s->dnn_handle_,
rnn_desc_,
param_.seq_length_,
y_desc_vec_.data(),
y.dptr_,
dy_desc_vec_.data(),
dy.dptr_,
dhy_desc_,
dhy_ptr,
dcy_desc_,
dcy_ptr,
w_desc_,
w.dptr_,
hx_desc_,
hx.dptr_,
cx_desc_,
cx_ptr,
dx_desc_vec_.data(),
dx.dptr_,
dhx_desc_,
dhx.dptr_,
dcx_desc_,
dcx_ptr,
temp_space.dptr_,
workspace_byte_,
reserve_space_.dptr,
reserve_space_byte_));
SyncDgrad();
if (req[rnn_enum::kParams] != kNullOp) {
CUDNN_CALL(cudnnRNNBackwardWeights(s->dnn_handle_,
rnn_desc_,
param_.seq_length_,
x_desc_vec_.data(),
x.dptr_,
hx_desc_,
hx.dptr_,
y_desc_vec_.data(),
y.dptr_,
temp_space.dptr_,
workspace_byte_,
dw_desc_,
dw.dptr_,
reserve_space_.dptr,
reserve_space_byte_));
}
#endif // MXNET_USE_CUDNN_GE_7200
#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
if (ctx_.dev_type == kCPU) {
int projection_size = 0;
if (param_.projection_size.has_value()) {
// TODO(zixuanweeei): Add training support for LSTM with projection on CPU.
// projection_size = param_.projection_size.value();
LOG(FATAL) << "No training support for LSTM with projection on CPU currently.";
}
// allocate temp space
const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_,
param_.batch_size_,
param_.state_size,
projection_size,
direction,
param_.mode);
if (!temp_init_space_ || temp_cpu_space_size_ != work_cpu_space_size) {
LOG(FATAL) << "Check temp init error";
}
DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.data().dptr_);
size_t r_size = GetRNNReserveSpaceSize(param_.num_layers,
direction,
param_.seq_length_,
param_.batch_size_,
param_.state_size,
param_.mode);
if (!init_space_ || reserve_cpu_space_size_ != r_size) {
LOG(FATAL) << "Check forward init error";
}
DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.data().dptr_);
RNNBackward<DType>(work_cpu_space,
reserve_space_ptr,
param_.num_layers,
direction,
param_.seq_length_,
param_.batch_size_,
param_.input_size_,
param_.state_size,
x.dptr_,
hx.dptr_,
cx_ptr,
w.dptr_,
y.dptr_,
dy.dptr_,
dhy_ptr,
dcy_ptr,
dx.dptr_,
dhx.dptr_,
dcx_ptr,
dw.dptr_,
db_ptr,
req[rnn_enum::kData],
req[rnn_enum::kParams],
req[rnn_enum::kState],
// State cell should be present for LSTMs, but is absent for other RNNs.
param_.mode == rnn_enum::kLstm ? req[rnn_enum::kStateCell] : kNullOp,
param_.p,
param_.mode);
}
}
private:
inline void Init(const OpContext& ctx,
mshadow::Stream<xpu>* s,
const std::vector<TBlob>& in_data,
const std::vector<TBlob>& out_data) {
using namespace mshadow;
size_t num_inputs = GetRnnNumInputs(param_);
// kOut
size_t num_outputs = 1;
if (param_.state_outputs) {
// kOut, kStateOut, kStateCellOut
num_outputs = (param_.mode == rnn_enum::kLstm) ? 3U : 2U;
}
CHECK_EQ(in_data.size(), num_inputs);
CHECK_EQ(out_data.size(), num_outputs);
#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
format_ = CUDNN_TENSOR_NCHW;
if (!init_cudnn_) {
init_cudnn_ = true;
// get input + output tensors
Tensor<xpu, 3, DType> x = in_data[rnn_enum::kData].get<xpu, 3, DType>(s);
Tensor<xpu, 1, DType> w = in_data[rnn_enum::kParams].get<xpu, 1, DType>(s);
param_.seq_length_ = x.shape_[0];
param_.batch_size_ = x.shape_[1];
param_.input_size_ = x.shape_[2];
// Tensor Descriptors
std::vector<cudnnTensorDescriptor_t> x_vec(param_.seq_length_);
std::vector<cudnnTensorDescriptor_t> y_vec(param_.seq_length_);
std::vector<cudnnTensorDescriptor_t> dx_vec(param_.seq_length_);
std::vector<cudnnTensorDescriptor_t> dy_vec(param_.seq_length_);
int dimA[3];
int strideA[3];
for (int i = 0; i < param_.seq_length_; i++) {
CUDNN_CALL(cudnnCreateTensorDescriptor(&x_vec[i]));
CUDNN_CALL(cudnnCreateTensorDescriptor(&y_vec[i]));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dx_vec[i]));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dy_vec[i]));
dimA[0] = param_.batch_size_;
dimA[1] = param_.input_size_;
dimA[2] = 1;
strideA[0] = dimA[2] * dimA[1];
strideA[1] = dimA[2];
strideA[2] = 1;
CUDNN_CALL(cudnnSetTensorNdDescriptor(x_vec[i], dtype_, 3, dimA, strideA));
CUDNN_CALL(cudnnSetTensorNdDescriptor(dx_vec[i], dtype_, 3, dimA, strideA));
dimA[0] = param_.batch_size_;
dimA[1] = param_.bidirectional ? param_.state_size * 2 : param_.state_size;
dimA[2] = 1;
strideA[0] = dimA[2] * dimA[1];
strideA[1] = dimA[2];
strideA[2] = 1;
CUDNN_CALL(cudnnSetTensorNdDescriptor(y_vec[i], dtype_, 3, dimA, strideA));
CUDNN_CALL(cudnnSetTensorNdDescriptor(dy_vec[i], dtype_, 3, dimA, strideA));
}
x_desc_vec_ = x_vec;
y_desc_vec_ = y_vec;
dx_desc_vec_ = dx_vec;
dy_desc_vec_ = dy_vec;
// set the state tensors
dimA[0] = param_.num_layers * (param_.bidirectional ? 2 : 1);
dimA[1] = param_.batch_size_;
dimA[2] = param_.state_size;
strideA[0] = dimA[2] * dimA[1];
strideA[1] = dimA[2];
strideA[2] = 1;
#if MXNET_USE_CUDNN_GE_7200
int dimB[3];
int strideB[3];
dimB[0] = param_.num_layers * (param_.bidirectional ? 2 : 1);
dimB[1] = param_.batch_size_;
dimB[2] =
param_.projection_size.has_value() ? param_.projection_size.value() : param_.state_size;
strideB[0] = dimB[2] * dimB[1];
strideB[1] = dimB[2];
strideB[2] = 1;
#endif // MXNET_USE_CUDNN_GE_7200
#if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_, dtype_, 3, dimB, strideB));
#else
CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_, dtype_, 3, dimA, strideA));
#endif // MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnSetTensorNdDescriptor(cx_desc_, dtype_, 3, dimA, strideA));
#if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_, dtype_, 3, dimB, strideB));
#else
CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_, dtype_, 3, dimA, strideA));
#endif // MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnSetTensorNdDescriptor(cy_desc_, dtype_, 3, dimA, strideA));
#if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_, dtype_, 3, dimB, strideB));
#else
CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_, dtype_, 3, dimA, strideA));
#endif // MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnSetTensorNdDescriptor(dcx_desc_, dtype_, 3, dimA, strideA));
#if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_, dtype_, 3, dimB, strideB));
#else
CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_, dtype_, 3, dimA, strideA));
#endif // MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnSetTensorNdDescriptor(dcy_desc_, dtype_, 3, dimA, strideA));
// Create Dropout descriptors
ctx.requested[rnn_enum::kCuDNNDropoutDescSpace].get_cudnn_dropout_desc(
&dropout_desc_, s, param_.p);
// RNN descriptors
// adopt pseudo-fp16 for all architectures
cudnnDataType_t dtype_with_fallback_ =
(cudnnGetVersion() >= 7500 && dtype_ == CUDNN_DATA_HALF) ? CUDNN_DATA_FLOAT : dtype_;
cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD;
dgrad_sync_needed_ = (rnn_algo == CUDNN_RNN_ALGO_STANDARD) && param_.bidirectional;
CUDNN_CALL(cudnnSetRNNDescriptor_v6(s->dnn_handle_,
rnn_desc_,
param_.state_size,
param_.num_layers,
dropout_desc_,
input_mode_,
direction_,
mode_,
rnn_algo,
dtype_with_fallback_));
cudnnMathType_t math_type = CUDNN_DEFAULT_MATH;
if (cudnn_tensor_core_ && rnn_algo == CUDNN_RNN_ALGO_STANDARD) {
math_type = CUDNN_TENSOR_OP_MATH;
}
#if CUDNN_VERSION >= 7200
if (GetEnvAllowTensorCore() && GetEnvAllowTensorCoreConversion() &&
(DataType<DType>::kFlag != kFloat16)) {
math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION;
}
#endif
CUDNN_CALL(cudnnSetRNNMatrixMathType(rnn_desc_, math_type));
#if MXNET_USE_CUDNN_GE_7200
if (param_.projection_size.has_value()) {
CUDNN_CALL(cudnnSetRNNProjectionLayers(
s->dnn_handle_, rnn_desc_, param_.projection_size.value(), 0));
}
if (param_.use_sequence_length) {
CUDNN_CALL(cudnnSetRNNPaddingMode(rnn_desc_, CUDNN_RNN_PADDED_IO_ENABLED));
}
#endif // MXNET_USE_CUDNN_GE_7200
// Get temp space sizes
CUDNN_CALL(cudnnGetRNNWorkspaceSize(
s->dnn_handle_, rnn_desc_, param_.seq_length_, x_desc_vec_.data(), &workspace_byte_));
CUDNN_CALL(cudnnGetRNNTrainingReserveSize(
s->dnn_handle_, rnn_desc_, param_.seq_length_, x_desc_vec_.data(), &reserve_space_byte_));
workspace_size_ = workspace_byte_ / sizeof(DType);
// Allocate the reserve space
reserve_space_ = Storage::Get()->Alloc(reserve_space_byte_, Context::GPU(s->dev_id));
reserve_space_.profiler_scope = "cudnn_rnn:";
reserve_space_.name = "reserve_space";
profiler::GpuDeviceStorageProfiler::Get()->UpdateStorageInfo(reserve_space_);
// Check that number of params are correct
size_t cudnn_param_size;
CUDNN_CALL(cudnnGetRNNParamsSize(
s->dnn_handle_, rnn_desc_, x_desc_vec_[0], &cudnn_param_size, dtype_));
CHECK_EQ(w.shape_[0] * sizeof(DType), cudnn_param_size);
// Set param descriptors
int dim_w[3] = {1, 1, 1};
dim_w[0] = w.shape_[0];
CUDNN_CALL(cudnnSetFilterNdDescriptor(w_desc_, dtype_, format_, 3, dim_w));
CUDNN_CALL(cudnnSetFilterNdDescriptor(dw_desc_, dtype_, format_, 3, dim_w));
// Query weight layout
// cudnnFilterDescriptor_t m_desc;
// CHECK_EQ(cudnnCreateFilterDescriptor(&m_desc), CUDNN_STATUS_SUCCESS);
// DType *p;
// int n = 2;
// int64_t last = 0;
// if (param_.mode == rnn_enum::kLstm) n = 8;
// else if (param_.mode == rnn_enum::kGru) n = 6;
// for (int i = 0; i < param_.num_layers*(param_.bidirectional?2:1); ++i) {
// for (int j = 0; j < n; ++j) {
// CHECK_EQ(cudnnGetRNNLinLayerMatrixParams(s->dnn_handle_, rnn_desc_,
// i, x_desc_vec_[0], w_desc_, 0, j, m_desc, (void**)&p), CUDNN_STATUS_SUCCESS);
// LOG(INFO) << ((int64_t)(p - nullptr))/sizeof(DType) - last;
// last = ((int64_t)(p - nullptr))/sizeof(DType);
// cudnnDataType_t t;
// cudnnTensorFormat_t f;
// int ndim = 5;
// int dims[5] = {0, 0, 0, 0, 0};
// CHECK_EQ(cudnnGetFilterNdDescriptor(m_desc, ndim, &t, &f, &ndim, &dims[0]),
// CUDNN_STATUS_SUCCESS);
// LOG(INFO) << "w: " << i << " " << j << " " << ((int64_t)(p - nullptr))/sizeof(DType);
// for (int i = 0; i < ndim; ++i) LOG(INFO) << dims[i];
// }
// }
// for (int i = 0; i < param_.num_layers*(param_.bidirectional?2:1); ++i) {
// for (int j = 0; j < n; ++j) {
// CHECK_EQ(cudnnGetRNNLinLayerBiasParams(s->dnn_handle_, rnn_desc_, i, x_desc_vec_[0],
// w_desc_, 0, j, m_desc, (void**)&p), CUDNN_STATUS_SUCCESS);
// LOG(INFO) << ((int64_t)(p - nullptr))/sizeof(DType) - last;
// last = ((int64_t)(p - nullptr))/sizeof(DType);
// LOG(INFO) << "b: " << i << " " << j << " " << ((int64_t)(p - nullptr))/sizeof(DType);
// }
// }
}
#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
}
// naive private variables used in CPU Context
bool init_space_, temp_init_space_;
size_t reserve_cpu_space_size_, temp_cpu_space_size_;
NDArray reserve_cpu_space_, temp_cpu_space_;
#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
// cuDNN versions up to and including v7.6.4 did not sync a last dgrad kernel back to the main
// cudnn handle's stream (non-persistant algo, bidirectional only). This could result in silent
// non-determinstic failures with very low probability, seen more often when wgrad is bypassed.
inline void SyncDgrad() {
if (CUDNN_VERSION <= 7604 && dgrad_sync_needed_) {
// Without blocking the CPU, create a synchronization point of all current GPU activity. No
// need to call cudaStreamWaitEvent- cudaEventRecord on the legacy default stream suffices.
if (!dgrad_sync_event_created_) {
CUDA_CALL(cudaEventCreateWithFlags(&dgrad_sync_event_, cudaEventDisableTiming));
dgrad_sync_event_created_ = true;
}
CUDA_CALL(cudaEventRecord(dgrad_sync_event_, cudaStreamLegacy));
}
}
#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
#if MXNET_USE_CUDNN == 1
cudnnDataType_t dtype_;
bool init_cudnn_;
cudnnRNNDescriptor_t rnn_desc_;
cudnnRNNMode_t mode_;
cudnnDirectionMode_t direction_;
cudnnRNNInputMode_t input_mode_;
cudnnDropoutDescriptor_t dropout_desc_;
Storage::Handle reserve_space_;
size_t workspace_byte_, reserve_space_byte_;
int workspace_size_;
std::vector<cudnnTensorDescriptor_t> x_desc_vec_, y_desc_vec_, dx_desc_vec_, dy_desc_vec_;
#if MXNET_USE_CUDNN_GE_7200
cudnnRNNDataDescriptor_t x_data_desc_, y_data_desc_, dx_data_desc_, dy_data_desc_;
DType padding_fill_ = 0;
#endif // MXNET_USE_CUDNN_GE_7200
cudnnTensorDescriptor_t hx_desc_, cx_desc_;
cudnnTensorDescriptor_t hy_desc_, cy_desc_;
cudnnTensorDescriptor_t dhx_desc_, dcx_desc_;
cudnnTensorDescriptor_t dhy_desc_, dcy_desc_;
cudnnFilterDescriptor_t w_desc_, dw_desc_;
// Allow TensorCore algo policy
bool cudnn_tensor_core_;
cudnnTensorFormat_t format_;
cudaEvent_t dgrad_sync_event_;
bool dgrad_sync_event_created_ = false;
bool dgrad_sync_needed_ = false;
#endif // MXNET_USE_CUDNN
}; // class RNNOp
template <typename xpu>
void RNNStatefulCompute(const OpStatePtr& state,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
int dtype = inputs[rnn_enum::kData].type_flag_;
// Hacky. This relies on fact that seq-len type is either the last input,
// or we aren't using seq-len input and this type should be same as dtype.
// Would prefer direct access to RNNParam object here but not sure how to get.
int itype = inputs[inputs.size() - 1].type_flag_;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
MSHADOW_TYPE_SWITCH(itype, IType, {
RNNOp<xpu, DType, IType>& op = state.get_state<RNNOp<xpu, DType, IType>>();
op.Forward(ctx, inputs, req, outputs);
});
});
}
/*
index description
0: x
1: w
2: hx
3: y
4: dy
5: hy
6: dhy
7: cx
8: cy
9: dcy
*/
template <typename xpu>
void RNNStatefulGradCompute(const OpStatePtr& state,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
std::vector<TBlob> in_data(inputs.begin(), inputs.begin() + 3);
std::vector<TBlob> out_data{inputs[3]};
std::vector<TBlob> out_grad{inputs[4]};
const std::vector<TBlob>& in_grad = outputs;
int dtype = inputs[rnn_enum::kData].type_flag_;
// Hacky. This relies on fact that seq-len type is either the last input,
// or we aren't using seq-len input and this type should be same as dtype.
// Would prefer direct access to RNNParam object here but not sure how to get.
int itype = outputs[outputs.size() - 1].type_flag_;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
MSHADOW_TYPE_SWITCH(itype, IType, {
RNNOp<xpu, DType, IType>& op = state.get_state<RNNOp<xpu, DType, IType>>();
const RNNParam& param = op.param_;
int index = 5;
if (param.state_outputs) {
out_data.push_back(inputs[index++]);
out_grad.push_back(inputs[index++]);
}
if (param.mode == rnn_enum::kLstm) {
in_data.push_back(inputs[index++]);
if (param.state_outputs) {
out_data.push_back(inputs[index++]);
out_grad.push_back(inputs[index]);
}
}
if (param.use_sequence_length) {
size_t seq_len_input_idx = rnn_enum::kSequenceLength;
if (param.mode != rnn_enum::kLstm) {
seq_len_input_idx -= 1;
}
in_data.push_back(outputs[seq_len_input_idx]);
}
op.Backward(ctx, out_grad, in_data, out_data, req, in_grad);
});
});
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_RNN_INL_H_