| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * Copyright (c) 2015 by Contributors |
| * \file rnn.cc |
| * \brief |
| * \author Sebastian Bodenstein |
| */ |
| #include "./rnn-inl.h" |
| |
| namespace mxnet { |
| namespace op { |
| |
| DMLC_REGISTER_PARAMETER(RNNParam); |
| static inline std::vector<std::string> ListArguments(const RNNParam& param_) { |
| if (param_.mode == rnn_enum::kLstm) { |
| return {"data", "parameters", "state", "state_cell"}; |
| } else { |
| return {"data", "parameters", "state"}; |
| } |
| } |
| |
| static bool RNNShape(const nnvm::NodeAttrs& attrs, |
| std::vector<TShape> *in_shape, |
| std::vector<TShape> *out_shape) { |
| const RNNParam& param_ = nnvm::get<RNNParam>(attrs.parsed); |
| using namespace mshadow; |
| if (param_.mode == rnn_enum::kLstm) { |
| CHECK_EQ(in_shape->size(), 4U) << "Needed input:[data, parameters, state, cell_state]," |
| << " got in_shape->size(): " << in_shape->size(); |
| } else { |
| CHECK_EQ(in_shape->size(), 3U) << |
| "Needed input:[data, parameters, state], got in_shape->size(): " << in_shape->size(); |
| } |
| const TShape &dshape = (*in_shape)[rnn_enum::kData]; |
| if (!mxnet::ndim_is_known(dshape)) return false; |
| CHECK_EQ(dshape.ndim(), 3U) \ |
| << "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]"; |
| // data: [sequence len, batch, input dimension] |
| int batch_size = dshape[1]; |
| int input_size = dshape[2]; |
| int numDirections = param_.bidirectional ? 2 : 1; |
| int total_layers = numDirections * param_.num_layers; // double for bidirectional |
| int layer_size = (param_.projection_size.has_value()) ? |
| param_.projection_size.value() : param_.state_size; |
| SHAPE_ASSIGN_CHECK(*in_shape, |
| rnn_enum::kState, |
| Shape3(total_layers, batch_size, layer_size)); |
| if (param_.mode == rnn_enum::kLstm) { |
| SHAPE_ASSIGN_CHECK(*in_shape, |
| rnn_enum::kStateCell, |
| Shape3(total_layers, batch_size, param_.state_size)); |
| } |
| |
| // calculate parameter vector length |
| int param_size = GetRnnParamSize(param_.num_layers, |
| input_size, |
| param_.state_size, |
| numDirections, |
| param_.mode, |
| param_.projection_size); |
| SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size)); |
| out_shape->clear(); |
| // output: [sequence len, batch, output size] |
| TShape oshape = dshape; |
| if (param_.projection_size.has_value()) { |
| oshape[2] = numDirections * param_.projection_size.value(); |
| } else { |
| oshape[2] = numDirections * param_.state_size; |
| } |
| out_shape->push_back(oshape); |
| if (param_.state_outputs) { |
| // outStateShape: [layer_num, batch, state size] |
| TShape outStateShape = dshape; |
| outStateShape[0] = total_layers; |
| outStateShape[1] = batch_size; |
| if (param_.projection_size.has_value()) { |
| outStateShape[2] = param_.projection_size.value(); |
| } else { |
| outStateShape[2] = param_.state_size; |
| } |
| out_shape->push_back(outStateShape); |
| // Deal with lstm cell state |
| if (param_.mode == rnn_enum::kLstm) { |
| TShape cellStateShape = dshape; |
| cellStateShape[0] = total_layers; |
| cellStateShape[1] = batch_size; |
| cellStateShape[2] = param_.state_size; |
| out_shape->push_back(cellStateShape); |
| } |
| } |
| return true; |
| } |
| |
| static bool RNNType(const nnvm::NodeAttrs& attrs, |
| std::vector<int> *in_type, |
| std::vector<int> *out_type) { |
| const RNNParam& param_ = nnvm::get<RNNParam>(attrs.parsed); |
| if (param_.mode == rnn_enum::kLstm) { |
| CHECK_EQ(in_type->size(), 4U); |
| } else { |
| CHECK_EQ(in_type->size(), 3U); |
| } |
| int dtype = (*in_type)[0]; |
| CHECK_NE(dtype, -1) << "First input must have specified type"; |
| for (size_t i = 0; i < in_type->size(); ++i) { |
| if ((*in_type)[i] == -1) { |
| TYPE_ASSIGN_CHECK(*in_type, i, dtype); |
| } else { |
| UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]); |
| } |
| } |
| out_type->clear(); |
| out_type->push_back(dtype); |
| if (param_.state_outputs) { |
| out_type->push_back(dtype); |
| // Deal with lstm cell state |
| if (param_.mode == rnn_enum::kLstm) |
| out_type->push_back(dtype); |
| } |
| return true; |
| } |
| |
| struct RNNGrad { |
| const char *op_name; |
| std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr &n, |
| const std::vector<nnvm::NodeEntry> &ograd) const { |
| const RNNParam& params = nnvm::get<RNNParam>(n->attrs.parsed); |
| std::vector<nnvm::NodeEntry> heads{ n->inputs[rnn_enum::kData], |
| n->inputs[rnn_enum::kParams], n->inputs[rnn_enum::kState] }; |
| heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kOut, 0}); |
| heads.push_back(ograd[rnn_enum::kOut]); |
| if (params.state_outputs) { |
| heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kStateOut, 0}); |
| heads.push_back(ograd[rnn_enum::kStateOut]); |
| } |
| if (params.mode == rnn_enum::kLstm) { |
| heads.push_back(n->inputs[rnn_enum::kStateCell]); |
| if (params.state_outputs) { |
| heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kStateCellOut, 0}); |
| heads.push_back(ograd[rnn_enum::kStateCellOut]); |
| } |
| } |
| return MakeGradNode(op_name, n, heads, n->attrs.dict); |
| } |
| }; |
| |
| NNVM_REGISTER_OP(RNN) |
| .describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are |
| implemented, with both multi-layer and bidirectional support. |
| |
| When the input data is of type float32 and the environment variables MXNET_CUDA_ALLOW_TENSOR_CORE |
| and MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION are set to 1, this operator will try to use |
| pseudo-float16 precision (float32 math with float16 I/O) precision in order to use |
| Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups. |
| |
| **Vanilla RNN** |
| |
| Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported: |
| ReLU and Tanh. |
| |
| With ReLU activation function: |
| |
| .. math:: |
| h_t = relu(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh}) |
| |
| With Tanh activtion function: |
| |
| .. math:: |
| h_t = \tanh(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh}) |
| |
| Reference paper: Finding structure in time - Elman, 1988. |
| https://crl.ucsd.edu/~elman/Papers/fsit.pdf |
| |
| **LSTM** |
| |
| Long Short-Term Memory - Hochreiter, 1997. http://www.bioinf.jku.at/publications/older/2604.pdf |
| |
| .. math:: |
| \begin{array}{ll} |
| i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\ |
| f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\ |
| g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\ |
| o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\ |
| c_t = f_t * c_{(t-1)} + i_t * g_t \\ |
| h_t = o_t * \tanh(c_t) |
| \end{array} |
| |
| **GRU** |
| |
| Gated Recurrent Unit - Cho et al. 2014. http://arxiv.org/abs/1406.1078 |
| |
| The definition of GRU here is slightly different from paper but compatible with CUDNN. |
| |
| .. math:: |
| \begin{array}{ll} |
| r_t = \mathrm{sigmoid}(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ |
| z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ |
| n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ |
| h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\ |
| \end{array} |
| )code" ADD_FILELINE) |
| .set_attr_parser(ParamParser<RNNParam>) |
| .set_num_inputs([](const NodeAttrs& attrs) { |
| const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed); |
| return params.mode == rnn_enum::kLstm ? 4 : 3; |
| }) |
| .set_num_outputs([](const NodeAttrs& attrs) { |
| const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed); |
| // kOut |
| int num_outputs = 1; |
| if (params.state_outputs) { |
| // kOut, kStateOut, kStateCellOut |
| num_outputs = (params.mode == rnn_enum::kLstm) ? 3 : 2; |
| } |
| |
| return num_outputs; |
| }) |
| .set_attr<nnvm::FListInputNames>("FListInputNames", |
| [](const NodeAttrs& attrs) { |
| const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed); |
| return ListArguments(params); |
| }) |
| .set_attr<mxnet::FInferShape>("FInferShape", RNNShape) |
| .set_attr<nnvm::FInferType>("FInferType", RNNType) |
| .set_attr<FCreateOpState>("FCreateOpState", CreateRNNState) |
| .set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulCompute<cpu>) |
| .set_attr<nnvm::FGradient>("FGradient", RNNGrad{"_backward_RNN"}) |
| .set_attr<FResourceRequestEx>("FResourceRequestEx", |
| [](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) { |
| std::vector<ResourceRequest> request; |
| const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed); |
| if (param.p == 0) return request; |
| if (dev_mask == kGPU) { |
| #if MXNET_USE_CUDNN_RNN |
| if (1.0f - param.p > 0) { |
| request.emplace_back(ResourceRequest::kCuDNNDropoutDesc); |
| return request; |
| } |
| #endif |
| } |
| return request; |
| }) |
| .add_argument("data", "NDArray-or-Symbol", "Input data to RNN") |
| .add_argument("parameters", "NDArray-or-Symbol", |
| "Vector of all RNN trainable parameters concatenated") |
| .add_argument("state", "NDArray-or-Symbol", "initial hidden state of the RNN") |
| .add_argument("state_cell", "NDArray-or-Symbol", |
| "initial cell state for LSTM networks (only for LSTM)") |
| .add_arguments(RNNParam::__FIELDS__()); |
| |
| NNVM_REGISTER_OP(_backward_RNN) |
| .set_num_outputs([](const NodeAttrs& attrs) { |
| const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed); |
| return params.mode == rnn_enum::kLstm ? 4 : 3; |
| }) |
| .set_attr_parser(ParamParser<RNNParam>) |
| .set_attr<bool>("TIsLayerOpBackward", true) |
| .set_attr<nnvm::TIsBackward>("TIsBackward", true) |
| .set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulGradCompute<cpu>); |
| } // namespace op |
| } // namespace mxnet |