blob: b2ac2f0cb615e4d2b57836291c5f93e3cec461f2 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file rnn.cc
* \brief
* \author Sebastian Bodenstein
*/
#include <iterator>
#include "./rnn-inl.h"
namespace mxnet {
namespace op {
DMLC_REGISTER_PARAMETER(RNNParam);
static inline std::vector<std::string> ListArguments(const RNNParam& param_) {
// All RNNs start off with same 3 input arguments
std::vector<std::string> arguments{"data", "parameters", "state"};
// LSTMs also have an additional state_cell argument
if (param_.mode == rnn_enum::kLstm) {
arguments.emplace_back("state_cell");
}
// All RNNs have option of additional sequence_length argument
if (param_.use_sequence_length) {
arguments.emplace_back("sequence_length");
}
return arguments;
}
static bool RNNShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
const RNNParam& param_ = nnvm::get<RNNParam>(attrs.parsed);
using namespace mshadow;
// Query param_ object to figure out what the expectd input arguments are
std::vector<std::string> expected_arguments = ListArguments(param_);
CHECK_EQ(in_shape->size(), expected_arguments.size()) << "Input shape mismatch. Expected " <<
expected_arguments.size() << " input parameters but got " << in_shape->size() << ".";
const TShape &dshape = (*in_shape)[rnn_enum::kData];
if (!mxnet::ndim_is_known(dshape)) return false;
CHECK_EQ(dshape.ndim(), 3U) \
<< "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]";
// data: [sequence len, batch, input dimension]
int batch_size = dshape[1];
int input_size = dshape[2];
int numDirections = param_.bidirectional ? 2 : 1;
int total_layers = numDirections * param_.num_layers; // double for bidirectional
int layer_size = (param_.projection_size.has_value()) ?
param_.projection_size.value() : param_.state_size;
SHAPE_ASSIGN_CHECK(*in_shape,
rnn_enum::kState,
Shape3(total_layers, batch_size, layer_size));
if (param_.mode == rnn_enum::kLstm) {
SHAPE_ASSIGN_CHECK(*in_shape,
rnn_enum::kStateCell,
Shape3(total_layers, batch_size, param_.state_size));
}
// calculate parameter vector length
int param_size = GetRnnParamSize(param_.num_layers,
input_size,
param_.state_size,
numDirections,
param_.mode,
param_.projection_size);
SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size));
// Check on sequence_length shape if using
if (param_.use_sequence_length) {
size_t seq_len_input_idx = rnn_enum::kSequenceLength;
if (param_.mode != rnn_enum::kLstm) --seq_len_input_idx;
SHAPE_ASSIGN_CHECK(*in_shape, seq_len_input_idx, Shape1(batch_size));
}
out_shape->clear();
// output: [sequence len, batch, output size]
TShape oshape = dshape;
if (param_.projection_size.has_value()) {
oshape[2] = numDirections * param_.projection_size.value();
} else {
oshape[2] = numDirections * param_.state_size;
}
out_shape->push_back(oshape);
if (param_.state_outputs) {
// outStateShape: [layer_num, batch, state size]
TShape outStateShape = dshape;
outStateShape[0] = total_layers;
outStateShape[1] = batch_size;
if (param_.projection_size.has_value()) {
outStateShape[2] = param_.projection_size.value();
} else {
outStateShape[2] = param_.state_size;
}
out_shape->push_back(outStateShape);
// Deal with lstm cell state
if (param_.mode == rnn_enum::kLstm) {
TShape cellStateShape = dshape;
cellStateShape[0] = total_layers;
cellStateShape[1] = batch_size;
cellStateShape[2] = param_.state_size;
out_shape->push_back(cellStateShape);
}
}
return true;
}
static bool RNNType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type) {
const RNNParam& param_ = nnvm::get<RNNParam>(attrs.parsed);
CHECK_EQ(in_type->size(), GetNumInputArguments(param_));
size_t seq_len_input_idx = rnn_enum::kSequenceLength;
if (param_.mode != rnn_enum::kLstm) --seq_len_input_idx;
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
std::vector<std::string> arguments = ListArguments(param_);
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
TYPE_ASSIGN_CHECK(*in_type, i, dtype);
} else {
// If using sequence length argument, it has its own indexing type
// All other input arguments must match the main data type
if (!(param_.use_sequence_length && i == seq_len_input_idx)) {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, arguments[i]);
}
}
}
out_type->clear();
out_type->push_back(dtype);
if (param_.state_outputs) {
out_type->push_back(dtype);
// Deal with lstm cell state
if (param_.mode == rnn_enum::kLstm) {
out_type->push_back(dtype);
}
}
return true;
}
static std::vector<ResourceRequest> RNNResourceEx(const NodeAttrs& attrs, const int dev_mask,
const DispatchMode dispatch_mode) {
std::vector<ResourceRequest> request;
if (dev_mask == kGPU) {
#if MXNET_USE_CUDNN == 1
request.emplace_back(ResourceRequest::kTempSpace);
const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
if (param.p != 0 && 1.0f - param.p > 0) {
request.emplace_back(ResourceRequest::kCuDNNDropoutDesc);
}
#endif
}
return request;
}
inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
DispatchMode wanted_mode = DispatchMode::kFCompute;
#if MXNET_USE_MKLDNN == 1
wanted_mode = DispatchMode::kFComputeEx;
#endif
return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, wanted_mode);
}
struct RNNGrad {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr &n,
const std::vector<nnvm::NodeEntry> &ograd) const {
const RNNParam& params = nnvm::get<RNNParam>(n->attrs.parsed);
std::vector<nnvm::NodeEntry> heads{ n->inputs[rnn_enum::kData],
n->inputs[rnn_enum::kParams], n->inputs[rnn_enum::kState] };
heads.emplace_back(n, rnn_enum::kOut, 0);
heads.push_back(ograd[rnn_enum::kOut]);
if (params.state_outputs) {
heads.emplace_back(n, rnn_enum::kStateOut, 0);
heads.push_back(ograd[rnn_enum::kStateOut]);
}
if (params.mode == rnn_enum::kLstm) {
heads.push_back(n->inputs[rnn_enum::kStateCell]);
if (params.state_outputs) {
heads.emplace_back(n, rnn_enum::kStateCellOut, 0);
heads.push_back(ograd[rnn_enum::kStateCellOut]);
}
}
return MakeGradNode(op_name, n, heads, n->attrs.dict);
}
};
#if MXNET_USE_MKLDNN == 1
static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
std::vector<TBlob> in_blobs;
std::vector<TBlob> out_blobs;
std::vector<NDArray> temp_ndarrays_i;
std::vector<NDArray> temp_ndarrays_o;
for (const NDArray& in : inputs) {
if (in.storage_type() == kDefaultStorage) {
temp_ndarrays_i.push_back(in.Reorder2Default());
in_blobs.emplace_back(temp_ndarrays_i.back().data());
} else {
in_blobs.emplace_back(in.data());
}
}
for (const NDArray& out : outputs) {
if (out.storage_type() == kDefaultStorage) {
temp_ndarrays_o.push_back(out.Reorder2Default());
out_blobs.emplace_back(temp_ndarrays_o.back().data());
} else {
out_blobs.emplace_back(out.data());
}
}
int dtype = in_blobs[rnn_enum::kData].type_flag_;
int itype = in_blobs[inputs.size()-1].type_flag_;
mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype);
Stream<cpu> *s = ctx.get_stream<cpu>();
auto cpu_engine = CpuEngine::Get()->get_engine();
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
MSHADOW_TYPE_SWITCH(itype, IType, {
RNNOp<cpu, DType, IType>& op = state_ptr.get_state<RNNOp<cpu, DType, IType>>();
const RNNParam& param = op.param_;
int ngates = 0, nstates = 0;
GetMKLDNNRNNAlgo(param.mode, &ngates, &nstates);
int D = param.bidirectional ? 2 : 1;
Tensor<cpu, 3, DType> x = in_blobs[rnn_enum::kData].get<cpu, 3, DType>(s);
int T = x.shape_[0];
int N = x.shape_[1];
int I = x.shape_[2];
int H = param.state_size;
int L = param.num_layers;
const size_t r_size = GetMKLDNNRNNCacheMemorySize(L, D, T, N, I, H, param.mode);
if (op.init_mem_ && op.reserve_mem_size_ < r_size) {
op.init_mem_ = false;
}
const size_t weights_version = inputs[rnn_enum::kParams].version();
if (!op.init_mem_) {
op.mem_space_ = NDArray(TShape({static_cast<dim_t>(r_size)}), op.ctx_, false, dtype);
op.reserve_mem_size_ = r_size;
op.init_mem_ = true;
op.has_cache = false;
// Assign weights_version
op.weights_version = weights_version;
}
// Check if NDArray was changed.
if (op.weights_version != weights_version) {
op.has_cache = false;
op.weights_version = weights_version;
}
DType* workptr = static_cast<DType*>(op.mem_space_.data().dptr_);
mkldnn::memory::dims src_layer_tz_0 = {T, N, I};
mkldnn::memory::dims src_layer_tz = {T, N, D * H};
mkldnn::memory::dims dst_layer_tz = {T, N, D * H};
auto dst_layer_md = mkldnn::memory::desc(
{ dst_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc);
if (op.x_memory.size() == 0) {
if (D == 1 && I == H) {
auto user_src_layer_md = mkldnn::memory::desc(
{ src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc);
auto user_src_layer_memory_n = mkldnn::memory({ user_src_layer_md, cpu_engine });
op.x_memory.push_back(user_src_layer_memory_n);
mkldnn::memory::dims weights_layer_tz = {L, 1, I, ngates, H}; // ldigo
mkldnn::memory::dims weights_iter_tz = {L, 1, H, ngates, H}; // ldigo
mkldnn::memory::dims bias_tz = {L, 1, ngates, H};
auto user_weight_layer_md = mkldnn::memory::desc(
{ weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
auto user_weight_iter_md = mkldnn::memory::desc(
{ weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
auto user_bias_md = mkldnn::memory::desc({ bias_tz },
mkldnn_dtype, mkldnn::memory::format::ldgo);
DType* weight_layer_n = workptr; // L * I * ngates * H
auto user_weight_layer_memory_n
= mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n);
op.wx_memory.push_back(user_weight_layer_memory_n);
DType* weight_iter_n = weight_layer_n + L * I * ngates * H; // L * H * ngates * H
auto user_weight_iter_memory_n
= mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n);
op.wh_memory.push_back(user_weight_iter_memory_n);
DType* bias_n = weight_iter_n + L * H * ngates * H; // L * ngates * H
auto user_bias_memory_n =
mkldnn::memory({ user_bias_md, cpu_engine }, bias_n);
op.bias_memory.push_back(user_bias_memory_n);
auto wx_md_n = mkldnn::memory::desc(
{ weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
DType* wx_n = bias_n + L * ngates * H; // L * ngates * I * H
auto wx_memory_n =
mkldnn::memory({ wx_md_n, cpu_engine }, wx_n);
DType* wh_n = wx_n + L * ngates * I * H; // L * ngates * H * H
auto wh_md_n = mkldnn::memory::desc(
{ weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
auto wh_memory_n =
mkldnn::memory({ wh_md_n, cpu_engine }, wh_n);
op.concat_weight_memory.push_back(wx_memory_n);
op.concat_weight_memory.push_back(wh_memory_n);
workptr = wh_n + L * ngates * H * H;
mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc
auto src_iter_md_n1 = mkldnn::memory::desc(
{ src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
for (int l = 0; l < L; l++) {
DType* src_iter_n1 = workptr; // nstates * N * H
auto src_iter_memory_n1 =
mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1);
op.concat_iter_memory.push_back(src_iter_memory_n1);
workptr = src_iter_n1 + nstates * N * H;
}
mkldnn::memory::dims src_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc
auto src_iter_md_n = mkldnn::memory::desc(
{ src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
DType* src_iter_n = workptr; // L * nstates * N * H
auto src_iter_memory_n =
mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n);
op.concat_iter_memory.push_back(src_iter_memory_n);
op.hcx_memory.push_back(src_iter_memory_n);
DType* dst_layer_n = src_iter_n + L * nstates * N * H; // T * N * D * H
auto dst_layer_memory_n
= mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n);
op.y_memory.push_back(dst_layer_memory_n);
mkldnn::memory::dims dst_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc
auto dst_iter_md_n = mkldnn::memory::desc(
{ dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
DType* dst_iter_n = dst_layer_n + T * N * D * H; // L * nstates * N * H
auto dst_iter_memory_n =
mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n);
op.hcy_memory.push_back(dst_iter_memory_n);
workptr = dst_iter_n + L * nstates * N * H;
} else {
auto user_src_layer_md_0 = mkldnn::memory::desc(
{ src_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::tnc);
auto user_src_layer_memory_0 = mkldnn::memory({ user_src_layer_md_0, cpu_engine });
op.x_memory.push_back(user_src_layer_memory_0);
mkldnn::memory::dims weights_layer_tz_0 = {1, D, I, ngates, H}; // ldigo
mkldnn::memory::dims weights_iter_tz_0 = {1, D, H, ngates, H}; // ldigo
mkldnn::memory::dims bias_tz_0 = {1, D, ngates, H};
auto user_weight_layer_md_0 = mkldnn::memory::desc(
{ weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo);
auto user_weight_iter_md_0 = mkldnn::memory::desc(
{ weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo);
auto user_bias_md_0 = mkldnn::memory::desc({ bias_tz_0 },
mkldnn_dtype, mkldnn::memory::format::ldgo);
DType* weight_layer_0 = workptr; // D * I * ngates * H
auto user_weight_layer_memory_0
= mkldnn::memory({ user_weight_layer_md_0, cpu_engine }, weight_layer_0);
op.wx_memory.push_back(user_weight_layer_memory_0);
DType* weight_iter_0 = weight_layer_0 + D * I * ngates * H; // D * H * ngates * H
auto user_weight_iter_memory_0
= mkldnn::memory({ user_weight_iter_md_0, cpu_engine }, weight_iter_0);
op.wh_memory.push_back(user_weight_iter_memory_0);
DType* bias_0 = weight_iter_0 + D * H * ngates * H; // D * ngates * H
auto user_bias_memory_0 =
mkldnn::memory({ user_bias_md_0, cpu_engine }, bias_0);
op.bias_memory.push_back(user_bias_memory_0);
workptr = bias_0 + D * ngates * H;
auto wx_md_0 = mkldnn::memory::desc(
{ weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
auto wx_memory_0 =
mkldnn::memory({ wx_md_0, cpu_engine });
auto wh_md_0 = mkldnn::memory::desc(
{ weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
auto wh_memory_0 =
mkldnn::memory({ wh_md_0, cpu_engine });
if (D == 2) {
DType* wx_0 = workptr; // D * ngates * I * H
wx_memory_0.set_data_handle(wx_0);
DType* wh_0 = wx_0 + D * ngates * I * H; // D * ngates * H * H
wh_memory_0.set_data_handle(wh_0);
workptr = wh_0 + D * ngates * H * H;
}
op.concat_weight_memory.push_back(wx_memory_0);
op.concat_weight_memory.push_back(wh_memory_0);
mkldnn::memory::dims src_iter_undi_tz_0 = {1, 1, nstates, N, H}; // ldsnc
auto src_iter_undi_md_0 = mkldnn::memory::desc(
{ src_iter_undi_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
DType* src_iter_undi_0 = workptr; // nstates * N * H
auto src_iter_undi_memory_0 =
mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi_0);
op.concat_iter_memory.push_back(src_iter_undi_memory_0);
workptr = src_iter_undi_0 + nstates * N * H;
if (D == 1) {
op.hcx_memory.push_back(src_iter_undi_memory_0);
} else {
DType* src_iter_undi2_0 = workptr; // nstates * N * H
auto src_iter_undi2_memory_0 =
mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi2_0);
op.concat_iter_memory.push_back(src_iter_undi2_memory_0);
mkldnn::memory::dims src_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc
auto src_iter_md_0 = mkldnn::memory::desc(
{ src_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
DType* src_iter_0 = src_iter_undi2_0 + nstates * N * H; // D * nstates * N * H
auto src_iter_memory_0 =
mkldnn::memory({ src_iter_md_0, cpu_engine }, src_iter_0);
op.concat_iter_memory.push_back(src_iter_memory_0);
op.hcx_memory.push_back(src_iter_memory_0);
workptr = src_iter_0 + D * nstates * N * H;
}
DType* dst_layer_0 = workptr; // T * N * D * H
auto dst_layer_memory_0
= mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_0);
op.y_memory.push_back(dst_layer_memory_0);
mkldnn::memory::dims dst_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc
auto dst_iter_md_0 = mkldnn::memory::desc(
{ dst_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
DType* dst_iter_0 = dst_layer_0 + T * N * D * H; // D * nstates * N * H
auto dst_iter_memory_0 =
mkldnn::memory({ dst_iter_md_0, cpu_engine }, dst_iter_0);
op.hcy_memory.push_back(dst_iter_memory_0);
workptr = dst_iter_0 + D * nstates * N * H;
// next L - 1 layers
if (L > 1 && D == 1) {
auto user_src_layer_md = mkldnn::memory::desc(
{ src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc);
auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine });
op.x_memory.push_back(user_src_layer_memory);
mkldnn::memory::dims weights_layer_tz = {L - 1, 1, H, ngates, H}; // ldigo
mkldnn::memory::dims weights_iter_tz = {L - 1, 1, H, ngates, H}; // ldigo
mkldnn::memory::dims bias_tz = {L - 1, 1, ngates, H};
auto user_weight_layer_md = mkldnn::memory::desc(
{ weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
auto user_weight_iter_md = mkldnn::memory::desc(
{ weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
auto user_bias_md = mkldnn::memory::desc({ bias_tz },
mkldnn_dtype, mkldnn::memory::format::ldgo);
DType* weight_layer_n = workptr; // (L - 1) * H * ngates * H
auto user_weight_layer_memory_n
= mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n);
op.wx_memory.push_back(user_weight_layer_memory_n);
DType* weight_iter_n = weight_layer_n +
(L - 1) * H * ngates * H; // (L - 1) * H * ngates * H
auto user_weight_iter_memory_n
= mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n);
op.wh_memory.push_back(user_weight_iter_memory_n);
DType* bias_n = weight_iter_n + (L - 1) * H * ngates * H; // (L - 1) * ngates * H
auto user_bias_memory_n =
mkldnn::memory({ user_bias_md, cpu_engine }, bias_n);
op.bias_memory.push_back(user_bias_memory_n);
auto wx_md_n = mkldnn::memory::desc(
{ weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
DType* wx_n = bias_n + (L - 1) * ngates * H; // (L - 1) * ngates * H * H
auto wx_memory_n =
mkldnn::memory({ wx_md_n, cpu_engine }, wx_n);
DType* wh_n = wx_n + (L - 1) * ngates * H * H; // (L - 1) * ngates * H * H
auto wh_md_n = mkldnn::memory::desc(
{ weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
auto wh_memory_n =
mkldnn::memory({ wh_md_n, cpu_engine }, wh_n);
op.concat_weight_memory.push_back(wx_memory_n);
op.concat_weight_memory.push_back(wh_memory_n);
workptr = wh_n + (L - 1) * ngates * H * H;
mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc
auto src_iter_md_n1 = mkldnn::memory::desc(
{ src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
for (int l = 0; l < L - 1; l++) {
DType* src_iter_n1 = workptr; // nstates * N * H
auto src_iter_memory_n1 =
mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1);
op.concat_iter_memory.push_back(src_iter_memory_n1);
workptr = src_iter_n1 + nstates * N * H;
}
mkldnn::memory::dims src_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc
auto src_iter_md_n = mkldnn::memory::desc(
{ src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
DType* src_iter_n = workptr; // (L - 1) * nstates * N * H
auto src_iter_memory_n =
mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n);
op.concat_iter_memory.push_back(src_iter_memory_n);
op.hcx_memory.push_back(src_iter_memory_n);
DType* dst_layer_n = src_iter_n + (L - 1) * nstates * N * H; // T * N * D * H
auto dst_layer_memory_n
= mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n);
op.y_memory.push_back(dst_layer_memory_n);
mkldnn::memory::dims dst_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc
auto dst_iter_md_n = mkldnn::memory::desc(
{ dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
DType* dst_iter_n = dst_layer_n + T * N * D * H; // (L - 1) * nstates * N * H
auto dst_iter_memory_n =
mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n);
op.hcy_memory.push_back(dst_iter_memory_n);
}
if (L > 1 && D == 2) {
mkldnn::memory::dims weights_layer_tz = {1, D, H * D, ngates, H}; // ldigo
mkldnn::memory::dims weights_iter_tz = {1, D, H, ngates, H}; // ldigo
mkldnn::memory::dims bias_tz = {1, D, ngates, H};
auto user_weight_layer_md = mkldnn::memory::desc(
{ weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
auto user_weight_iter_md = mkldnn::memory::desc(
{ weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
auto user_bias_md = mkldnn::memory::desc({ bias_tz },
mkldnn_dtype, mkldnn::memory::format::ldgo);
auto user_src_layer_md = mkldnn::memory::desc(
{ src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc);
auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine });
op.x_memory.push_back(user_src_layer_memory);
auto wx_md_n = mkldnn::memory::desc(
{ weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
auto wh_md_n = mkldnn::memory::desc(
{ weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
for (int l = 0; l < L; l++) {
DType* weight_layer_n = workptr; // D * (H * D) * ngates * H
auto user_weight_layer_memory_n
= mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n);
op.wx_memory.push_back(user_weight_layer_memory_n);
DType* weight_iter_n = weight_layer_n +
D * (H * D) * ngates * H; // D * H * ngates * H
auto user_weight_iter_memory_n
= mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n);
op.wh_memory.push_back(user_weight_iter_memory_n);
DType* bias_n = weight_iter_n + D * H * ngates * H; // D * ngates * H
auto user_bias_memory_n =
mkldnn::memory({ user_bias_md, cpu_engine }, bias_n);
op.bias_memory.push_back(user_bias_memory_n);
workptr = bias_n + D * ngates * H;
}
DType* wx_n = workptr; // D * ngates * (D * H) * H
DType* wh_n = wx_n + D * ngates * (D * H) * H; // D * ngates * H * H
auto wx_memory_n =
mkldnn::memory({ wx_md_n, cpu_engine }, wx_n);
auto wh_memory_n =
mkldnn::memory({ wh_md_n, cpu_engine }, wh_n);
op.concat_weight_memory.push_back(wx_memory_n);
op.concat_weight_memory.push_back(wh_memory_n);
mkldnn::memory::dims src_iter_undi_tz = {1, 1, nstates, N, H}; // ldsnc
auto src_iter_undi_md = mkldnn::memory::desc(
{ src_iter_undi_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
DType* src_iter_undi = wh_n + D * ngates * H * H; // nstates * N * H
auto src_iter_undi_memory =
mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi);
op.concat_iter_memory.push_back(src_iter_undi_memory_0);
DType* src_iter_undi2 = src_iter_undi + nstates * N * H; // nstates * N * H
auto src_iter_undi2_memory =
mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi2);
op.concat_iter_memory.push_back(src_iter_undi2_memory);
mkldnn::memory::dims src_iter_tz = {1, D, nstates, N, H}; // ldsnc
auto src_iter_md = mkldnn::memory::desc(
{ src_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
DType* src_iter = src_iter_undi2 + nstates * N * H; // D * nstates * N * H
auto src_iter_memory =
mkldnn::memory({ src_iter_md, cpu_engine }, src_iter);
op.concat_iter_memory.push_back(src_iter_memory);
op.hcx_memory.push_back(src_iter_memory);
DType* dst_layer_n = src_iter + D * nstates * N * H; // T * N * D * H
auto dst_layer_memory_n
= mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n);
op.y_memory.push_back(dst_layer_memory_n);
mkldnn::memory::dims dst_iter_tz_n = {1, D, nstates, N, H}; // ldsnc
auto dst_iter_md_n = mkldnn::memory::desc(
{ dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
DType* dst_iter_n = dst_layer_n + T * N * D * H; // D * nstates * N * H
auto dst_iter_memory_n =
mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n);
op.hcy_memory.push_back(dst_iter_memory_n);
}
}
}
op.Forward(ctx, in_blobs, req, out_blobs);
});
});
}
static void RNNStatefulComputeExCPU(const OpStatePtr& state_ptr, const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNRNN(inputs[0])) {
RNNStatefulComputeCPU(state_ptr, ctx, inputs, req, outputs);
return;
}
int use_mkldnn_rnn = dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1);
dmlc::SetEnv("MXNET_USE_MKLDNN_RNN", 0);
FallBackCompute(RNNStatefulCompute<cpu>, state_ptr, ctx, inputs, req, outputs);
dmlc::SetEnv("MXNET_USE_MKLDNN_RNN", use_mkldnn_rnn);
}
#endif
NNVM_REGISTER_OP(RNN)
.add_alias("_npx_rnn")
.describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are
implemented, with both multi-layer and bidirectional support.
When the input data is of type float32 and the environment variables MXNET_CUDA_ALLOW_TENSOR_CORE
and MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION are set to 1, this operator will try to use
pseudo-float16 precision (float32 math with float16 I/O) precision in order to use
Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups.
**Vanilla RNN**
Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported:
ReLU and Tanh.
With ReLU activation function:
.. math::
h_t = relu(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh})
With Tanh activtion function:
.. math::
h_t = \tanh(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh})
Reference paper: Finding structure in time - Elman, 1988.
https://crl.ucsd.edu/~elman/Papers/fsit.pdf
**LSTM**
Long Short-Term Memory - Hochreiter, 1997. http://www.bioinf.jku.at/publications/older/2604.pdf
.. math::
\begin{array}{ll}
i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\
f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\
o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\
c_t = f_t * c_{(t-1)} + i_t * g_t \\
h_t = o_t * \tanh(c_t)
\end{array}
**GRU**
Gated Recurrent Unit - Cho et al. 2014. http://arxiv.org/abs/1406.1078
The definition of GRU here is slightly different from paper but compatible with CUDNN.
.. math::
\begin{array}{ll}
r_t = \mathrm{sigmoid}(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\
\end{array}
)code" ADD_FILELINE)
.set_attr_parser(ParamParser<RNNParam>)
.set_num_inputs([](const NodeAttrs& attrs) {
const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
return GetNumInputArguments(params);
})
.set_num_outputs([](const NodeAttrs& attrs) {
const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
// kOut
int num_outputs = 1;
if (params.state_outputs) {
// kOut, kStateOut, kStateCellOut
num_outputs = (params.mode == rnn_enum::kLstm) ? 3 : 2;
}
return num_outputs;
})
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
return ListArguments(params);
})
.set_attr<mxnet::FInferShape>("FInferShape", RNNShape)
.set_attr<nnvm::FInferType>("FInferType", RNNType)
.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
.set_attr<FCreateOpState>("FCreateOpState", CreateRNNState)
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", RNNStatefulComputeExCPU)
#endif
.set_attr<nnvm::FGradient>("FGradient", RNNGrad{"_backward_RNN"})
.set_attr<FResourceRequestEx>("FResourceRequestEx", RNNResourceEx)
.add_argument("data", "NDArray-or-Symbol", "Input data to RNN")
.add_argument("parameters", "NDArray-or-Symbol",
"Vector of all RNN trainable parameters concatenated")
.add_argument("state", "NDArray-or-Symbol", "initial hidden state of the RNN")
.add_argument("state_cell", "NDArray-or-Symbol",
"initial cell state for LSTM networks (only for LSTM)")
.add_argument("sequence_length", "NDArray-or-Symbol",
"Vector of valid sequence lengths for each element in batch. (Only used if"
" use_sequence_length kwarg is True)")
.add_arguments(RNNParam::__FIELDS__());
NNVM_REGISTER_OP(_backward_RNN)
.set_num_outputs([](const NodeAttrs& attrs) {
const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
return GetNumInputArguments(params);
})
.set_attr_parser(ParamParser<RNNParam>)
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulGradCompute<cpu>)
.set_attr<FResourceRequestEx>("FResourceRequestEx", RNNResourceEx);
} // namespace op
} // namespace mxnet