| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| #include "./cudnn_rnn.h" |
| #ifdef USE_CUDNN |
| #include <cudnn.h> |
| #if CUDNN_VERSION >= 5005 |
| #include <chrono> |
| #include "./cudnn_utils.h" |
| #include "singa/utils/logging.h" |
| |
| namespace singa { |
| RegisterLayerClass(cudnn_rnn, CudnnRNN); |
| CudnnRNN::~CudnnRNN() { |
| if (weight_desc_ != nullptr) |
| CUDNN_CHECK(cudnnDestroyFilterDescriptor(weight_desc_)); |
| if (dropout_desc_ != nullptr) |
| CUDNN_CHECK(cudnnDestroyDropoutDescriptor(dropout_desc_)); |
| if (rnn_desc_ != nullptr) |
| CUDNN_CHECK(cudnnDestroyRNNDescriptor(rnn_desc_)); |
| if (hx_desc_ != nullptr) |
| CUDNN_CHECK(cudnnDestroyTensorDescriptor(hx_desc_)); |
| if (hy_desc_ != nullptr) |
| CUDNN_CHECK(cudnnDestroyTensorDescriptor(hy_desc_)); |
| if (cx_desc_ != nullptr) |
| CUDNN_CHECK(cudnnDestroyTensorDescriptor(cx_desc_)); |
| if (cy_desc_ != nullptr) |
| CUDNN_CHECK(cudnnDestroyTensorDescriptor(cy_desc_)); |
| if (dhx_desc_ != nullptr) |
| CUDNN_CHECK(cudnnDestroyTensorDescriptor(dhx_desc_)); |
| if (dhy_desc_ != nullptr) |
| CUDNN_CHECK(cudnnDestroyTensorDescriptor(dhy_desc_)); |
| if (dcx_desc_ != nullptr) |
| CUDNN_CHECK(cudnnDestroyTensorDescriptor(dcx_desc_)); |
| if (dcy_desc_ != nullptr) |
| CUDNN_CHECK(cudnnDestroyTensorDescriptor(dcy_desc_)); |
| DestroyIODescriptors(); |
| } |
| |
| void CudnnRNN::ToDevice(std::shared_ptr<Device> device) { |
| RNN::ToDevice(device); |
| workspace_.ToDevice(device); |
| reserve_space_.ToDevice(device); |
| dropout_state_.ToDevice(device); |
| } |
| |
| void CudnnRNN::DestroyIODescriptors() { |
| if (x_descs_ != nullptr) { |
| for (size_t i = 0; i < max_length_; i++) { |
| CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_descs_[i])); |
| CUDNN_CHECK(cudnnDestroyTensorDescriptor(dx_descs_[i])); |
| } |
| delete [] x_descs_; |
| delete [] dx_descs_; |
| } |
| if (y_descs_ != nullptr) { |
| for (size_t i = 0; i < max_length_; i++) { |
| CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_descs_[i])); |
| CUDNN_CHECK(cudnnDestroyTensorDescriptor(dy_descs_[i])); |
| } |
| delete [] y_descs_; |
| delete [] dy_descs_; |
| } |
| } |
| |
| void CudnnRNN::UpdateIODescriptors(size_t len, const vector<Tensor> &inputs) { |
| bool reset = false; |
| if (max_length_ < len) { |
| DestroyIODescriptors(); |
| max_length_ = len; |
| x_descs_ = new cudnnTensorDescriptor_t[len]; |
| dx_descs_ = new cudnnTensorDescriptor_t[len]; |
| y_descs_ = new cudnnTensorDescriptor_t[len]; |
| dy_descs_ = new cudnnTensorDescriptor_t[len]; |
| for (size_t i = 0; i < len; i++) { |
| CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_descs_[i])); |
| CUDNN_CHECK(cudnnCreateTensorDescriptor(&dx_descs_[i])); |
| CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_descs_[i])); |
| CUDNN_CHECK(cudnnCreateTensorDescriptor(&dy_descs_[i])); |
| } |
| reset = true; |
| } |
| |
| for (size_t i = 0; i < len; i++) { |
| CHECK_EQ(inputs[i].shape(1), input_size_); |
| if (inputs[i].shape(0) != batch_size_ || reset) { |
| int d[3] = {1, 1, 1}, s[3] = {1, 1, 1}; |
| d[0] = static_cast<int>(inputs[i].shape(0)); |
| CHECK_GT(d[0], 0); |
| d[1] = static_cast<int>(inputs[i].shape(1)); |
| s[0] = d[1] * d[2]; |
| s[1] = d[2]; |
| CUDNN_CHECK(cudnnSetTensorNdDescriptor(x_descs_[i], dtype_, 3, d, s)); |
| CUDNN_CHECK(cudnnSetTensorNdDescriptor(dx_descs_[i], dtype_, 3, d, s)); |
| |
| d[0] = static_cast<int>(inputs[i].shape(0)); |
| d[1] = static_cast<int>(hidden_size_ * num_directions_); |
| s[0] = d[1] * d[2]; |
| s[1] = d[2]; |
| CUDNN_CHECK(cudnnSetTensorNdDescriptor(y_descs_[i], dtype_, 3, d, s)); |
| CUDNN_CHECK(cudnnSetTensorNdDescriptor(dy_descs_[i], dtype_, 3, d, s)); |
| } |
| } |
| } |
| |
| // must be called after setting IO descriptors |
| void CudnnRNN::SetRNNDescriptor(shared_ptr<Device> dev) { |
| auto ctx = dev->context(0); |
| CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc_)); |
| size_t state_size; |
| CUDNN_CHECK(cudnnDropoutGetStatesSize(ctx->cudnn_handle, &state_size)); |
| dropout_state_ = Tensor(Shape{state_size}, dev, kChar); |
| CUDNN_CHECK(cudnnSetDropoutDescriptor( |
| dropout_desc_, ctx->cudnn_handle, 1 - dropout_, // keep probability |
| dropout_state_.block()->mutable_data(), state_size, seed_)); |
| |
| CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc_)); |
| cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; |
| if (input_mode_ == "skip") |
| input_mode = CUDNN_SKIP_INPUT; |
| |
| cudnnDirectionMode_t direction = CUDNN_UNIDIRECTIONAL; |
| if (direction_ == "bidirectional") |
| direction = CUDNN_BIDIRECTIONAL; |
| |
| cudnnRNNMode_t rnn_mode = CUDNN_LSTM; |
| if (rnn_mode_ == "relu") |
| rnn_mode = CUDNN_RNN_RELU; |
| else if (rnn_mode_ == "tanh") |
| rnn_mode = CUDNN_RNN_TANH; |
| else if (rnn_mode_ == "gru") |
| rnn_mode = CUDNN_GRU; |
| #if CUDNN_MAJOR <= 5 |
| CUDNN_CHECK(cudnnSetRNNDescriptor(rnn_desc_, hidden_size_, num_stacks_, |
| dropout_desc_, input_mode, direction, |
| rnn_mode, dtype_)); |
| #else |
| CUDNN_CHECK(cudnnSetRNNDescriptor(ctx->cudnn_handle, rnn_desc_, hidden_size_, num_stacks_, |
| dropout_desc_, input_mode, direction, |
| rnn_mode, CUDNN_RNN_ALGO_STANDARD, dtype_)); |
| #endif |
| size_t weight_size; |
| CUDNN_CHECK(cudnnGetRNNParamsSize(ctx->cudnn_handle, rnn_desc_, x_descs_[0], |
| &weight_size, dtype_)); |
| // check the size manually calculated |
| CHECK_EQ(weight_size, weight_.Size() * sizeof(float)); |
| int filter_dim[3] = {static_cast<int>(weight_size), 1, 1}; |
| CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc_)); |
| CUDNN_CHECK(cudnnSetFilterNdDescriptor(weight_desc_, dtype_, |
| CUDNN_TENSOR_NCHW, 3, filter_dim)); |
| } |
| |
| void CudnnRNN::ResetHiddenAndCellDescriptors(size_t batch_size) { |
| if (batch_size_ == 0) { |
| CUDNN_CHECK(cudnnCreateTensorDescriptor(&cx_desc_)); |
| CUDNN_CHECK(cudnnCreateTensorDescriptor(&dcx_desc_)); |
| CUDNN_CHECK(cudnnCreateTensorDescriptor(&cy_desc_)); |
| CUDNN_CHECK(cudnnCreateTensorDescriptor(&dcy_desc_)); |
| CUDNN_CHECK(cudnnCreateTensorDescriptor(&hx_desc_)); |
| CUDNN_CHECK(cudnnCreateTensorDescriptor(&dhx_desc_)); |
| CUDNN_CHECK(cudnnCreateTensorDescriptor(&hy_desc_)); |
| CUDNN_CHECK(cudnnCreateTensorDescriptor(&dhy_desc_)); |
| } |
| |
| int dim[3] = {1, 1, 1}; |
| dim[0] = static_cast<int>(num_stacks_ * num_directions_); |
| dim[1] = static_cast<int>(batch_size); |
| dim[2] = static_cast<int>(hidden_size_); |
| int stride[3] = {1, 1, 1}; |
| stride[0] = dim[1] * dim[2]; |
| stride[1] = dim[2]; |
| CUDNN_CHECK(cudnnSetTensorNdDescriptor(hx_desc_, dtype_, 3, dim, stride)); |
| CUDNN_CHECK(cudnnSetTensorNdDescriptor(dhx_desc_, dtype_, 3, dim, stride)); |
| CUDNN_CHECK(cudnnSetTensorNdDescriptor(hy_desc_, dtype_, 3, dim, stride)); |
| CUDNN_CHECK(cudnnSetTensorNdDescriptor(dhy_desc_, dtype_, 3, dim, stride)); |
| CUDNN_CHECK(cudnnSetTensorNdDescriptor(cx_desc_, dtype_, 3, dim, stride)); |
| CUDNN_CHECK(cudnnSetTensorNdDescriptor(dcx_desc_, dtype_, 3, dim, stride)); |
| CUDNN_CHECK(cudnnSetTensorNdDescriptor(cy_desc_, dtype_, 3, dim, stride)); |
| CUDNN_CHECK(cudnnSetTensorNdDescriptor(dcy_desc_, dtype_, 3, dim, stride)); |
| } |
| |
| void CudnnRNN::UpdateSpaces(size_t seq_length, shared_ptr<Device> dev) { |
| size_t count; |
| auto ctx = dev->context(0); |
| CUDNN_CHECK(cudnnGetRNNWorkspaceSize(ctx->cudnn_handle, rnn_desc_, |
| seq_length, x_descs_, &count)); |
| if (workspace_.Size() != count) { |
| workspace_ = Tensor(Shape{count}, dev, kChar); |
| // workspace_.SetValue(0); |
| } |
| |
| CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(ctx->cudnn_handle, rnn_desc_, |
| seq_length, x_descs_, &count)); |
| if (reserve_space_.Size() != count) { |
| reserve_space_ = Tensor(Shape{count}, dev, kChar); |
| // reserve_space_.SetValue(0); |
| } |
| } |
| |
| void CudnnRNN::UpdateStates(size_t num_x, const vector<Tensor> &inputs) { |
| UpdateIODescriptors(num_x, inputs); |
| size_t new_batch_size = inputs.at(0).shape(0); |
| if (batch_size_ != new_batch_size) |
| ResetHiddenAndCellDescriptors(new_batch_size); |
| if (rnn_desc_ == nullptr) |
| SetRNNDescriptor(inputs.at(0).device()); |
| UpdateSpaces(num_x, inputs.at(0).device()); |
| batch_size_ = new_batch_size; |
| seq_length_ = num_x; |
| } |
| |
| Tensor CudnnRNN::MergeInputs(size_t num, const vector<Tensor> &in) { |
| if (num == 1) |
| return in.at(0); |
| size_t size = 0; |
| for (size_t i = 0; i < num; i++) size += in.at(i).Size(); |
| Tensor out(Shape{size}, in.at(0).device(), in.at(0).data_type()); |
| for (size_t i = 0, offset = 0; i < num; i++) { |
| CopyDataToFrom(&out, in.at(i), in.at(i).Size(), offset); |
| offset += in.at(i).Size(); |
| } |
| return out; |
| } |
| |
| vector<Tensor> CudnnRNN::SplitOutput(size_t num, size_t dim, |
| const vector<Tensor> &in, |
| const Tensor output) { |
| vector<Tensor> outputs; |
| if (num == 1) { |
| outputs.push_back(Reshape(output, Shape{in.at(0).shape(0), dim})); |
| } else { |
| for (size_t i = 0, offset = 0; offset < output.Size(); i++) { |
| Shape s{in.at(i).shape(0), dim}; |
| Tensor out(s, output.device(), output.data_type()); |
| CopyDataToFrom(&out, output, out.Size(), 0, offset); |
| outputs.push_back(out); |
| offset += out.Size(); |
| } |
| CHECK_EQ(num, outputs.size()); |
| } |
| return outputs; |
| } |
| |
| const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) { |
| DataType dtype = inputs.at(0).data_type(); |
| auto dev = inputs.at(0).device(); |
| |
| // copy input data into a block of contiguous memory |
| // hx (and cx) is at the end of inputs |
| CHECK_GT(inputs.size(), 1u + has_cell_); |
| size_t num_x = inputs.size() - has_cell_ - 1; |
| Tensor input = MergeInputs(num_x, inputs); |
| // LOG(INFO) << "input size " << input.Size() << " value " << input.L1(); |
| |
| if (rnn_desc_ != nullptr) |
| CHECK_EQ(dtype_, GetCudnnDataType(dtype)) |
| << "Cannot change cudnn data type during training from " << dtype_ |
| << " to " << GetCudnnDataType(dtype); |
| else |
| dtype_ = GetCudnnDataType(dtype); |
| |
| UpdateStates(num_x, inputs); |
| // CheckFowardShapes(); |
| |
| Shape outshape{input.Size() * hidden_size_ / input_size_ * num_directions_}; |
| Tensor output(outshape, dev, dtype); |
| // LOG(INFO) << "output size " << output.Size(); |
| Tensor hx = inputs.at(num_x); |
| Shape state_shape{num_stacks_ * num_directions_, batch_size_, hidden_size_}; |
| Tensor hy(state_shape, dev, dtype); |
| Tensor cy, cx; |
| if (has_cell_) { |
| cx = inputs.at(num_x + 1); |
| cy.ResetLike(hy); |
| } |
| |
| int did = input.device()->id(); |
| CHECK_EQ(did, output.device()->id()); |
| if (hx.Size()) { |
| CHECK_EQ(did, hx.device()->id()); |
| CHECK_EQ(hx.device()->lang(), kCuda); |
| } |
| if (cx.Size()) { |
| CHECK_EQ(did, cx.device()->id()); |
| CHECK_EQ(cx.device()->lang(), kCuda); |
| } |
| CHECK_EQ(did, weight_.device()->id()); |
| CHECK_EQ(did, workspace_.device()->id()); |
| CHECK_EQ(input.device()->lang(), kCuda); |
| CHECK_EQ(output.device()->lang(), kCuda); |
| CHECK_EQ(weight_.device()->lang(), kCuda); |
| CHECK_EQ(workspace_.device()->lang(), kCuda); |
| |
| // LOG(INFO) << "hidden size " << hy.Size(); |
| // LOG(INFO) << "weight size " << weight_.Size() << " value " << weight_.L1(); |
| Block *inb = input.block(), *outb = output.block(), |
| *wb = this->weight_.block(), *hxb = hx.block(), *cxb = cx.block(), |
| *hyb = hy.block(), *cyb = cy.block(), |
| *wspace = this->workspace_.block(), |
| *rspace = this->reserve_space_.block(); |
| if (flag & kTrain) { |
| CHECK_EQ(reserve_space_.device()->lang(), kCuda); |
| CHECK_EQ(did, reserve_space_.device()->id()); |
| dev->Exec( |
| [inb, outb, wb, hxb, cxb, hyb, cyb, wspace, rspace, this](Context * ctx) { |
| // clang-format off |
| cudnnRNNForwardTraining( |
| ctx->cudnn_handle, |
| this->rnn_desc_, |
| this->seq_length_, |
| this->x_descs_, inb->data(), |
| this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(), |
| this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(), |
| this->weight_desc_, wb->data(), |
| this->y_descs_, outb->mutable_data(), |
| this->hy_desc_, hyb->mutable_data(), |
| this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(), |
| wspace->mutable_data(), |
| this->workspace_.Size(), rspace->mutable_data(), |
| this->reserve_space_.Size()); |
| // clang-format on |
| }, |
| {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace}); |
| buf_.push(input); |
| buf_.push(output); |
| buf_.push(hx); |
| buf_.push(cx); |
| } else { |
| dev->Exec([inb, outb, wb, hxb, cxb, hyb, cyb, wspace, this](Context * ctx) { |
| // clang-format off |
| cudnnRNNForwardInference( |
| ctx->cudnn_handle, |
| this->rnn_desc_, |
| this->seq_length_, |
| this->x_descs_, inb->data(), |
| this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(), |
| this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(), |
| this->weight_desc_, wb->data(), |
| this->y_descs_, outb->mutable_data(), |
| this->hy_desc_, hyb->mutable_data(), |
| this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(), |
| wspace->mutable_data(), this->workspace_.Size()); |
| // clang-format on |
| }, {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace}); |
| } |
| auto outputs = |
| SplitOutput(num_x, hidden_size_ * num_directions_, inputs, output); |
| outputs.push_back(hy); |
| if (has_cell_) outputs.push_back(cy); |
| return outputs; |
| } |
| |
| // TODO(wangwei) check Tensor device to be on cuda? |
| const std::pair<vector<Tensor>, vector<Tensor>> CudnnRNN::Backward( |
| int flag, const vector<Tensor> &grads) { |
| // dhy (and dcy) is at last |
| const Tensor cx = buf_.top(); // cannot use const Tensor& due to pop() |
| buf_.pop(); |
| const Tensor hx = buf_.top(); |
| buf_.pop(); |
| const Tensor y = buf_.top(); |
| buf_.pop(); |
| const Tensor x = buf_.top(); |
| buf_.pop(); |
| |
| auto dev = y.device(); |
| auto dtype = y.data_type(); |
| |
| CHECK_GT(grads.size(), 1u + has_cell_); |
| size_t num_dy = grads.size() - has_cell_ - 1; |
| CHECK_EQ(num_dy, seq_length_); |
| const Tensor dy = MergeInputs(num_dy, grads); |
| CHECK_EQ(dy.Size(), y.Size()); |
| const Tensor dhy = grads.at(num_dy); |
| Tensor dcy; |
| if (has_cell_) |
| dcy = grads.at(num_dy + 1); |
| |
| Shape xshape{y.Size() * input_size_ / hidden_size_ / num_directions_}; |
| Tensor dx(xshape, dev, dtype); |
| Tensor dw(weight_.shape(), dev, dtype); |
| Shape state_shape{num_stacks_ * num_directions_, batch_size_, hidden_size_}; |
| Tensor dhx(state_shape, dev, dtype); |
| Tensor dcx; |
| if (has_cell_) |
| dcx.ResetLike(dhx); |
| dw.SetValue(0.0f); |
| Block *yb = y.block(), *dyb = dy.block(), *dhyb = dhy.block(), |
| *dcyb = dcy.block(), *xb = x.block(), *cxb = cx.block(), |
| *wb = weight_.block(), *dwb = dw.block(), *hxb = hx.block(), |
| *dxb = dx.block(), *dhxb = dhx.block(), *dcxb = dcx.block(), |
| *wspace = workspace_.block(), *rspace = reserve_space_.block(); |
| |
| y.device()->Exec( |
| [yb, dyb, dhyb, dcyb, xb, cxb, wb, dwb, hxb, dxb, dhxb, dcxb, wspace, |
| rspace, this](Context * ctx) { |
| // clang-format off |
| cudnnRNNBackwardData( |
| ctx->cudnn_handle, |
| this->rnn_desc_, |
| this->seq_length_, |
| this->y_descs_, yb->data(), |
| this->dy_descs_, dyb->data(), |
| this->dhy_desc_, dhyb == nullptr ? nullptr : dhyb->data(), |
| this->dcy_desc_, dcyb == nullptr ? nullptr : dcyb->data(), |
| this->weight_desc_, wb->data(), |
| this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(), |
| this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(), |
| this->dx_descs_, dxb->mutable_data(), |
| this->dhx_desc_, dhxb->mutable_data(), |
| this->dcx_desc_, dcxb == nullptr ? nullptr : dcxb->mutable_data(), |
| wspace->mutable_data(), this->workspace_.Size(), |
| rspace->mutable_data(), this->reserve_space_.Size()); |
| cudnnRNNBackwardWeights( |
| ctx->cudnn_handle, |
| this->rnn_desc_, |
| this->seq_length_, |
| this->x_descs_, xb->data(), |
| this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(), |
| this->y_descs_, yb->data(), |
| wspace->data(), this->workspace_.Size(), |
| this->dweight_desc_, dwb->mutable_data(), |
| rspace->data(), this->reserve_space_.Size()); |
| // clang-format on |
| }, |
| {yb, dyb, dhyb, dcyb, xb, wb, wspace, rspace}, |
| {dxb, dwb, dhxb, dcxb, wspace, rspace}); |
| |
| vector <Tensor> param_grad{dw}; |
| auto data_grads = SplitOutput(num_dy, input_size_, grads, dx); |
| data_grads.push_back(dhx); |
| if (has_cell_) |
| data_grads.push_back(dcx); |
| return std::make_pair(data_grads, param_grad); |
| } |
| |
| } // namespace singa |
| #endif // CUDNN_VERSION >= 5005 |
| #endif // USE_CUDNN |