| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * Copyright (c) 2015 by Contributors |
| * \file rnn.cc |
| * \brief |
| * \author Sebastian Bodenstein |
| */ |
| |
| #include <iterator> |
| |
| #include "./rnn-inl.h" |
| #if MXNET_USE_MKLDNN == 1 |
| #include "./nn/mkldnn/mkldnn_rnn-inl.h" |
| #endif // MXNET_USE_MKLDNN == 1 |
| |
| namespace mxnet { |
| namespace op { |
| |
| DMLC_REGISTER_PARAMETER(RNNParam); |
| static inline std::vector<std::string> ListArguments(const RNNParam& param_) { |
| // All RNNs start off with same 3 input arguments |
| std::vector<std::string> arguments{"data", "parameters", "state"}; |
| |
| // LSTMs also have an additional state_cell argument |
| if (param_.mode == rnn_enum::kLstm) { |
| arguments.emplace_back("state_cell"); |
| } |
| |
| // All RNNs have option of additional sequence_length argument |
| if (param_.use_sequence_length) { |
| arguments.emplace_back("sequence_length"); |
| } |
| |
| return arguments; |
| } |
| |
| static bool RNNShape(const nnvm::NodeAttrs& attrs, |
| std::vector<TShape> *in_shape, |
| std::vector<TShape> *out_shape) { |
| const RNNParam& param_ = nnvm::get<RNNParam>(attrs.parsed); |
| using namespace mshadow; |
| |
| // Query param_ object to figure out what the expectd input arguments are |
| std::vector<std::string> expected_arguments = ListArguments(param_); |
| |
| CHECK_EQ(in_shape->size(), expected_arguments.size()) << "Input shape mismatch. Expected " << |
| expected_arguments.size() << " input parameters but got " << in_shape->size() << "."; |
| |
| const TShape &dshape = (*in_shape)[rnn_enum::kData]; |
| if (!mxnet::ndim_is_known(dshape)) return false; |
| CHECK_EQ(dshape.ndim(), 3U) \ |
| << "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]"; |
| // data: [sequence len, batch, input dimension] |
| for (int i = 0; i < dshape.ndim(); i++) { |
| CHECK_LT(dshape[i], INT32_MAX) << "ValueError: RNN does not support large" |
| << "dimensions (>= 2^31)."; |
| } |
| int batch_size = dshape[1]; |
| int input_size = dshape[2]; |
| int numDirections = param_.bidirectional ? 2 : 1; |
| int total_layers = numDirections * param_.num_layers; // double for bidirectional |
| int layer_size = (param_.projection_size.has_value()) ? |
| param_.projection_size.value() : param_.state_size; |
| SHAPE_ASSIGN_CHECK(*in_shape, |
| rnn_enum::kState, |
| Shape3(total_layers, batch_size, layer_size)); |
| if (param_.mode == rnn_enum::kLstm) { |
| SHAPE_ASSIGN_CHECK(*in_shape, |
| rnn_enum::kStateCell, |
| Shape3(total_layers, batch_size, param_.state_size)); |
| } |
| |
| // calculate parameter vector length |
| int param_size = GetRnnParamSize(param_.num_layers, |
| input_size, |
| param_.state_size, |
| numDirections, |
| param_.mode, |
| param_.projection_size); |
| SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size)); |
| |
| // Check on sequence_length shape if using |
| if (param_.use_sequence_length) { |
| size_t seq_len_input_idx = rnn_enum::kSequenceLength; |
| if (param_.mode != rnn_enum::kLstm) --seq_len_input_idx; |
| |
| SHAPE_ASSIGN_CHECK(*in_shape, seq_len_input_idx, Shape1(batch_size)); |
| } |
| |
| out_shape->clear(); |
| // output: [sequence len, batch, output size] |
| TShape oshape = dshape; |
| if (param_.projection_size.has_value()) { |
| oshape[2] = numDirections * param_.projection_size.value(); |
| } else { |
| oshape[2] = numDirections * param_.state_size; |
| } |
| out_shape->push_back(oshape); |
| if (param_.state_outputs) { |
| // outStateShape: [layer_num, batch, state size] |
| TShape outStateShape = dshape; |
| outStateShape[0] = total_layers; |
| outStateShape[1] = batch_size; |
| if (param_.projection_size.has_value()) { |
| outStateShape[2] = param_.projection_size.value(); |
| } else { |
| outStateShape[2] = param_.state_size; |
| } |
| out_shape->push_back(outStateShape); |
| // Deal with lstm cell state |
| if (param_.mode == rnn_enum::kLstm) { |
| TShape cellStateShape = dshape; |
| cellStateShape[0] = total_layers; |
| cellStateShape[1] = batch_size; |
| cellStateShape[2] = param_.state_size; |
| out_shape->push_back(cellStateShape); |
| } |
| } |
| |
| return true; |
| } |
| |
| static bool RNNType(const nnvm::NodeAttrs& attrs, |
| std::vector<int> *in_type, |
| std::vector<int> *out_type) { |
| const RNNParam& param_ = nnvm::get<RNNParam>(attrs.parsed); |
| |
| CHECK_EQ(in_type->size(), GetNumInputArguments(param_)); |
| |
| size_t seq_len_input_idx = rnn_enum::kSequenceLength; |
| if (param_.mode != rnn_enum::kLstm) --seq_len_input_idx; |
| |
| int dtype = (*in_type)[0]; |
| CHECK_NE(dtype, -1) << "First input must have specified type"; |
| std::vector<std::string> arguments = ListArguments(param_); |
| for (size_t i = 0; i < in_type->size(); ++i) { |
| if ((*in_type)[i] == -1) { |
| TYPE_ASSIGN_CHECK(*in_type, i, dtype); |
| } else { |
| // If using sequence length argument, it has its own indexing type |
| // All other input arguments must match the main data type |
| if (!(param_.use_sequence_length && i == seq_len_input_idx)) { |
| UNIFORM_TYPE_CHECK((*in_type)[i], dtype, arguments[i]); |
| } |
| } |
| } |
| out_type->clear(); |
| out_type->push_back(dtype); |
| if (param_.state_outputs) { |
| out_type->push_back(dtype); |
| // Deal with lstm cell state |
| if (param_.mode == rnn_enum::kLstm) { |
| out_type->push_back(dtype); |
| } |
| } |
| return true; |
| } |
| |
| static std::vector<ResourceRequest> RNNResourceEx(const NodeAttrs& attrs, const int dev_mask, |
| const DispatchMode dispatch_mode) { |
| std::vector<ResourceRequest> request; |
| if (dev_mask == kGPU) { |
| #if MXNET_USE_CUDNN == 1 |
| request.emplace_back(ResourceRequest::kTempSpace); |
| request.emplace_back(ResourceRequest::kCuDNNDropoutDesc); |
| #endif |
| } else { |
| request.emplace_back(ResourceRequest::kRandom); |
| #if MXNET_USE_MKLDNN == 1 |
| request.emplace_back(ResourceRequest::kTempSpace); |
| #endif |
| } |
| return request; |
| } |
| |
| #if MXNET_USE_MKLDNN == 1 |
| inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs, |
| const int dev_mask, |
| DispatchMode* dispatch_mode, |
| std::vector<int> *in_attrs, |
| std::vector<int> *out_attrs) { |
| const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed); |
| const bool support_mkldnn_rnn = |
| !param.use_sequence_length && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1); |
| return MKLDNNStorageType(attrs, dev_mask, support_mkldnn_rnn, |
| dispatch_mode, in_attrs, out_attrs); |
| } |
| #endif // MXNET_USE_MKLDNN == 1 |
| |
| struct RNNGrad { |
| const char *op_name; |
| std::vector<nnvm::NodeEntry> operator()(const nnvm::ObjectPtr &n, |
| const std::vector<nnvm::NodeEntry> &ograd) const { |
| const RNNParam& params = nnvm::get<RNNParam>(n->attrs.parsed); |
| std::vector<nnvm::NodeEntry> heads{ n->inputs[rnn_enum::kData], |
| n->inputs[rnn_enum::kParams], n->inputs[rnn_enum::kState] }; |
| heads.emplace_back(n, rnn_enum::kOut, 0); |
| heads.push_back(ograd[rnn_enum::kOut]); |
| if (params.state_outputs) { |
| heads.emplace_back(n, rnn_enum::kStateOut, 0); |
| heads.push_back(ograd[rnn_enum::kStateOut]); |
| } |
| if (params.mode == rnn_enum::kLstm) { |
| heads.push_back(n->inputs[rnn_enum::kStateCell]); |
| if (params.state_outputs) { |
| heads.emplace_back(n, rnn_enum::kStateCellOut, 0); |
| heads.push_back(ograd[rnn_enum::kStateCellOut]); |
| } |
| } |
| return MakeGradNode(op_name, n, heads, n->attrs.dict); |
| } |
| }; |
| |
| static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs, |
| const Context ctx, |
| const mxnet::ShapeVector &in_shapes, |
| const std::vector<int> &in_types) { |
| const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed); |
| OpStatePtr state = OpStatePtr(); |
| int dtype = in_types[rnn_enum::kData]; |
| int itype = dtype; |
| if (param.use_sequence_length) { |
| size_t seq_len_input_idx = rnn_enum::kSequenceLength; |
| if (param.mode != rnn_enum::kLstm) { |
| seq_len_input_idx -= 1; |
| } |
| itype = in_types[seq_len_input_idx]; |
| } |
| |
| #if MXNET_USE_MKLDNN == 1 |
| if (ctx.dev_type == kCPU && SupportMKLDNNRnn(param, in_types[rnn_enum::kData])) { |
| const mxnet::TShape& data_shape = in_shapes[rnn_enum::kData]; |
| state = OpStatePtr::Create<MKLDNNRnnOp>(param, data_shape[0], |
| data_shape[1], data_shape[2]); |
| return state; |
| } |
| #endif // MXNET_USE_MKLDNN == 1 |
| |
| MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { |
| MSHADOW_TYPE_SWITCH(itype, IType, { |
| if (ctx.dev_type == kGPU) { |
| state = OpStatePtr::Create<RNNOp<gpu, DType, IType>>(param, ctx); |
| } else { |
| state = OpStatePtr::Create<RNNOp<cpu, DType, IType>>(param, ctx); |
| } |
| }); |
| }); |
| return state; |
| } |
| |
| #if MXNET_USE_MKLDNN == 1 |
| static void RNNStatefulComputeExCPU(const OpStatePtr& state_ptr, |
| const OpContext& ctx, |
| const std::vector<NDArray>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<NDArray>& outputs) { |
| if (SupportMKLDNNRnn(inputs[rnn_enum::kData].dtype())) { |
| MKLDNNRnnOp& op = state_ptr.get_state<MKLDNNRnnOp>(); |
| op.Forward(ctx, inputs, req, outputs); |
| } else { |
| FallBackCompute(RNNStatefulCompute<cpu>, state_ptr, ctx, inputs, req, outputs); |
| } |
| } |
| |
| static void RNNStatefulGradComputeExCPU(const OpStatePtr& state_ptr, |
| const OpContext& ctx, |
| const std::vector<NDArray>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<NDArray>& outputs) { |
| if (SupportMKLDNNRnn(inputs[rnn_enum::kData].dtype())) { |
| MKLDNNRnnOp& op = state_ptr.get_state<MKLDNNRnnOp>(); |
| op.Backward(ctx, inputs, req, outputs); |
| } else { |
| FallBackCompute(RNNStatefulGradCompute<cpu>, state_ptr, ctx, inputs, req, outputs); |
| } |
| } |
| #endif // MXNET_USE_MKLDNN == 1 |
| |
| NNVM_REGISTER_OP(RNN) |
| .add_alias("_npx_rnn") |
| .describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are |
| implemented, with both multi-layer and bidirectional support. |
| |
| When the input data is of type float32 and the environment variables MXNET_CUDA_ALLOW_TENSOR_CORE |
| and MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION are set to 1, this operator will try to use |
| pseudo-float16 precision (float32 math with float16 I/O) precision in order to use |
| Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups. |
| |
| **Vanilla RNN** |
| |
| Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported: |
| ReLU and Tanh. |
| |
| With ReLU activation function: |
| |
| .. math:: |
| h_t = relu(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh}) |
| |
| With Tanh activtion function: |
| |
| .. math:: |
| h_t = \tanh(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh}) |
| |
| Reference paper: Finding structure in time - Elman, 1988. |
| https://crl.ucsd.edu/~elman/Papers/fsit.pdf |
| |
| **LSTM** |
| |
| Long Short-Term Memory - Hochreiter, 1997. http://www.bioinf.jku.at/publications/older/2604.pdf |
| |
| .. math:: |
| \begin{array}{ll} |
| i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\ |
| f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\ |
| g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\ |
| o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\ |
| c_t = f_t * c_{(t-1)} + i_t * g_t \\ |
| h_t = o_t * \tanh(c_t) |
| \end{array} |
| |
| With the projection size being set, LSTM could use the projection feature to reduce the parameters |
| size and give some speedups without significant damage to the accuracy. |
| |
| Long Short-Term Memory Based Recurrent Neural Network Architectures for Large Vocabulary Speech |
| Recognition - Sak et al. 2014. https://arxiv.org/abs/1402.1128 |
| |
| .. math:: |
| \begin{array}{ll} |
| i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{ri} r_{(t-1)} + b_{ri}) \\ |
| f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{rf} r_{(t-1)} + b_{rf}) \\ |
| g_t = \tanh(W_{ig} x_t + b_{ig} + W_{rc} r_{(t-1)} + b_{rg}) \\ |
| o_t = \mathrm{sigmoid}(W_{io} x_t + b_{o} + W_{ro} r_{(t-1)} + b_{ro}) \\ |
| c_t = f_t * c_{(t-1)} + i_t * g_t \\ |
| h_t = o_t * \tanh(c_t) |
| r_t = W_{hr} h_t |
| \end{array} |
| |
| **GRU** |
| |
| Gated Recurrent Unit - Cho et al. 2014. http://arxiv.org/abs/1406.1078 |
| |
| The definition of GRU here is slightly different from paper but compatible with CUDNN. |
| |
| .. math:: |
| \begin{array}{ll} |
| r_t = \mathrm{sigmoid}(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ |
| z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ |
| n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ |
| h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\ |
| \end{array} |
| )code" ADD_FILELINE) |
| .set_attr_parser(ParamParser<RNNParam>) |
| .set_num_inputs([](const NodeAttrs& attrs) { |
| const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed); |
| return GetNumInputArguments(params); |
| }) |
| .set_num_outputs([](const NodeAttrs& attrs) { |
| const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed); |
| // kOut |
| int num_outputs = 1; |
| if (params.state_outputs) { |
| // kOut, kStateOut, kStateCellOut |
| num_outputs = (params.mode == rnn_enum::kLstm) ? 3 : 2; |
| } |
| |
| return num_outputs; |
| }) |
| .set_attr<nnvm::FListInputNames>("FListInputNames", |
| [](const NodeAttrs& attrs) { |
| const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed); |
| return ListArguments(params); |
| }) |
| .set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) { |
| const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed); |
| std::vector<std::string> names{"output"}; |
| if (params.state_outputs) { |
| names.emplace_back("state_output"); |
| if (params.mode == rnn_enum::kLstm) |
| names.emplace_back("statecell_output"); |
| } |
| return names; |
| }) |
| .set_attr<mxnet::FInferShape>("FInferShape", RNNShape) |
| .set_attr<nnvm::FInferType>("FInferType", RNNType) |
| .set_attr<FCreateOpState>("FCreateOpState", CreateRNNState) |
| .set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulCompute<cpu>) |
| #if MXNET_USE_MKLDNN == 1 |
| .set_attr<FInferStorageType>("FInferStorageType", RNNStorageType) |
| .set_attr<bool>("TIsMKLDNN", true) |
| .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", RNNStatefulComputeExCPU) |
| #endif |
| .set_attr<nnvm::FGradient>("FGradient", RNNGrad{"_backward_RNN"}) |
| .set_attr<FResourceRequestEx>("FResourceRequestEx", RNNResourceEx) |
| .add_argument("data", "NDArray-or-Symbol", "Input data to RNN") |
| .add_argument("parameters", "NDArray-or-Symbol", |
| "Vector of all RNN trainable parameters concatenated") |
| .add_argument("state", "NDArray-or-Symbol", "initial hidden state of the RNN") |
| .add_argument("state_cell", "NDArray-or-Symbol", |
| "initial cell state for LSTM networks (only for LSTM)") |
| .add_argument("sequence_length", "NDArray-or-Symbol", |
| "Vector of valid sequence lengths for each element in batch. (Only used if" |
| " use_sequence_length kwarg is True)") |
| .add_arguments(RNNParam::__FIELDS__()); |
| |
| NNVM_REGISTER_OP(_backward_RNN) |
| .set_num_inputs([](const NodeAttrs& attrs) { |
| const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed); |
| int ret = 5; |
| if (params.state_outputs) { |
| ret += 2; |
| } |
| if (params.mode == rnn_enum::kLstm) { |
| ++ret; |
| if (params.state_outputs) { |
| ret += 2; |
| } |
| } |
| return ret; |
| }) |
| .set_num_outputs([](const NodeAttrs& attrs) { |
| const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed); |
| return GetNumInputArguments(params); |
| }) |
| .set_attr_parser(ParamParser<RNNParam>) |
| .set_attr<bool>("TIsLayerOpBackward", true) |
| .set_attr<nnvm::TIsBackward>("TIsBackward", true) |
| .set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulGradCompute<cpu>) |
| #if MXNET_USE_MKLDNN == 1 |
| .set_attr<FInferStorageType>("FInferStorageType", RNNStorageType) |
| .set_attr<bool>("TIsMKLDNN", true) |
| .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", RNNStatefulGradComputeExCPU) |
| #endif |
| .set_attr<FResourceRequestEx>("FResourceRequestEx", RNNResourceEx); |
| } // namespace op |
| } // namespace mxnet |