blob: 7012a3c22f50318921151e13ff56da95d40f0026 [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file rnn.cc
* \brief
* \author Sebastian Bodenstein
*/
#include "./rnn-inl.h"
namespace mxnet {
namespace op {
DMLC_REGISTER_PARAMETER(RNNParam);
static inline std::vector<std::string> ListArguments(const RNNParam& param_) {
if (param_.mode == rnn_enum::kLstm) {
return {"data", "parameters", "state", "state_cell"};
} else {
return {"data", "parameters", "state"};
}
}
static bool RNNShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
const RNNParam& param_ = nnvm::get<RNNParam>(attrs.parsed);
using namespace mshadow;
if (param_.mode == rnn_enum::kLstm) {
CHECK_EQ(in_shape->size(), 4U) << "Needed input:[data, parameters, state, cell_state],"
<< " got in_shape->size(): " << in_shape->size();
} else {
CHECK_EQ(in_shape->size(), 3U) <<
"Needed input:[data, parameters, state], got in_shape->size(): " << in_shape->size();
}
const TShape &dshape = (*in_shape)[rnn_enum::kData];
if (!mxnet::ndim_is_known(dshape)) return false;
CHECK_EQ(dshape.ndim(), 3U) \
<< "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]";
// data: [sequence len, batch, input dimension]
int batch_size = dshape[1];
int input_size = dshape[2];
int numDirections = param_.bidirectional ? 2 : 1;
int total_layers = numDirections * param_.num_layers; // double for bidirectional
int layer_size = (param_.projection_size.has_value()) ?
param_.projection_size.value() : param_.state_size;
SHAPE_ASSIGN_CHECK(*in_shape,
rnn_enum::kState,
Shape3(total_layers, batch_size, layer_size));
if (param_.mode == rnn_enum::kLstm) {
SHAPE_ASSIGN_CHECK(*in_shape,
rnn_enum::kStateCell,
Shape3(total_layers, batch_size, param_.state_size));
}
// calculate parameter vector length
int param_size = GetRnnParamSize(param_.num_layers,
input_size,
param_.state_size,
numDirections,
param_.mode,
param_.projection_size);
SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size));
out_shape->clear();
// output: [sequence len, batch, output size]
TShape oshape = dshape;
if (param_.projection_size.has_value()) {
oshape[2] = numDirections * param_.projection_size.value();
} else {
oshape[2] = numDirections * param_.state_size;
}
out_shape->push_back(oshape);
if (param_.state_outputs) {
// outStateShape: [layer_num, batch, state size]
TShape outStateShape = dshape;
outStateShape[0] = total_layers;
outStateShape[1] = batch_size;
if (param_.projection_size.has_value()) {
outStateShape[2] = param_.projection_size.value();
} else {
outStateShape[2] = param_.state_size;
}
out_shape->push_back(outStateShape);
// Deal with lstm cell state
if (param_.mode == rnn_enum::kLstm) {
TShape cellStateShape = dshape;
cellStateShape[0] = total_layers;
cellStateShape[1] = batch_size;
cellStateShape[2] = param_.state_size;
out_shape->push_back(cellStateShape);
}
}
return true;
}
static bool RNNType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type) {
const RNNParam& param_ = nnvm::get<RNNParam>(attrs.parsed);
if (param_.mode == rnn_enum::kLstm) {
CHECK_EQ(in_type->size(), 4U);
} else {
CHECK_EQ(in_type->size(), 3U);
}
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
TYPE_ASSIGN_CHECK(*in_type, i, dtype);
} else {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]);
}
}
out_type->clear();
out_type->push_back(dtype);
if (param_.state_outputs) {
out_type->push_back(dtype);
// Deal with lstm cell state
if (param_.mode == rnn_enum::kLstm)
out_type->push_back(dtype);
}
return true;
}
struct RNNGrad {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr &n,
const std::vector<nnvm::NodeEntry> &ograd) const {
const RNNParam& params = nnvm::get<RNNParam>(n->attrs.parsed);
std::vector<nnvm::NodeEntry> heads{ n->inputs[rnn_enum::kData],
n->inputs[rnn_enum::kParams], n->inputs[rnn_enum::kState] };
heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kOut, 0});
heads.push_back(ograd[rnn_enum::kOut]);
if (params.state_outputs) {
heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kStateOut, 0});
heads.push_back(ograd[rnn_enum::kStateOut]);
}
if (params.mode == rnn_enum::kLstm) {
heads.push_back(n->inputs[rnn_enum::kStateCell]);
if (params.state_outputs) {
heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kStateCellOut, 0});
heads.push_back(ograd[rnn_enum::kStateCellOut]);
}
}
return MakeGradNode(op_name, n, heads, n->attrs.dict);
}
};
NNVM_REGISTER_OP(RNN)
.describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are
implemented, with both multi-layer and bidirectional support.
When the input data is of type float32 and the environment variables MXNET_CUDA_ALLOW_TENSOR_CORE
and MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION are set to 1, this operator will try to use
pseudo-float16 precision (float32 math with float16 I/O) precision in order to use
Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups.
**Vanilla RNN**
Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported:
ReLU and Tanh.
With ReLU activation function:
.. math::
h_t = relu(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh})
With Tanh activtion function:
.. math::
h_t = \tanh(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh})
Reference paper: Finding structure in time - Elman, 1988.
https://crl.ucsd.edu/~elman/Papers/fsit.pdf
**LSTM**
Long Short-Term Memory - Hochreiter, 1997. http://www.bioinf.jku.at/publications/older/2604.pdf
.. math::
\begin{array}{ll}
i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\
f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\
o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\
c_t = f_t * c_{(t-1)} + i_t * g_t \\
h_t = o_t * \tanh(c_t)
\end{array}
**GRU**
Gated Recurrent Unit - Cho et al. 2014. http://arxiv.org/abs/1406.1078
The definition of GRU here is slightly different from paper but compatible with CUDNN.
.. math::
\begin{array}{ll}
r_t = \mathrm{sigmoid}(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\
\end{array}
)code" ADD_FILELINE)
.set_attr_parser(ParamParser<RNNParam>)
.set_num_inputs([](const NodeAttrs& attrs) {
const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
return params.mode == rnn_enum::kLstm ? 4 : 3;
})
.set_num_outputs([](const NodeAttrs& attrs) {
const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
// kOut
int num_outputs = 1;
if (params.state_outputs) {
// kOut, kStateOut, kStateCellOut
num_outputs = (params.mode == rnn_enum::kLstm) ? 3 : 2;
}
return num_outputs;
})
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
return ListArguments(params);
})
.set_attr<mxnet::FInferShape>("FInferShape", RNNShape)
.set_attr<nnvm::FInferType>("FInferType", RNNType)
.set_attr<FCreateOpState>("FCreateOpState", CreateRNNState)
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", RNNGrad{"_backward_RNN"})
.set_attr<FResourceRequestEx>("FResourceRequestEx",
[](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) {
std::vector<ResourceRequest> request;
const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
if (param.p == 0) return request;
if (dev_mask == kGPU) {
#if MXNET_USE_CUDNN_RNN
if (1.0f - param.p > 0) {
request.emplace_back(ResourceRequest::kCuDNNDropoutDesc);
return request;
}
#endif
}
return request;
})
.add_argument("data", "NDArray-or-Symbol", "Input data to RNN")
.add_argument("parameters", "NDArray-or-Symbol",
"Vector of all RNN trainable parameters concatenated")
.add_argument("state", "NDArray-or-Symbol", "initial hidden state of the RNN")
.add_argument("state_cell", "NDArray-or-Symbol",
"initial cell state for LSTM networks (only for LSTM)")
.add_arguments(RNNParam::__FIELDS__());
NNVM_REGISTER_OP(_backward_RNN)
.set_num_outputs([](const NodeAttrs& attrs) {
const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
return params.mode == rnn_enum::kLstm ? 4 : 3;
})
.set_attr_parser(ParamParser<RNNParam>)
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulGradCompute<cpu>);
} // namespace op
} // namespace mxnet