| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file rnn_impl.h |
| * \brief |
| * \author Shu Zhang |
| */ |
| #ifndef MXNET_OPERATOR_RNN_IMPL_H_ |
| #define MXNET_OPERATOR_RNN_IMPL_H_ |
| |
| #include <dmlc/logging.h> |
| #include <dmlc/parameter.h> |
| #include <mxnet/operator.h> |
| #include <algorithm> |
| #include <random> |
| #include <map> |
| #include <vector> |
| #include <string> |
| #include <utility> |
| #include "./math.h" |
| #include "./math_functions-inl.h" |
| #include "./operator_common.h" |
| #include "./mshadow_op.h" |
| #include "./linalg.h" |
| |
| namespace mxnet { |
| namespace op { |
| |
| template <typename DType> |
| inline DType sigmoid(DType x) { |
| return 1.0f / (1.0f + exp(-x)); |
| } |
| |
| template <typename DType> |
| inline DType relu(DType x) { |
| return x > 0.0f ? static_cast<float>(x) : 0.0f; |
| } |
| |
| template <typename DType> |
| void LstmForwardTrainingSingleLayer(DType* ws, |
| DType* rs, |
| bool state_outputs, |
| bool bid, |
| const index_t T, |
| const index_t N, |
| const index_t I, |
| const int H, |
| const Tensor<cpu, 2, DType>& x, |
| const Tensor<cpu, 2, DType>& hx, |
| const Tensor<cpu, 2, DType>& cx, |
| const Tensor<cpu, 3, DType>& y, |
| DType* w_ptr, |
| DType* b_ptr, |
| DType* hy_ptr, |
| DType* cy_ptr) { |
| using namespace mshadow; |
| const Tensor<cpu, 2, DType> wx(w_ptr, Shape2(H * 4, I)); |
| const Tensor<cpu, 2, DType> wh(w_ptr + I * H * 4, Shape2(H * 4, H)); |
| const Tensor<cpu, 2, DType> bx(b_ptr, Shape2(4, H)); |
| const Tensor<cpu, 2, DType> bh(b_ptr + H * 4, Shape2(4, H)); |
| const Tensor<cpu, 2, DType> yx_flat(ws, Shape2(T * N, 4 * H)); |
| const Tensor<cpu, 2, DType> yh_flat(ws + T * N * H * 4, Shape2(N, 4 * H)); |
| const Tensor<cpu, 4, DType> yx(yx_flat.dptr_, Shape4(T, N, 4, H)); |
| const Tensor<cpu, 3, DType> yh(yh_flat.dptr_, Shape3(N, 4, H)); |
| Tensor<cpu, 2, DType> h(yh_flat.dptr_ + N * H * 4, Shape2(N, H)); |
| DType* c_ptr = bid ? rs + T * N * H * 7 : rs; |
| Tensor<cpu, 3, DType> c(c_ptr, Shape3(T, N, H)); |
| Tensor<cpu, 4, DType> ifgo(c_ptr + T * N * H, Shape4(T, N, H, 4)); |
| |
| const int offset = bid ? H : 0; |
| const DType alpha = 1.0; |
| const DType beta = 0.0; |
| const index_t cell_size = N * H; |
| linalg_gemm(x, wx, yx_flat, alpha, beta, false, true); |
| |
| const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); |
| for (index_t i = 0; i < T; ++i) { |
| index_t t = bid ? T - 1 - i : i; |
| linalg_gemm(i ? h : hx, wh, yh_flat, alpha, beta, false, true); |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t jk = 0; jk < cell_size; ++jk) { |
| index_t j = jk / H; |
| index_t k = jk % H; |
| DType it = sigmoid<DType>(yx[t][j][0][k] + yh[j][0][k] + bx[0][k] + bh[0][k]); |
| DType ft = sigmoid<DType>(yx[t][j][1][k] + yh[j][1][k] + bx[1][k] + bh[1][k]); |
| DType gt = tanh(yx[t][j][2][k] + yh[j][2][k] + bx[2][k] + bh[2][k]); |
| DType ot = sigmoid<DType>(yx[t][j][3][k] + yh[j][3][k] + bx[3][k] + bh[3][k]); |
| DType ct = (i ? c[i - 1][j][k] : cx[j][k]) * ft + it * gt; |
| DType ht = ot * tanh(ct); |
| h[j][k] = ht; |
| // reserve |
| y[t][j][k + offset] = ht; |
| c[i][j][k] = ct; |
| ifgo[i][j][k][0] = it; |
| ifgo[i][j][k][1] = ft; |
| ifgo[i][j][k][2] = gt; |
| ifgo[i][j][k][3] = ot; |
| if (i == T - 1 && state_outputs) { |
| hy_ptr[jk] = ht; |
| cy_ptr[jk] = ct; |
| } |
| } |
| } |
| } |
| |
| template <typename DType> |
| void LstmForwardTraining(DType* ws, |
| DType* rs, |
| bool state_outputs, |
| const int L, |
| const int D, |
| const index_t T, |
| const index_t N, |
| const index_t I, |
| const int H, |
| DType* x_ptr, |
| DType* hx_ptr, |
| DType* cx_ptr, |
| DType* w_ptr, |
| DType* b_ptr, |
| DType* y_ptr, |
| DType* hy_ptr, |
| DType* cy_ptr, |
| const float dropout, |
| std::mt19937& rnd_engine) { // NOLINT(runtime/references) |
| DType* dropout_random = rs; |
| DType* rs2 = dropout_random + (L - 1) * D * T * N * H; |
| const int total_layers = D * L; |
| Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(total_layers, N, H)); |
| Tensor<cpu, 3, DType> cx(cx_ptr, Shape3(total_layers, N, H)); |
| const index_t b_size = 2 * H * 4; |
| const index_t r_size = D * T * N * H * 6; |
| const index_t y_offset = T * N * H * 5; |
| const index_t cell_size = N * H; |
| int idx = 0; // state & cell state's idx; |
| const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); |
| for (int i = 0; i < L; ++i) { |
| const index_t input_size = i ? H * D : I; |
| const index_t w_size = (input_size + H) * H * 4; |
| Tensor<cpu, 2, DType> x(x_ptr, Shape2(T * N, input_size)); |
| Tensor<cpu, 3, DType> y(rs2 + y_offset, Shape3(T, N, H * D)); |
| LstmForwardTrainingSingleLayer<DType>(ws, |
| rs2, |
| state_outputs, |
| false, |
| T, |
| N, |
| input_size, |
| H, |
| x, |
| hx[idx], |
| cx[idx], |
| y, |
| w_ptr, |
| b_ptr, |
| hy_ptr, |
| cy_ptr); |
| if (D == 2) { |
| w_ptr += w_size; |
| b_ptr += b_size; |
| ++idx; |
| if (state_outputs) { |
| hy_ptr += cell_size; |
| cy_ptr += cell_size; |
| } |
| LstmForwardTrainingSingleLayer<DType>(ws, |
| rs2, |
| state_outputs, |
| true, |
| T, |
| N, |
| input_size, |
| H, |
| x, |
| hx[idx], |
| cx[idx], |
| y, |
| w_ptr, |
| b_ptr, |
| hy_ptr, |
| cy_ptr); |
| } |
| if (i != L - 1) { |
| w_ptr += w_size; |
| b_ptr += b_size; |
| if (dropout > 0.0f) { |
| std::uniform_real_distribution<float> distribution(0, 1); |
| for (index_t j = 0; j < T * N * H * D; j++) { |
| if (distribution(rnd_engine) < dropout) { |
| dropout_random[i * T * N * H * D + j] = 0; |
| y.dptr_[j] = 0; |
| } else { |
| dropout_random[i * T * N * H * D + j] = 1.0f - dropout; |
| y.dptr_[j] = y.dptr_[j] / (1.0f - dropout); |
| } |
| } |
| } |
| x_ptr = y.dptr_; |
| rs2 += r_size; |
| ++idx; |
| if (state_outputs) { |
| hy_ptr += cell_size; |
| cy_ptr += cell_size; |
| } |
| } |
| } |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < T * N * H * D; ++i) { |
| y_ptr[i] = (rs2 + y_offset)[i]; |
| } |
| } |
| |
| template <typename DType> |
| void LstmForwardInferenceSingleLayer(DType* ws, |
| bool state_outputs, |
| bool bid, |
| const index_t T, |
| const index_t N, |
| const index_t I, |
| const int H, |
| const int P, |
| const Tensor<cpu, 2, DType>& x, |
| const Tensor<cpu, 2, DType>& hx, |
| const Tensor<cpu, 2, DType>& cx, |
| const Tensor<cpu, 3, DType>& y, |
| DType* w_ptr, |
| DType* b_ptr, |
| DType* hy_ptr, |
| DType* cy_ptr) { |
| using namespace mshadow; |
| const Tensor<cpu, 2, DType> wx(w_ptr, Shape2(H * 4, I)); |
| const Tensor<cpu, 2, DType> wh(w_ptr + I * H * 4, Shape2(H * 4, (P ? P : H))); |
| Tensor<cpu, 2, DType> whr(w_ptr, Shape2(1, 1)); |
| if (P > 0) |
| whr = Tensor<cpu, 2, DType>(wh.dptr_ + P * 4 * H, Shape2(P, H)); |
| const Tensor<cpu, 2, DType> bx(b_ptr, Shape2(4, H)); |
| const Tensor<cpu, 2, DType> bh(b_ptr + H * 4, Shape2(4, H)); |
| Tensor<cpu, 2, DType> yx_flat(ws, Shape2(T * N, H * 4)); |
| Tensor<cpu, 2, DType> yh_flat(ws + T * N * H * 4, Shape2(N, H * 4)); |
| const Tensor<cpu, 4, DType> yx(yx_flat.dptr_, Shape4(T, N, 4, H)); |
| const Tensor<cpu, 3, DType> yh(yh_flat.dptr_, Shape3(N, 4, H)); |
| Tensor<cpu, 2, DType> h(yh_flat.dptr_ + N * H * 4, Shape2(N, H)); |
| Tensor<cpu, 2, DType> c(h.dptr_ + N * H, Shape2(N, H)); |
| Tensor<cpu, 2, DType> r(hy_ptr, Shape2(1, 1)); |
| if (P > 0) |
| r = Tensor<cpu, 2, DType>(hy_ptr, Shape2(N, P)); |
| const int offset = bid ? H : 0; |
| const int proj_offset = bid ? P : 0; |
| const DType alpha = 1.0; |
| const DType beta = 0.0; |
| const index_t cell_size = N * H; |
| linalg_gemm(x, wx, yx_flat, alpha, beta, false, true); |
| |
| const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); |
| for (index_t i = 0; i < T; ++i) { |
| index_t t = bid ? T - 1 - i : i; |
| if (P > 0) { |
| linalg_gemm(i ? r : hx, wh, yh_flat, alpha, beta, false, true); |
| } else { |
| linalg_gemm(i ? h : hx, wh, yh_flat, alpha, beta, false, true); |
| } |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t jk = 0; jk < cell_size; ++jk) { |
| int j = jk / H; |
| int k = jk % H; |
| DType it = sigmoid<DType>(yx[t][j][0][k] + yh[j][0][k] + bx[0][k] + bh[0][k]); |
| DType ft = sigmoid<DType>(yx[t][j][1][k] + yh[j][1][k] + bx[1][k] + bh[1][k]); |
| DType gt = tanh(yx[t][j][2][k] + yh[j][2][k] + bx[2][k] + bh[2][k]); |
| DType ot = sigmoid<DType>(yx[t][j][3][k] + yh[j][3][k] + bx[3][k] + bh[3][k]); |
| DType ct = (i ? c[j][k] : cx[j][k]) * ft + it * gt; |
| DType ht = ot * tanh(ct); |
| if (P == 0) |
| y[t][j][k + offset] = ht; |
| if (i == T - 1 && state_outputs) { |
| if (P == 0) |
| hy_ptr[jk] = ht; |
| cy_ptr[jk] = ct; |
| } else { |
| c[j][k] = ct; |
| } |
| h[j][k] = ht; |
| } |
| if (P > 0) { |
| linalg_gemm(h, whr, r, alpha, beta, false, true); |
| #pragma GCC diagnostic push |
| #if __GNUC__ >= 8 |
| #pragma GCC diagnostic ignored "-Wclass-memaccess" |
| #endif |
| #pragma omp parallel for num_threads(omp_threads) |
| for (int j = 0; j < N; ++j) { |
| std::memcpy(y[t][j].dptr_ + proj_offset, r[j].dptr_, P * sizeof(DType)); |
| } |
| #pragma GCC diagnostic pop |
| } |
| } |
| } |
| |
| template <typename DType> |
| void LstmForwardInference(DType* ws, |
| bool state_outputs, |
| const int L, |
| const int D, |
| const index_t T, |
| const index_t N, |
| const index_t I, |
| const int H, |
| const int P, |
| DType* x_ptr, |
| DType* hx_ptr, |
| DType* cx_ptr, |
| DType* w_ptr, |
| DType* b_ptr, |
| DType* y_ptr, |
| DType* hy_ptr, |
| DType* cy_ptr) { |
| const int total_layers = D * L; |
| Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(total_layers, N, P ? P : H)); |
| Tensor<cpu, 3, DType> cx(cx_ptr, Shape3(total_layers, N, H)); |
| const index_t b_size = 2 * H * 4; |
| const index_t cell_size = N * H; |
| const index_t projection_size = (P ? P : H) * N; |
| DType* y_tmp_ptr = ws + (T + 1) * cell_size * 4 + cell_size * 2; |
| DType* y_cur_ptr = y_ptr; |
| int idx = 0; // state & cell state's idx; |
| bool flag = L % 2 ? false : true; |
| for (int i = 0; i < L; ++i) { |
| const index_t input_size = i ? (P ? P : H) * D : I; |
| index_t w_size = (input_size + (P ? P : H)) * H * 4; |
| if (P > 0) { |
| w_size += P * H; |
| } |
| // If bidirectional, need space to save current layer output y. |
| if (D == 2) { |
| y_cur_ptr = flag ? y_tmp_ptr : y_ptr; |
| flag = !flag; |
| } |
| Tensor<cpu, 2, DType> x(x_ptr, Shape2(T * N, input_size)); |
| Tensor<cpu, 3, DType> y(y_cur_ptr, Shape3(T, N, (P ? P : H) * D)); |
| LstmForwardInferenceSingleLayer<DType>(ws, |
| state_outputs, |
| false, |
| T, |
| N, |
| input_size, |
| H, |
| P, |
| x, |
| hx[idx], |
| cx[idx], |
| y, |
| w_ptr, |
| b_ptr, |
| hy_ptr, |
| cy_ptr); |
| // If bidirectional, then calculate the reverse direction's forward result. |
| if (D == 2) { |
| w_ptr += w_size; |
| b_ptr += b_size; |
| ++idx; |
| if (state_outputs) { |
| hy_ptr += projection_size; |
| cy_ptr += cell_size; |
| } |
| LstmForwardInferenceSingleLayer<DType>(ws, |
| state_outputs, |
| true, |
| T, |
| N, |
| input_size, |
| H, |
| P, |
| x, |
| hx[idx], |
| cx[idx], |
| y, |
| w_ptr, |
| b_ptr, |
| hy_ptr, |
| cy_ptr); |
| } |
| // Don't need to move pointer in the last layer. |
| if (i != L - 1) { |
| w_ptr += w_size; |
| b_ptr += b_size; |
| x_ptr = y_cur_ptr; |
| ++idx; |
| if (state_outputs) { |
| hy_ptr += projection_size; |
| cy_ptr += cell_size; |
| } |
| } |
| } |
| } |
| |
| template <typename DType> |
| void LstmBackwardSingleLayer(DType* ws, |
| DType* rs, |
| DType* tmp_buf, |
| bool bid, |
| const index_t T, |
| const index_t N, |
| const index_t I, |
| const int H, |
| const Tensor<cpu, 2, DType>& x, |
| const Tensor<cpu, 2, DType>& hx, |
| const Tensor<cpu, 2, DType>& cx, |
| const Tensor<cpu, 3, DType>& y, |
| const Tensor<cpu, 3, DType>& dy, |
| const Tensor<cpu, 2, DType>& dx, |
| const Tensor<cpu, 2, DType>& dhx, |
| const Tensor<cpu, 2, DType>& dcx, |
| DType* dhy_ptr, |
| DType* dcy_ptr, |
| DType* w_ptr, |
| DType* dw_ptr, |
| DType* db_ptr, |
| int req_data, |
| int req_params, |
| int req_state, |
| int req_statecell) { |
| using namespace mshadow; |
| const Tensor<cpu, 2, DType> wx(w_ptr, Shape2(H * 4, I)); |
| const Tensor<cpu, 2, DType> wh(w_ptr + I * H * 4, Shape2(H * 4, H)); |
| Tensor<cpu, 2, DType> dwx(dw_ptr, Shape2(H * 4, I)); |
| Tensor<cpu, 2, DType> dwh(dw_ptr + I * H * 4, Shape2(H * 4, H)); |
| Tensor<cpu, 1, DType> dbx(db_ptr, Shape1(H * 4)); |
| Tensor<cpu, 1, DType> dbh(dbx.dptr_ + H * 4, Shape1(H * 4)); |
| DType* c_ptr = bid ? rs + T * N * H * 7 : rs; |
| const Tensor<cpu, 3, DType> c(c_ptr, Shape3(T, N, H)); |
| const Tensor<cpu, 4, DType> ifgo(c_ptr + T * N * H, Shape4(T, N, H, 4)); |
| const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); |
| if (req_params != kNullOp && req_params != kAddTo) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (int i = 0; i < H * 4 * H; ++i) { |
| dwh.dptr_[i] = 0; |
| } |
| #pragma omp parallel for num_threads(omp_threads) |
| for (int i = 0; i < 4 * H; ++i) { |
| dbx.dptr_[i] = 0; |
| dbh.dptr_[i] = 0; |
| } |
| } |
| Tensor<cpu, 4, DType> difgo(ws, Shape4(T, N, 4, H)); |
| Tensor<cpu, 2, DType> dh(ws + T * N * H * 4, Shape2(N, H)); |
| Tensor<cpu, 2, DType> dc(dh.dptr_ + N * H, Shape2(N, H)); |
| Tensor<cpu, 2, DType> htmp(dc.dptr_ + N * H, Shape2(N, H)); |
| const int offset = bid ? H : 0; |
| const DType alpha = 1.0; |
| const DType beta0 = 0.0; |
| const DType beta1 = 1.0; |
| const DType beta2 = 2.0; |
| const index_t cell_size = N * H; |
| if (dhy_ptr != nullptr) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < cell_size; ++i) { |
| dh.dptr_[i] = dhy_ptr[i]; |
| } |
| } else { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < cell_size; ++i) { |
| dh.dptr_[i] = 0; |
| } |
| } |
| if (dcy_ptr != nullptr) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < cell_size; ++i) { |
| dc.dptr_[i] = dcy_ptr[i]; |
| } |
| } else { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < cell_size; ++i) { |
| dc.dptr_[i] = 0; |
| } |
| } |
| |
| for (index_t i = T - 1; i >= 0; --i) { |
| index_t t = bid ? T - 1 - i : i; |
| index_t tnext = bid ? t + 1 : t - 1; |
| const Tensor<cpu, 2, DType>& dhnext = i ? dh : dhx; |
| const Tensor<cpu, 2, DType>& dcnext = i ? dc : dcx; |
| const Tensor<cpu, 2, DType>& hnext = i ? htmp : hx; |
| const Tensor<cpu, 2, DType>& cnext = i ? c[i - 1] : cx; |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t jk = 0; jk < cell_size; ++jk) { |
| index_t j = jk / H; |
| index_t k = jk % H; |
| DType tc = tanh(c[i][j][k]); |
| DType it = ifgo[i][j][k][0]; |
| DType ft = ifgo[i][j][k][1]; |
| DType gt = ifgo[i][j][k][2]; |
| DType ot = ifgo[i][j][k][3]; |
| dh[j][k] += dy[t][j][k + offset]; |
| dc[j][k] += dh[j][k] * ot * (1 - tc * tc); |
| difgo[t][j][0][k] = dc[j][k] * gt * it * (1 - it); |
| difgo[t][j][1][k] = dc[j][k] * cnext[j][k] * ft * (1 - ft); |
| difgo[t][j][2][k] = dc[j][k] * it * (1 - gt * gt); |
| difgo[t][j][3][k] = dh[j][k] * tc * ot * (1 - ot); |
| if (req_statecell != kNullOp || i > 0) { |
| dcnext[j][k] = dc[j][k] * ft; |
| } |
| if (i) { |
| htmp[j][k] = y[tnext][j][k + offset]; |
| } |
| } |
| Tensor<cpu, 2, DType> dyh(difgo[t].dptr_, Shape2(N, H * 4)); |
| if (req_state != kNullOp || i > 0) { |
| linalg_gemm(dyh, wh, dhnext, alpha, beta0, false, false); |
| } |
| if (req_params != kNullOp) { |
| if (req_params != kAddTo) { |
| linalg_gemm(dyh, hnext, dwh, alpha, beta1, true, false); |
| } else { |
| linalg_gemm(dyh, hnext, dwh, alpha, beta2, true, false); |
| |
| // generate dwx every time step for AddTo |
| Tensor<cpu, 2, DType> x_t(x.dptr_ + i * N * I, Shape2(N, I)); |
| Tensor<cpu, 2, DType> dyx_t(difgo.dptr_ + i * N * H * 4, Shape2(N, H * 4)); |
| linalg_gemm(dyx_t, x_t, dwx, alpha, beta2, true, false); |
| } |
| } |
| } |
| Tensor<cpu, 2, DType> dyx(difgo.dptr_, Shape2(T * N, H * 4)); |
| if (req_data != kNullOp) { |
| linalg_gemm(dyx, wx, dx, alpha, bid ? beta1 : beta0, false, false); |
| } |
| if (req_params != kNullOp && req_params != kAddTo) { |
| linalg_gemm(dyx, x, dwx, alpha, beta0, true, false); |
| } |
| const index_t row = T * N; |
| const index_t col = H * 4; |
| if (req_params != kNullOp) { |
| if (req_params != kAddTo) { |
| for (index_t i = 0; i < row; ++i) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t j = 0; j < col; ++j) { |
| dbx[j] += dyx[i][j]; |
| dbh[j] = dbx[j]; |
| } |
| } |
| } else { |
| const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf, Shape2(col, T)); |
| const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + col * T, Shape2(col, T)); |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < col * T; ++i) { |
| tmp_dbx.dptr_[i] = 0; |
| tmp_dbh.dptr_[i] = 0; |
| } |
| for (index_t t = T - 1; t >= 0; --t) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t j = 0; j < col; ++j) { |
| for (index_t i = 0; i < N; ++i) { |
| tmp_dbx[j][t] += dyx[t * N + i][j]; |
| tmp_dbh[j][t] = tmp_dbx[j][t]; |
| } |
| } |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t j = 0; j < col; ++j) { |
| dbx[j] += tmp_dbx[j][t] + dbx[j]; |
| dbh[j] += tmp_dbh[j][t] + dbh[j]; |
| } |
| } |
| } |
| } |
| } |
| |
| template <typename DType> |
| void LstmBackward(DType* ws, |
| DType* rs, |
| const int L, |
| const int D, |
| const index_t T, |
| const index_t N, |
| const index_t I, |
| const int H, |
| DType* x_ptr, |
| DType* hx_ptr, |
| DType* cx_ptr, |
| DType* w_ptr, |
| DType* y_ptr, |
| DType* dy_ptr, |
| DType* dhy_ptr, |
| DType* dcy_ptr, |
| DType* dx_ptr, |
| DType* dhx_ptr, |
| DType* dcx_ptr, |
| DType* dw_ptr, |
| DType* db_ptr, |
| int req_data, |
| int req_params, |
| int req_state, |
| int req_statecell, |
| const float dropout) { |
| DType* dropout_random = rs + (L - 1) * D * T * N * H; |
| DType* rs2 = rs + (L - 1) * D * T * N * H; |
| DType* tmp_buf = ws; |
| DType* ws2 = tmp_buf + 8 * T * H; |
| const int total_layers = D * L; |
| Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(total_layers, N, H)); |
| Tensor<cpu, 3, DType> cx(cx_ptr, Shape3(total_layers, N, H)); |
| Tensor<cpu, 3, DType> dhx(dhx_ptr, Shape3(total_layers, N, H)); |
| Tensor<cpu, 3, DType> dcx(dcx_ptr, Shape3(total_layers, N, H)); |
| const index_t b_size = 2 * H * 4; |
| const index_t r_size = D * T * N * H * 6; |
| const index_t y_offset = T * N * H * 5; |
| const index_t w_size1 = (I + H) * H * 4; // first layer |
| const index_t w_size2 = (D * H + H) * H * 4; // other layers |
| const index_t cell_size = N * H; |
| const index_t y_size = T * N * H * D; |
| DType* dy_tmp_ptr = ws2 + T * cell_size * 4 + cell_size * 3; |
| for (int i = L - 1; i >= 0; --i) { |
| const index_t input_size = i ? H * D : I; |
| const index_t w_size = i ? w_size2 : w_size1; |
| int idx = i * D; |
| DType* w_cur_ptr = i ? w_ptr + (w_size1 + (i - 1) * w_size2) * D : w_ptr; |
| DType* dw_cur_ptr = i ? dw_ptr + (w_size1 + (i - 1) * w_size2) * D : dw_ptr; |
| DType* db_cur_ptr = db_ptr + i * b_size * D; |
| DType* rs_cur_ptr = rs2 + i * r_size; |
| DType* dhy_cur_ptr = dhy_ptr ? dhy_ptr + i * cell_size * D : nullptr; |
| DType* dcy_cur_ptr = dcy_ptr ? dcy_ptr + i * cell_size * D : nullptr; |
| Tensor<cpu, 3, DType> y(rs_cur_ptr + y_offset, Shape3(T, N, H * D)); |
| Tensor<cpu, 3, DType> dy(dy_ptr, Shape3(T, N, H * D)); |
| Tensor<cpu, 2, DType> x(i ? y.dptr_ - r_size : x_ptr, Shape2(T * N, input_size)); |
| Tensor<cpu, 2, DType> dx(i ? dy_tmp_ptr : dx_ptr, Shape2(T * N, input_size)); |
| LstmBackwardSingleLayer<DType>(ws2, |
| rs_cur_ptr, |
| tmp_buf, |
| false, |
| T, |
| N, |
| input_size, |
| H, |
| x, |
| hx[idx], |
| cx[idx], |
| y, |
| dy, |
| dx, |
| dhx[idx], |
| dcx[idx], |
| dhy_cur_ptr, |
| dcy_cur_ptr, |
| w_cur_ptr, |
| dw_cur_ptr, |
| db_cur_ptr, |
| req_data, |
| req_params, |
| req_state, |
| req_statecell); |
| if (D == 2) { |
| w_cur_ptr += w_size; |
| dw_cur_ptr += w_size; |
| db_cur_ptr += b_size; |
| ++idx; |
| dhy_cur_ptr = dhy_ptr ? dhy_cur_ptr + cell_size : nullptr; |
| dcy_cur_ptr = dcy_ptr ? dcy_cur_ptr + cell_size : nullptr; |
| LstmBackwardSingleLayer<DType>(ws2, |
| rs_cur_ptr, |
| tmp_buf, |
| true, |
| T, |
| N, |
| input_size, |
| H, |
| x, |
| hx[idx], |
| cx[idx], |
| y, |
| dy, |
| dx, |
| dhx[idx], |
| dcx[idx], |
| dhy_cur_ptr, |
| dcy_cur_ptr, |
| w_cur_ptr, |
| dw_cur_ptr, |
| db_cur_ptr, |
| req_data, |
| req_params, |
| req_state, |
| req_statecell); |
| |
| // Prevent overwritting dy while calculating dx in left2right layer |
| const int loop_iteration = (L - 1) - i; |
| dy_tmp_ptr = loop_iteration % 2 ? dy_tmp_ptr - y_size : dy_tmp_ptr + y_size; |
| } |
| if (dropout > 0.0f && i > 0 && req_data != kNullOp) { |
| dropout_random = dropout_random - T * N * D * H; |
| const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t j = 0; j < T * N * D * H; j++) { |
| if (dropout_random[j] == 0) { |
| dx.dptr_[j] = 0; |
| } else { |
| dx.dptr_[j] = dx.dptr_[j] / (1.0f - dropout); |
| } |
| } |
| } |
| dy_ptr = dx.dptr_; |
| } |
| } |
| |
| template <typename DType> |
| void GruForwardInferenceSingleLayer(DType* ws, |
| DType* tmp_buf, |
| bool state_outputs, |
| const int D, |
| const index_t T, |
| const index_t N, |
| const index_t I, |
| const int H, |
| const Tensor<cpu, 2, DType>& x, |
| const Tensor<cpu, 2, DType>& hx, |
| DType* wx_ptr, |
| DType* wh_ptr, |
| DType* bx_ptr, |
| DType* bh_ptr, |
| DType* y_ptr, |
| DType* hy_ptr) { |
| DType* ht = y_ptr; |
| DType* ht_1 = y_ptr; |
| DType* back_ht_1 = y_ptr + (T - 1) * N * H * D + H; |
| DType* back_ht = back_ht_1; |
| DType* gemmC1 = ws; // [D, T, N, 3 * H] |
| DType* gemmC2 = gemmC1 + D * T * N * 3 * H; // N * 3 * H |
| DType* rt = gemmC2 + N * 3 * H; |
| DType* zt = rt + N * H; |
| DType* nt = zt + N * H; |
| DType* back_wx_ptr = wx_ptr + I * 3 * H + H * 3 * H; |
| DType* back_wh_ptr = wh_ptr + I * 3 * H + H * 3 * H; |
| DType* back_bx_ptr = (bx_ptr != nullptr) ? bx_ptr + 3 * H * 2 : nullptr; |
| DType* back_bh_ptr = (bh_ptr != nullptr) ? bh_ptr + 3 * H * 2 : nullptr; |
| DType* back_gemmC1 = gemmC1 + T * N * 3 * H; |
| DType* gemmC1_t = gemmC1; |
| |
| const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H * 3, I)); |
| const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H * 3, H)); |
| const Tensor<cpu, 2, DType> bx(bx_ptr, Shape2(3, H)); |
| const Tensor<cpu, 2, DType> bh(bh_ptr, Shape2(3, H)); |
| const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H * 3, I)); |
| const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H * 3, H)); |
| const Tensor<cpu, 2, DType> back_bx(back_bx_ptr, Shape2(3, H)); |
| const Tensor<cpu, 2, DType> back_bh(back_bh_ptr, Shape2(3, H)); |
| const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); |
| if (D == 1) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; i++) |
| for (int j = 0; j < H; j++) { |
| y_ptr[i * H + j] = hx[i][j]; |
| } |
| } else { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; i++) |
| for (int j = 0; j < H; j++) { |
| y_ptr[i * D * H + j] = hx[i][j]; |
| back_ht_1[i * D * H + j] = hx[N + i][j]; |
| } |
| } |
| Tensor<cpu, 2, DType> dgemmC1(ws, Shape2(T * N, 3 * H)); |
| Tensor<cpu, 2, DType> dgemmC2(gemmC2, Shape2(N, 3 * H)); |
| Tensor<cpu, 2, DType> dback_gemmC1(back_gemmC1, Shape2(T * N, 3 * H)); |
| |
| // x * wx.T : [T * N, I] * [I, 3 * H] |
| DType alpha = 1.0; |
| DType beta = 0.0; |
| linalg_gemm(x, wx, dgemmC1, alpha, beta, false, true); |
| if (D == 2) { |
| linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true); |
| } |
| |
| for (index_t t = 0; t < T; t++) { |
| // perform the first direction, X * wx and H * wh for each step |
| // ht-1 * wh, ht-1:[N, H] wh:[3 * H, H] |
| Tensor<cpu, 2, DType> dht_1(ht_1, Shape2(N, D * H)); |
| if (D == 1) { |
| linalg_gemm(dht_1, wh, dgemmC2, alpha, beta, false, true); |
| } else { |
| Tensor<cpu, 3, DType> dht_1_tmp = |
| Tensor<cpu, 3, DType>(reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N)); |
| dht_1_tmp = reshape(dht_1.T(), Shape3(D, H, N)); |
| linalg_gemm(dht_1_tmp[0], wh, dgemmC2, alpha, beta, true, true); |
| } |
| gemmC1_t = gemmC1 + t * N * 3 * H; |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; ++i) { |
| for (int j = 0; j < H; ++j) { |
| index_t rtb = i * 3 * H; |
| index_t ztb = i * 3 * H + H; |
| index_t ntb = i * 3 * H + 2 * H; |
| rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] + gemmC2[rtb + j] + bx[0][j] + bh[0][j]); |
| zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] + gemmC2[ztb + j] + bx[1][j] + bh[1][j]); |
| nt[i * H + j] = |
| tanh(gemmC1_t[ntb + j] + bx[2][j] + rt[i * H + j] * (gemmC2[ntb + j] + bh[2][j])); |
| ht[i * D * H + j] = |
| (1 - zt[i * H + j]) * nt[i * H + j] + zt[i * H + j] * ht_1[i * D * H + j]; |
| } |
| } |
| ht_1 = ht; |
| ht = ht + D * H * N; |
| // perform the second direction |
| if (D == 2) { |
| gemmC1_t = back_gemmC1 + (T - 1 - t) * N * 3 * H; |
| Tensor<cpu, 2, DType> dback_ht_1(back_ht_1 - H, Shape2(N, D * H)); |
| Tensor<cpu, 3, DType> dback_ht_1_tmp = |
| Tensor<cpu, 3, DType>(reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N)); |
| dback_ht_1_tmp = reshape(dback_ht_1.T(), Shape3(D, H, N)); |
| linalg_gemm(dback_ht_1_tmp[1], back_wh, dgemmC2, alpha, beta, true, true); |
| |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; ++i) { |
| for (int j = 0; j < H; ++j) { |
| index_t rtb = i * 3 * H; |
| index_t ztb = i * 3 * H + H; |
| index_t ntb = i * 3 * H + 2 * H; |
| rt[i * H + j] = |
| sigmoid(gemmC1_t[rtb + j] + gemmC2[rtb + j] + back_bx[0][j] + back_bh[0][j]); |
| zt[i * H + j] = |
| sigmoid(gemmC1_t[ztb + j] + gemmC2[ztb + j] + back_bx[1][j] + back_bh[1][j]); |
| nt[i * H + j] = tanh(gemmC1_t[ntb + j] + back_bx[2][j] + |
| rt[i * H + j] * (gemmC2[ntb + j] + back_bh[2][j])); |
| back_ht[i * D * H + j] = |
| (1 - zt[i * H + j]) * nt[i * H + j] + zt[i * H + j] * back_ht_1[i * D * H + j]; |
| } |
| } |
| back_ht_1 = back_ht; |
| back_ht = back_ht - D * H * N; |
| } |
| } |
| // copy last state to hy, from(N, H * D) to (D, N, H) |
| if (state_outputs) { |
| if (D == 1) { |
| DType* y_start = y_ptr + (T - 1) * N * H; |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; i++) |
| for (int j = 0; j < H; j++) { |
| hy_ptr[i * H + j] = y_start[i * H + j]; |
| } |
| } else { |
| DType* y_start = y_ptr + (T - 1) * N * H * D; |
| DType* y_back_start = y_ptr + H; |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; i++) |
| for (int j = 0; j < H; j++) { |
| hy_ptr[i * H + j] = y_start[i * D * H + j]; |
| hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j]; |
| } |
| } |
| } |
| } |
| |
| template <typename DType> |
| void GruForwardInference(DType* ws, |
| bool state_outputs, |
| const int L, |
| const int D, |
| const index_t T, |
| const index_t N, |
| index_t I, |
| const int H, |
| DType* x_ptr, |
| DType* hx_ptr, |
| DType* w_ptr, |
| DType* y_ptr, |
| DType* hy_ptr) { |
| DType* wx = w_ptr; |
| DType* wh = wx + I * H * 3; |
| DType* bx = |
| wh + H * H * 3 + (D - 1) * (H * H * 3 + I * H * 3) + (L - 1) * ((D + 1) * H) * H * 3 * D; |
| DType* bh = bx + H * 3; |
| |
| DType* y_tmp = ws; |
| DType* y_l = x_ptr; |
| DType* tmp_buf = y_tmp + D * T * N * H; |
| DType* ws2 = y_tmp + D * T * N * H + D * H * N; |
| |
| DType* wx_l = wx; |
| DType* wh_l = wh; |
| DType* bx_l = bx; |
| DType* bh_l = bh; |
| Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(D * L, N, H)); |
| DType* hy_l = hy_ptr; |
| for (int l = 0; l < L; l++) { |
| Tensor<cpu, 2, DType> x_l(y_l, Shape2(T * N, I)); |
| if ((L + l) % 2) { |
| y_l = y_ptr; |
| } else { |
| y_l = y_tmp; |
| } |
| Tensor<cpu, 2, DType> hx_l = hx[D * l]; |
| GruForwardInferenceSingleLayer<DType>( |
| ws2, tmp_buf, state_outputs, D, T, N, I, H, x_l, hx_l, wx_l, wh_l, bx_l, bh_l, y_l, hy_l); |
| hy_l = hy_l + D * N * H; |
| bx_l = bx_l + 3 * H * D * 2; |
| bh_l = bh_l + 3 * H * D * 2; |
| wx_l = wx_l + I * H * 3 * D + H * H * 3 * D; |
| if (l == 0) { |
| I = D * H; |
| } |
| wh_l = wx_l + I * 3 * H; |
| } |
| } |
| |
| template <typename DType> |
| void GruForwardTrainingSingleLayer(DType* ws, |
| DType* tmp_buf, |
| bool state_outputs, |
| const int D, |
| const index_t T, |
| const index_t N, |
| const index_t I, |
| const int H, |
| const Tensor<cpu, 2, DType>& x, |
| const Tensor<cpu, 2, DType>& hx, |
| DType* wx_ptr, |
| DType* wh_ptr, |
| DType* bx_ptr, |
| DType* bh_ptr, |
| DType* gateR, |
| DType* gateZ, |
| DType* gateN, |
| DType* Mnh, |
| DType* y_ptr, |
| DType* hy_ptr) { |
| DType* ht = y_ptr; |
| DType* ht_1 = y_ptr; |
| DType* back_ht_1 = y_ptr + (T - 1) * N * H * D + H; |
| DType* back_ht = back_ht_1; |
| |
| DType* gemmC1 = ws; // [D, T, N, 3 * H] |
| DType* gemmC2 = gemmC1 + D * T * N * 3 * H; // N * 3 * H |
| DType* rt = gateR; |
| DType* zt = gateZ; |
| DType* nt = gateN; |
| DType* back_wx_ptr = wx_ptr + I * 3 * H + H * 3 * H; |
| DType* back_wh_ptr = wh_ptr + I * 3 * H + H * 3 * H; |
| DType* back_bx_ptr = (bx_ptr != nullptr) ? bx_ptr + 3 * H * 2 : nullptr; |
| DType* back_bh_ptr = (bh_ptr != nullptr) ? bh_ptr + 3 * H * 2 : nullptr; |
| DType* back_gateR = gateR + T * N * H; |
| DType* back_gateZ = gateZ + T * N * H; |
| DType* back_gateN = gateN + T * N * H; |
| DType* back_Mnh = Mnh + T * N * H; |
| DType* back_gemmC1 = gemmC1 + T * N * 3 * H; |
| DType* gemmC1_t = gemmC1; |
| |
| const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H * 3, I)); |
| const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H * 3, H)); |
| const Tensor<cpu, 2, DType> bx(bx_ptr, Shape2(3, H)); |
| const Tensor<cpu, 2, DType> bh(bh_ptr, Shape2(3, H)); |
| const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H * 3, I)); |
| const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H * 3, H)); |
| const Tensor<cpu, 2, DType> back_bx(back_bx_ptr, Shape2(3, H)); |
| const Tensor<cpu, 2, DType> back_bh(back_bh_ptr, Shape2(3, H)); |
| const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); |
| if (D == 1) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; i++) |
| for (int j = 0; j < H; j++) { |
| y_ptr[i * H + j] = hx[i][j]; |
| } |
| } else { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; i++) |
| for (int j = 0; j < H; j++) { |
| y_ptr[i * D * H + j] = hx[i][j]; |
| back_ht_1[i * D * H + j] = hx[N + i][j]; |
| } |
| } |
| |
| Tensor<cpu, 2, DType> dgemmC1(ws, Shape2(T * N, 3 * H)); |
| Tensor<cpu, 2, DType> dgemmC2(gemmC2, Shape2(N, 3 * H)); |
| Tensor<cpu, 2, DType> dback_gemmC1(back_gemmC1, Shape2(T * N, 3 * H)); |
| |
| // x * wx.T : [T * N, I] * [I, 3 * H] |
| DType alpha = 1.0; |
| DType beta = 0.0; |
| linalg_gemm(x, wx, dgemmC1, alpha, beta, false, true); |
| if (D == 2) { |
| linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true); |
| } |
| |
| for (index_t t = 0; t < T; t++) { |
| // perform the first direction, X * wx and H * wh for each step |
| // ht-1 * wh, ht-1:[N, H] wh:[3 * H, H] |
| Tensor<cpu, 2, DType> dht_1(ht_1, Shape2(N, D * H)); |
| if (D == 1) { |
| linalg_gemm(dht_1, wh, dgemmC2, alpha, beta, false, true); |
| } else { |
| Tensor<cpu, 3, DType> dht_1_tmp = |
| Tensor<cpu, 3, DType>(reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N)); |
| dht_1_tmp = reshape(dht_1.T(), Shape3(D, H, N)); |
| linalg_gemm(dht_1_tmp[0], wh, dgemmC2, alpha, beta, true, true); |
| } |
| rt = gateR + t * N * H; |
| zt = gateZ + t * N * H; |
| nt = gateN + t * N * H; |
| gemmC1_t = gemmC1 + t * N * 3 * H; |
| DType* Mnht = Mnh + t * N * H; |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; ++i) { |
| for (int j = 0; j < H; ++j) { |
| index_t rtb = i * 3 * H; |
| index_t ztb = i * 3 * H + H; |
| index_t ntb = i * 3 * H + 2 * H; |
| Mnht[i * H + j] = gemmC2[ntb + j] + bh[2][j]; |
| rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] + gemmC2[rtb + j] + bx[0][j] + bh[0][j]); |
| zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] + gemmC2[ztb + j] + bx[1][j] + bh[1][j]); |
| nt[i * H + j] = |
| tanh(gemmC1_t[ntb + j] + bx[2][j] + rt[i * H + j] * (gemmC2[ntb + j] + bh[2][j])); |
| ht[i * D * H + j] = |
| (1 - zt[i * H + j]) * nt[i * H + j] + zt[i * H + j] * ht_1[i * D * H + j]; |
| } |
| } |
| ht_1 = ht; |
| ht = ht + D * H * N; |
| // perform the second direction |
| if (D == 2) { |
| rt = back_gateR + (T - 1 - t) * N * H; |
| zt = back_gateZ + (T - 1 - t) * N * H; |
| nt = back_gateN + (T - 1 - t) * N * H; |
| gemmC1_t = back_gemmC1 + (T - 1 - t) * N * 3 * H; |
| Tensor<cpu, 2, DType> dback_ht_1(back_ht_1 - H, Shape2(N, D * H)); |
| Tensor<cpu, 3, DType> dback_ht_1_tmp = |
| Tensor<cpu, 3, DType>(reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N)); |
| dback_ht_1_tmp = reshape(dback_ht_1.T(), Shape3(D, H, N)); |
| linalg_gemm(dback_ht_1_tmp[1], back_wh, dgemmC2, alpha, beta, true, true); |
| |
| DType* back_Mnht = back_Mnh + (T - 1 - t) * N * H; |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; ++i) { |
| for (int j = 0; j < H; ++j) { |
| index_t rtb = i * 3 * H; |
| index_t ztb = i * 3 * H + H; |
| index_t ntb = i * 3 * H + 2 * H; |
| back_Mnht[i * H + j] = gemmC2[ntb + j] + back_bh[2][j]; |
| rt[i * H + j] = |
| sigmoid(gemmC1_t[rtb + j] + gemmC2[rtb + j] + back_bx[0][j] + back_bh[0][j]); |
| zt[i * H + j] = |
| sigmoid(gemmC1_t[ztb + j] + gemmC2[ztb + j] + back_bx[1][j] + back_bh[1][j]); |
| nt[i * H + j] = tanh(gemmC1_t[ntb + j] + back_bx[2][j] + |
| rt[i * H + j] * (gemmC2[ntb + j] + back_bh[2][j])); |
| back_ht[i * D * H + j] = |
| (1 - zt[i * H + j]) * nt[i * H + j] + zt[i * H + j] * back_ht_1[i * D * H + j]; |
| } |
| } |
| back_ht_1 = back_ht; |
| back_ht = back_ht - D * H * N; |
| } |
| } |
| |
| // copy last state to hy, from(N, H * D) to (D, N, H) |
| if (state_outputs) { |
| if (D == 1) { |
| DType* y_start = y_ptr + (T - 1) * N * H; |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; i++) |
| for (int j = 0; j < H; j++) { |
| hy_ptr[i * H + j] = y_start[i * H + j]; |
| } |
| } else { |
| DType* y_start = y_ptr + (T - 1) * N * H * D; |
| DType* y_back_start = y_ptr + H; |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; i++) |
| for (int j = 0; j < H; j++) { |
| hy_ptr[i * H + j] = y_start[i * D * H + j]; |
| hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j]; |
| } |
| } |
| } |
| } |
| |
| template <typename DType> |
| void GruForwardTraining(DType* ws, |
| DType* rs, |
| bool state_outputs, |
| const int L, |
| const int D, |
| const index_t T, |
| const index_t N, |
| index_t I, |
| const int H, |
| DType* x_ptr, |
| DType* hx_ptr, |
| DType* w_ptr, |
| DType* y_ptr, |
| DType* hy_ptr, |
| const float dropout, |
| std::mt19937& rnd_engine) { // NOLINT(runtime/references) |
| DType* wx = w_ptr; |
| DType* wh = wx + I * H * 3; |
| DType* bx = |
| wh + H * H * 3 + (D - 1) * (H * H * 3 + I * H * 3) + (L - 1) * ((D + 1) * H) * H * 3 * D; |
| DType* bh = bx + H * 3; |
| Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(D * L, N, H)); |
| DType* hy_l = hy_ptr; |
| DType* gateR_l = rs; |
| DType* gateZ_l = gateR_l + L * T * D * N * H; |
| DType* gateN_l = gateZ_l + L * T * D * N * H; |
| DType* y_l = gateN_l + L * T * D * N * H; |
| DType* Mnh_l = y_l + L * T * N * H * D; |
| DType* dropout_random = Mnh_l + L * D * T * N * H; |
| DType* tmp_buf = dropout_random + (L - 1) * D * T * N * H; |
| DType* ws2 = tmp_buf + D * N * H; |
| DType* wx_l = wx; |
| DType* wh_l = wh; |
| DType* bx_l = bx; |
| DType* bh_l = bh; |
| DType* y_tmp = x_ptr; |
| for (int l = 0; l < L; l++) { |
| if (l != 0) { |
| y_tmp = y_l; |
| y_l = y_l + T * N * H * D; |
| } |
| if (dropout > 0.0f && l > 0) { |
| std::uniform_real_distribution<float> distribution(0, 1); |
| for (index_t i = 0; i < T * N * I; i++) { |
| if (distribution(rnd_engine) < dropout) { |
| dropout_random[(l - 1) * T * N * I + i] = 0; |
| y_tmp[i] = 0; |
| } else { |
| dropout_random[(l - 1) * T * N * I + i] = 1.0f - dropout; |
| y_tmp[i] = y_tmp[i] / (1.0f - dropout); |
| } |
| } |
| } |
| Tensor<cpu, 2, DType> x_l(y_tmp, Shape2(T * N, I)); |
| Tensor<cpu, 2, DType> hx_l = hx[D * l]; |
| GruForwardTrainingSingleLayer<DType>(ws2, |
| tmp_buf, |
| state_outputs, |
| D, |
| T, |
| N, |
| I, |
| H, |
| x_l, |
| hx_l, |
| wx_l, |
| wh_l, |
| bx_l, |
| bh_l, |
| gateR_l, |
| gateZ_l, |
| gateN_l, |
| Mnh_l, |
| y_l, |
| hy_l); |
| gateR_l = gateR_l + T * D * N * H; |
| gateZ_l = gateZ_l + T * D * N * H; |
| gateN_l = gateN_l + T * D * N * H; |
| Mnh_l = Mnh_l + T * D * N * H; |
| hy_l = hy_l + D * N * H; |
| bx_l = bx_l + 3 * H * D * 2; |
| bh_l = bh_l + 3 * H * D * 2; |
| |
| wx_l = wx_l + I * H * 3 * D + H * H * 3 * D; |
| if (l == 0) { |
| I = D * H; |
| } |
| wh_l = wx_l + I * 3 * H; |
| } |
| const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < T * N * H * D; ++i) { |
| y_ptr[i] = y_l[i]; |
| } |
| } |
| |
| template <typename DType> |
| void GruBackwardSingleLayer(DType* ws, |
| DType* tmp_buf, |
| const int D, |
| const index_t T, |
| const index_t N, |
| const index_t I, |
| const int H, |
| const Tensor<cpu, 2, DType>& x, |
| const Tensor<cpu, 2, DType>& hx, |
| DType* wx_ptr, |
| DType* wh_ptr, |
| DType* y_ptr, |
| DType* dy_ptr, |
| DType* dhy_ptr, |
| DType* gateR, |
| DType* gateZ, |
| DType* gateN, |
| DType* Mnh, |
| DType* dx, |
| DType* dhx, |
| DType* dwx, |
| DType* dwh, |
| DType* dbx, |
| DType* dbh, |
| int req_data, |
| int req_params, |
| int req_state) { |
| DType* dyt; |
| DType* ht1; // [N, D, H] |
| DType* rt; |
| DType* zt; |
| DType* nt; |
| DType* dat; |
| DType* dart; |
| DType* dar = ws; // [T, N, 3 * H] |
| DType* da = dar + T * N * 3 * H; // [T, N, 3 * H] |
| DType* dht1 = da + T * N * 3 * H; // [D, N, H] |
| DType* hx_ = dht1 + D * N * H; // [N, D, H] |
| DType* Mnht = Mnh; |
| DType* back_ht1; |
| DType* back_dht1 = dht1 + N * H; // [N, H] |
| DType* back_Mnht = Mnh + T * N * H; |
| DType* back_gateR = gateR + T * N * H; |
| DType* back_gateZ = gateZ + T * N * H; |
| DType* back_gateN = gateN + T * N * H; |
| DType* back_wx_ptr = wx_ptr + I * 3 * H + H * 3 * H; |
| DType* back_wh_ptr = wh_ptr + I * 3 * H + H * 3 * H; |
| DType* back_dwx = dwx + I * 3 * H + H * 3 * H; |
| DType* back_dwh = dwh + I * 3 * H + H * 3 * H; |
| DType* back_dbx = dbx + 3 * H * 2; |
| DType* back_dbh = dbh + 3 * H * 2; |
| |
| DType alpha = 1.0; |
| DType beta = 0.0; |
| const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H * 3, I)); |
| const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H * 3, H)); |
| const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H * 3, I)); |
| const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H * 3, H)); |
| const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); |
| if (req_params != kNullOp && req_params != kAddTo) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (int i = 0; i < D * H * 3 * H; ++i) { |
| dwh[i] = 0; |
| } |
| #pragma omp parallel for num_threads(omp_threads) |
| for (int i = 0; i < D * 3 * H; ++i) { |
| dbx[i] = 0; |
| dbh[i] = 0; |
| } |
| } |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N * H; ++i) { |
| if (dhy_ptr) { |
| dht1[i] = dhy_ptr[i]; |
| } else { |
| dht1[i] = 0; |
| } |
| } |
| |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; ++i) { |
| for (int j = 0; j < H; ++j) { |
| hx_[i * D * H + j] = hx[i][j]; |
| } |
| } |
| |
| if (D == 2) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N * H; ++i) { |
| if (dhy_ptr) { |
| back_dht1[i] = dhy_ptr[N * H + i]; |
| } else { |
| back_dht1[i] = 0; |
| } |
| } |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; ++i) { |
| for (int j = 0; j < H; ++j) { |
| hx_[i * D * H + H + j] = hx[N + i][j]; |
| } |
| } |
| } |
| for (index_t t = T - 1; t >= 0; --t) { |
| if (t) { |
| ht1 = y_ptr + (t - 1) * N * D * H; |
| } else { |
| ht1 = hx_; |
| } |
| // add dy[T, N, D, H] to dhy[D, N, H] |
| dyt = dy_ptr + t * N * D * H; |
| |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; ++i) { |
| for (int j = 0; j < H; ++j) { |
| dht1[i * H + j] += dyt[i * D * H + j]; |
| } |
| } |
| |
| rt = gateR + t * N * H; |
| zt = gateZ + t * N * H; |
| nt = gateN + t * N * H; |
| Mnht = Mnh + t * N * H; |
| dat = da + t * N * 3 * H; |
| dart = dar + t * N * 3 * H; |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; ++i) { |
| for (int j = 0; j < H; ++j) { |
| int nid = i * 3 * H + 2 * H + j; |
| int zid = i * 3 * H + H + j; |
| int rid = i * 3 * H + j; |
| int id = i * H + j; |
| dat[nid] = dht1[id] * (1 - zt[id]) * (1 - nt[id] * nt[id]); |
| dart[zid] = dat[zid] = dht1[id] * (ht1[i * D * H + j] - nt[id]) * zt[id] * (1 - zt[id]); |
| dart[rid] = dat[rid] = dat[nid] * Mnht[id] * rt[id] * (1 - rt[id]); |
| dart[nid] = dat[nid] * rt[id]; |
| dht1[id] = dht1[id] * zt[id]; |
| } |
| } |
| if (req_params != kNullOp) { |
| alpha = 1.0; |
| beta = 1.0; |
| // dht1 = dart * wh [N, H] = [N, 3 * H] * [3 * H, H] |
| Tensor<cpu, 2, DType> d_dht1(dht1, Shape2(N, H)); |
| Tensor<cpu, 2, DType> d_dart(dart, Shape2(N, 3 * H)); |
| linalg_gemm(d_dart, wh, d_dht1, alpha, beta, false, false); |
| |
| if (req_params == kAddTo) { |
| beta = 2.0; |
| // dwx = da.T * x [3 * H, I] = [3 * H, N] * [N, I] for AddTo |
| Tensor<cpu, 2, DType> d_xt(x.dptr_ + t * N * I, Shape2(N, I)); |
| Tensor<cpu, 2, DType> d_dat(dat, Shape2(N, 3 * H)); |
| Tensor<cpu, 2, DType> d_dwx(dwx, Shape2(3 * H, I)); |
| linalg_gemm(d_dat, d_xt, d_dwx, alpha, beta, true, false); |
| } |
| // dwh = dart.T * ht1 [3 * H, H] = [3 * H, N] * [N, H] |
| Tensor<cpu, 2, DType> d_ht1(ht1, Shape2(N, D * H)); |
| Tensor<cpu, 2, DType> d_dwh(dwh, Shape2(3 * H, H)); |
| Tensor<cpu, 3, DType> d_ht1_tmp = |
| Tensor<cpu, 3, DType>(reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N)); |
| d_ht1_tmp = reshape(d_ht1.T(), Shape3(D, H, N)); |
| linalg_gemm(d_dart, d_ht1_tmp[0], d_dwh, alpha, beta, true, true); |
| } |
| } |
| |
| if (req_params != kNullOp) { |
| // dbx = e * da [1, 3 * H] = [1, N] * [N, 3 * H] |
| if (req_params != kAddTo) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (int i = 0; i < 3 * H; ++i) { |
| for (index_t j = 0; j < N * T; ++j) { |
| dbx[i] += da[j * 3 * H + i]; |
| dbh[i] += dar[j * 3 * H + i]; |
| } |
| } |
| } else { |
| const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H * 3, T)); |
| const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + 3 * H * T, Shape2(H * 3, T)); |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < H * T * 3; ++i) { |
| tmp_dbx.dptr_[i] = 0; |
| tmp_dbh.dptr_[i] = 0; |
| } |
| |
| for (index_t t = T - 1; t >= 0; --t) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (int i = 0; i < 3 * H; ++i) { |
| for (index_t j = 0; j < N; ++j) { |
| tmp_dbx[i][t] += da[t * N * 3 * H + j * 3 * H + i]; |
| tmp_dbh[i][t] += dar[t * N * 3 * H + j * 3 * H + i]; |
| } |
| } |
| #pragma omp parallel for num_threads(omp_threads) |
| for (int i = 0; i < 3 * H; ++i) { |
| dbx[i] += tmp_dbx[i][t] + dbx[i]; |
| dbh[i] += tmp_dbh[i][t] + dbh[i]; |
| } |
| } |
| } |
| } |
| alpha = 1.0; |
| beta = 0.0; |
| |
| // dx = da * wx [T * N, I] = [T * N, 3 * H] * [3 * H, I] |
| Tensor<cpu, 2, DType> d_da(da, Shape2(T * N, 3 * H)); |
| if (req_data != kNullOp) { |
| Tensor<cpu, 2, DType> d_dx(dx, Shape2(T * N, I)); |
| linalg_gemm(d_da, wx, d_dx, alpha, beta, false, false); |
| } |
| |
| // dwx = da.T * x [3 * H, I] = [3 * H, T * N] * [T * N, I] |
| if (req_params != kNullOp && req_params != kAddTo) { |
| Tensor<cpu, 2, DType> d_dwx(dwx, Shape2(3 * H, I)); |
| linalg_gemm(d_da, x, d_dwx, alpha, beta, true, false); |
| } |
| |
| if (D == 2) { |
| for (index_t t = 0; t < T; ++t) { |
| if (t == T - 1) { |
| back_ht1 = hx_; |
| } else { |
| back_ht1 = y_ptr + (t + 1) * N * D * H; |
| } |
| |
| // add dy[T, N, D, H] to dhy[D, N, H] |
| dyt = dy_ptr + t * N * D * H; |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; ++i) { |
| for (int j = 0; j < H; ++j) { |
| back_dht1[i * H + j] += dyt[i * D * H + H + j]; |
| } |
| } |
| |
| rt = back_gateR + t * N * H; |
| zt = back_gateZ + t * N * H; |
| nt = back_gateN + t * N * H; |
| back_Mnht = Mnh + (T + t) * N * H; |
| dat = da + t * N * 3 * H; |
| dart = dar + t * N * 3 * H; |
| |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; ++i) { |
| for (int j = 0; j < H; ++j) { |
| index_t nid = i * 3 * H + 2 * H + j; |
| index_t zid = i * 3 * H + H + j; |
| index_t rid = i * 3 * H + j; |
| index_t id = i * H + j; |
| dat[nid] = back_dht1[id] * (1 - zt[id]) * (1 - nt[id] * nt[id]); |
| dart[zid] = dat[zid] = |
| back_dht1[id] * (back_ht1[i * D * H + H + j] - nt[id]) * zt[id] * (1 - zt[id]); |
| dart[rid] = dat[rid] = dat[nid] * back_Mnht[id] * rt[id] * (1 - rt[id]); |
| dart[nid] = dat[nid] * rt[id]; |
| back_dht1[id] = back_dht1[id] * zt[id]; |
| } |
| } |
| |
| if (req_params != kNullOp) { |
| alpha = 1.0; |
| beta = 1.0; |
| // dht1 = da * wh [N, H] = [N, 3 * H] * [3 * H, H] |
| Tensor<cpu, 2, DType> d_dart(dart, Shape2(N, 3 * H)); |
| Tensor<cpu, 2, DType> d_back_dht1(back_dht1, Shape2(N, H)); |
| linalg_gemm(d_dart, back_wh, d_back_dht1, alpha, beta, false, false); |
| |
| // dwh = da.T * ht1 [3 * H, H] = [3 * H, N] * [N, H] |
| Tensor<cpu, 2, DType> d_back_dwh(back_dwh, Shape2(3 * H, H)); |
| Tensor<cpu, 2, DType> d_back_ht1(back_ht1 + H, Shape2(N, D * H)); |
| Tensor<cpu, 3, DType> d_back_ht1_tmp = |
| Tensor<cpu, 3, DType>(reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N)); |
| d_back_ht1_tmp = reshape(d_back_ht1.T(), Shape3(D, H, N)); |
| if (req_params == kAddTo) { |
| beta = 2.0; |
| // dwx = da.T * x [3 * H, I] = [3 * H, N] * [N, I] for AddTo |
| Tensor<cpu, 2, DType> d_xt(x.dptr_ + t * N * I, Shape2(N, I)); |
| Tensor<cpu, 2, DType> d_dat(dat, Shape2(N, 3 * H)); |
| Tensor<cpu, 2, DType> d_back_dwx(back_dwx, Shape2(3 * H, I)); |
| linalg_gemm(d_dat, d_xt, d_back_dwx, alpha, beta, true, false); |
| } |
| linalg_gemm(d_dart, d_back_ht1_tmp[0], d_back_dwh, alpha, beta, true, true); |
| } |
| } |
| |
| if (req_params != kNullOp) { |
| // dbx = e * da [1, 3 * H] = [1, N] * [N, 3 * H] |
| if (req_params != kAddTo) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (int i = 0; i < 3 * H; ++i) { |
| for (index_t j = 0; j < N * T; ++j) { |
| back_dbx[i] += da[j * 3 * H + i]; |
| back_dbh[i] += dar[j * 3 * H + i]; |
| } |
| } |
| } else { |
| const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H * 3, T)); |
| const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + 3 * H * T, Shape2(H * 3, T)); |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < H * T * 3; ++i) { |
| tmp_dbx.dptr_[i] = 0; |
| tmp_dbh.dptr_[i] = 0; |
| } |
| for (index_t t = T - 1; t >= 0; --t) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (int i = 0; i < 3 * H; ++i) { |
| for (index_t j = 0; j < N; ++j) { |
| tmp_dbx[i][t] += da[t * N * 3 * H + j * 3 * H + i]; |
| tmp_dbh[i][t] += dar[t * N * 3 * H + j * 3 * H + i]; |
| } |
| } |
| #pragma omp parallel for num_threads(omp_threads) |
| for (int i = 0; i < 3 * H; ++i) { |
| back_dbx[i] += tmp_dbx[i][t] + back_dbx[i]; |
| back_dbh[i] += tmp_dbh[i][t] + back_dbh[i]; |
| } |
| } |
| } |
| } |
| alpha = 1.0; |
| beta = 1.0; |
| // dxt = da * wx [T * N, I] = [T * N, 3 * H] * [3 * H, I] |
| Tensor<cpu, 2, DType> d_da2(da, Shape2(T * N, 3 * H)); |
| if (req_data != kNullOp) { |
| Tensor<cpu, 2, DType> d_dx(dx, Shape2(T * N, I)); |
| linalg_gemm(d_da2, back_wx, d_dx, alpha, beta, false, false); |
| } |
| alpha = 1.0; |
| beta = 0.0; |
| // dwx = da.T * x [3 * H, I] = [3 * H, T * N] * [T * N, I] |
| if (req_params != kNullOp && req_params != kAddTo) { |
| Tensor<cpu, 2, DType> d_back_dwx(back_dwx, Shape2(3 * H, I)); |
| linalg_gemm(d_da2, x, d_back_dwx, alpha, beta, true, false); |
| } |
| } |
| if (req_state != kNullOp) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N * H * D; ++i) { |
| dhx[i] = dht1[i]; |
| } |
| } |
| } |
| |
| template <typename DType> |
| void GruBackward(DType* ws, |
| DType* rs, |
| const int L, |
| const int D, |
| const index_t T, |
| const index_t N, |
| index_t I, |
| const int H, |
| DType* x_ptr, |
| DType* hx_ptr, |
| DType* w_ptr, |
| DType* dy_ptr, |
| DType* dhy_ptr, |
| DType* dx_ptr, |
| DType* dhx_ptr, |
| DType* dw_ptr, |
| int req_data, |
| int req_params, |
| int req_state, |
| const float dropout) { |
| DType* wx = w_ptr; |
| DType* dwx = dw_ptr; |
| DType* dwh = dwx + I * H * 3; |
| DType* dbx = |
| dwh + H * H * 3 + (D - 1) * (H * H * 3 + I * H * 3) + (L - 1) * ((D + 1) * H) * H * 3 * D; |
| DType* gateR_l = rs + (L - 1) * T * D * N * H; |
| DType* gateZ_l = gateR_l + L * T * D * N * H; |
| DType* gateN_l = gateZ_l + L * T * D * N * H; |
| DType* y_l = gateN_l + L * T * D * N * H; |
| DType* Mnh_l = y_l + L * T * N * H * D; |
| DType* dropout_random = Mnh_l + L * D * T * N * H; |
| DType* tmp_buf = dropout_random + (L - 1) * D * T * N * H; |
| DType* dx_l = tmp_buf + T * N * D * H + 3 * H * T * 2; |
| DType* ws2 = dx_l + T * N * D * H; |
| DType* wx_l = |
| (L == 1) ? wx : wx + (L - 2) * D * (D + 1) * H * 3 * H + D * I * 3 * H + D * H * 3 * H; |
| DType* wh_l = wx_l; |
| if (L == 1) { |
| wh_l = wh_l + I * H * 3; |
| } else { |
| wh_l = wh_l + (D * H) * H * 3; |
| } |
| DType* dhy_l = nullptr; |
| if (dhy_ptr) |
| dhy_l = dhy_ptr + (L - 1) * D * N * H; |
| DType* dwx_l = |
| (L == 1) ? dwx : dwx + (L - 2) * D * (D + 1) * H * 3 * H + D * I * 3 * H + D * H * 3 * H; |
| DType* dwh_l = nullptr; |
| if (L == 1) { |
| dwh_l = dwx_l + I * H * 3; |
| } else { |
| dwh_l = dwx_l + (D * H) * H * 3; |
| } |
| DType* dbx_l = dbx + (L - 1) * D * 3 * H * 2; |
| DType* dbh_l = dbx_l + 3 * H; |
| DType* dhx_l = dhx_ptr + (L - 1) * D * N * H; |
| DType* dy_l = dy_ptr; |
| Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(L, D * N, H)); |
| index_t inputsize = I; |
| DType* y_tmp = y_l - T * N * H * D; |
| const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); |
| for (int l = L - 1; l >= 0; --l) { |
| if (l == 0) { |
| I = inputsize; |
| y_tmp = x_ptr; |
| dx_l = dx_ptr; |
| } else { |
| I = D * H; |
| } |
| Tensor<cpu, 2, DType> hx_l = hx[l]; |
| Tensor<cpu, 2, DType> x_l(y_tmp, Shape2(T * N, I)); |
| GruBackwardSingleLayer<DType>(ws2, |
| tmp_buf, |
| D, |
| T, |
| N, |
| I, |
| H, |
| x_l, |
| hx_l, |
| wx_l, |
| wh_l, |
| y_l, |
| dy_l, |
| dhy_l, |
| gateR_l, |
| gateZ_l, |
| gateN_l, |
| Mnh_l, |
| dx_l, |
| dhx_l, |
| dwx_l, |
| dwh_l, |
| dbx_l, |
| dbh_l, |
| req_data, |
| req_params, |
| req_state); |
| if (dropout > 0.0f && l > 0 && req_data != kNullOp) { |
| dropout_random = dropout_random - T * N * D * H; |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < T * N * I; i++) { |
| if (dropout_random[i] == 0) { |
| dx_l[i] = 0; |
| } else { |
| dx_l[i] = dx_l[i] / (1.0f - dropout); |
| } |
| } |
| } |
| if (l > 0) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < T * N * H * D; ++i) { |
| dy_l[i] = dx_l[i]; |
| } |
| gateR_l = gateR_l - T * D * N * H; |
| gateZ_l = gateZ_l - T * D * N * H; |
| gateN_l = gateN_l - T * D * N * H; |
| Mnh_l = Mnh_l - T * D * N * H; |
| dhx_l = dhx_l - D * N * H; |
| if (dhy_l) |
| dhy_l = dhy_l - D * N * H; |
| y_l = y_l - T * N * H * D; |
| y_tmp = y_tmp - T * N * H * D; |
| if (l == 1) { |
| wx_l = wx_l - (inputsize + H) * H * 3 * D; |
| wh_l = wx_l + inputsize * 3 * H; |
| dwx_l = dwx_l - (inputsize + H) * H * 3 * D; |
| dwh_l = dwx_l + inputsize * 3 * H; |
| } else { |
| wx_l = wx_l - (I + H) * H * 3 * D; |
| wh_l = wx_l + I * 3 * H; |
| dwx_l = dwx_l - (I + H) * H * 3 * D; |
| dwh_l = dwx_l + I * 3 * H; |
| } |
| dbx_l = dbx_l - D * 3 * H * 2; |
| dbh_l = dbx_l + 3 * H; |
| } |
| } |
| } |
| |
| template <typename DType> |
| void VanillaRNNForwardInferenceSingleLayer(DType* ws, |
| DType* tmp_buf, |
| bool state_outputs, |
| const int D, |
| const index_t T, |
| const index_t N, |
| const index_t I, |
| const int H, |
| const Tensor<cpu, 2, DType>& x, |
| const Tensor<cpu, 2, DType>& hx, |
| DType* wx_ptr, |
| DType* wh_ptr, |
| DType* bx_ptr, |
| DType* bh_ptr, |
| DType* y_ptr, |
| DType* hy_ptr, |
| int mode) { |
| DType* ht = y_ptr; |
| DType* ht_1 = y_ptr; |
| DType* back_ht_1 = y_ptr + (T - 1) * N * H * D + H; |
| DType* back_ht = back_ht_1; |
| DType* gemmC1 = ws; // [D, T, N, H] |
| DType* gemmC2 = gemmC1 + D * T * N * H; // N * H |
| DType* back_wx_ptr = wx_ptr + I * H + H * H; |
| DType* back_wh_ptr = wh_ptr + I * H + H * H; |
| DType* back_bx_ptr = (bx_ptr != nullptr) ? bx_ptr + H * 2 : nullptr; |
| DType* back_bh_ptr = (bh_ptr != nullptr) ? bh_ptr + H * 2 : nullptr; |
| DType* back_gemmC1 = gemmC1 + T * N * H; |
| DType* gemmC1_t = gemmC1; |
| |
| const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H, I)); |
| const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H, H)); |
| const Tensor<cpu, 2, DType> bx(bx_ptr, Shape2(1, H)); |
| const Tensor<cpu, 2, DType> bh(bh_ptr, Shape2(1, H)); |
| const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H, I)); |
| const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H, H)); |
| const Tensor<cpu, 2, DType> back_bx(back_bx_ptr, Shape2(1, H)); |
| const Tensor<cpu, 2, DType> back_bh(back_bh_ptr, Shape2(1, H)); |
| const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); |
| if (D == 1) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; i++) |
| for (int j = 0; j < H; j++) { |
| y_ptr[i * H + j] = hx[i][j]; |
| } |
| } else { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; i++) |
| for (int j = 0; j < H; j++) { |
| y_ptr[i * D * H + j] = hx[i][j]; |
| back_ht_1[i * D * H + j] = hx[N + i][j]; |
| } |
| } |
| Tensor<cpu, 2, DType> dgemmC1(ws, Shape2(T * N, H)); |
| Tensor<cpu, 2, DType> dgemmC2(gemmC2, Shape2(N, H)); |
| Tensor<cpu, 2, DType> dback_gemmC1(back_gemmC1, Shape2(T * N, H)); |
| |
| // x * wx.T : [T * N, I] * [I, H] |
| DType alpha = 1.0; |
| DType beta = 0.0; |
| linalg_gemm(x, wx, dgemmC1, alpha, beta, false, true); |
| if (D == 2) { |
| linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true); |
| } |
| |
| for (index_t t = 0; t < T; t++) { |
| // perform the first direction, X * wx and H * wh for each step |
| // ht-1 * wh, ht-1:[N, H] wh:[H, H] |
| Tensor<cpu, 2, DType> dht_1(ht_1, Shape2(N, D * H)); |
| if (D == 1) { |
| linalg_gemm(dht_1, wh, dgemmC2, alpha, beta, false, true); |
| } else { |
| Tensor<cpu, 3, DType> dht_1_tmp = |
| Tensor<cpu, 3, DType>(reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N)); |
| dht_1_tmp = reshape(dht_1.T(), Shape3(D, H, N)); |
| linalg_gemm(dht_1_tmp[0], wh, dgemmC2, alpha, beta, true, true); |
| } |
| gemmC1_t = gemmC1 + t * N * H; |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; ++i) { |
| for (int j = 0; j < H; ++j) { |
| index_t tb = i * H; |
| if (mode == 1) { |
| ht[i * D * H + j] = tanh(gemmC1_t[tb + j] + bx[0][j] + gemmC2[tb + j] + bh[0][j]); |
| } else { |
| ht[i * D * H + j] = relu(gemmC1_t[tb + j] + bx[0][j] + gemmC2[tb + j] + bh[0][j]); |
| } |
| } |
| } |
| ht_1 = ht; |
| ht = ht + D * H * N; |
| // perform the second direction |
| if (D == 2) { |
| gemmC1_t = back_gemmC1 + (T - 1 - t) * N * H; |
| Tensor<cpu, 2, DType> dback_ht_1(back_ht_1 - H, Shape2(N, D * H)); |
| Tensor<cpu, 3, DType> dback_ht_1_tmp = |
| Tensor<cpu, 3, DType>(reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N)); |
| dback_ht_1_tmp = reshape(dback_ht_1.T(), Shape3(D, H, N)); |
| linalg_gemm(dback_ht_1_tmp[1], back_wh, dgemmC2, alpha, beta, true, true); |
| |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; ++i) { |
| for (int j = 0; j < H; ++j) { |
| index_t tb = i * H; |
| if (mode == 1) { |
| back_ht[i * D * H + j] = |
| tanh(gemmC1_t[tb + j] + back_bx[0][j] + gemmC2[tb + j] + back_bh[0][j]); |
| } else { |
| back_ht[i * D * H + j] = |
| relu(gemmC1_t[tb + j] + back_bx[0][j] + gemmC2[tb + j] + back_bh[0][j]); |
| } |
| } |
| } |
| back_ht_1 = back_ht; |
| back_ht = back_ht - D * H * N; |
| } |
| } |
| // copy last state to hy, from(N, H * D) to (D, N, H) |
| if (state_outputs) { |
| if (D == 1) { |
| DType* y_start = y_ptr + (T - 1) * N * H; |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; i++) |
| for (int j = 0; j < H; j++) { |
| hy_ptr[i * H + j] = y_start[i * H + j]; |
| } |
| } else { |
| DType* y_start = y_ptr + (T - 1) * N * H * D; |
| DType* y_back_start = y_ptr + H; |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; i++) |
| for (int j = 0; j < H; j++) { |
| hy_ptr[i * H + j] = y_start[i * D * H + j]; |
| hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j]; |
| } |
| } |
| } |
| } |
| |
| template <typename DType> |
| void VanillaRNNForwardInference(DType* ws, |
| bool state_outputs, |
| const int L, |
| const int D, |
| const index_t T, |
| const index_t N, |
| index_t I, |
| const int H, |
| DType* x_ptr, |
| DType* hx_ptr, |
| DType* w_ptr, |
| DType* y_ptr, |
| DType* hy_ptr, |
| int mode) { |
| DType* wx = w_ptr; |
| DType* wh = wx + I * H; |
| DType* bx = wh + H * H + (D - 1) * (H * H + I * H) + (L - 1) * ((D + 1) * H) * H * D; |
| DType* bh = bx + H; |
| |
| DType* y_tmp = ws; |
| DType* y_l = x_ptr; |
| DType* tmp_buf = y_tmp + D * T * N * H; |
| DType* ws2 = y_tmp + D * T * N * H + D * H * N; |
| |
| DType* wx_l = wx; |
| DType* wh_l = wh; |
| DType* bx_l = bx; |
| DType* bh_l = bh; |
| Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(D * L, N, H)); |
| DType* hy_l = hy_ptr; |
| for (int l = 0; l < L; l++) { |
| Tensor<cpu, 2, DType> x_l(y_l, Shape2(T * N, I)); |
| if ((L + l) % 2) { |
| y_l = y_ptr; |
| } else { |
| y_l = y_tmp; |
| } |
| Tensor<cpu, 2, DType> hx_l = hx[D * l]; |
| VanillaRNNForwardInferenceSingleLayer<DType>(ws2, |
| tmp_buf, |
| state_outputs, |
| D, |
| T, |
| N, |
| I, |
| H, |
| x_l, |
| hx_l, |
| wx_l, |
| wh_l, |
| bx_l, |
| bh_l, |
| y_l, |
| hy_l, |
| mode); |
| hy_l = hy_l + D * N * H; |
| bx_l = bx_l + H * D * 2; |
| bh_l = bh_l + H * D * 2; |
| wx_l = wx_l + I * H * D + H * H * D; |
| if (l == 0) { |
| I = D * H; |
| } |
| wh_l = wx_l + I * H; |
| } |
| } |
| |
| template <typename DType> |
| void VanillaRNNForwardTrainingSingleLayer(DType* ws, |
| DType* tmp_buf, |
| bool state_outputs, |
| const int D, |
| const index_t T, |
| const index_t N, |
| const index_t I, |
| const int H, |
| const Tensor<cpu, 2, DType>& x, |
| const Tensor<cpu, 2, DType>& hx, |
| DType* wx_ptr, |
| DType* wh_ptr, |
| DType* bx_ptr, |
| DType* bh_ptr, |
| DType* gateN, |
| DType* y_ptr, |
| DType* hy_ptr, |
| int mode) { |
| DType* ht = y_ptr; |
| DType* ht_1 = y_ptr; |
| DType* back_ht_1 = y_ptr + (T - 1) * N * H * D + H; |
| DType* back_ht = back_ht_1; |
| |
| DType* gemmC1 = ws; // [D, T, N, H] |
| DType* gemmC2 = gemmC1 + D * T * N * H; // N * H |
| DType* nt = gateN; |
| DType* back_wx_ptr = wx_ptr + I * H + H * H; |
| DType* back_wh_ptr = wh_ptr + I * H + H * H; |
| DType* back_bx_ptr = (bx_ptr != nullptr) ? bx_ptr + H * 2 : nullptr; |
| DType* back_bh_ptr = (bh_ptr != nullptr) ? bh_ptr + H * 2 : nullptr; |
| DType* back_gateN = gateN + T * N * H; |
| DType* back_gemmC1 = gemmC1 + T * N * H; |
| DType* gemmC1_t = gemmC1; |
| |
| const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H, I)); |
| const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H, H)); |
| const Tensor<cpu, 2, DType> bx(bx_ptr, Shape2(1, H)); |
| const Tensor<cpu, 2, DType> bh(bh_ptr, Shape2(1, H)); |
| const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H * 1, I)); |
| const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H * 1, H)); |
| const Tensor<cpu, 2, DType> back_bx(back_bx_ptr, Shape2(1, H)); |
| const Tensor<cpu, 2, DType> back_bh(back_bh_ptr, Shape2(1, H)); |
| const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); |
| if (D == 1) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; i++) |
| for (int j = 0; j < H; j++) { |
| y_ptr[i * H + j] = hx[i][j]; |
| } |
| } else { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; i++) |
| for (int j = 0; j < H; j++) { |
| y_ptr[i * D * H + j] = hx[i][j]; |
| back_ht_1[i * D * H + j] = hx[N + i][j]; |
| } |
| } |
| |
| Tensor<cpu, 2, DType> dgemmC1(ws, Shape2(T * N, H)); |
| Tensor<cpu, 2, DType> dgemmC2(gemmC2, Shape2(N, H)); |
| Tensor<cpu, 2, DType> dback_gemmC1(back_gemmC1, Shape2(T * N, H)); |
| |
| // x * wx.T : [T * N, I] * [I, H] |
| DType alpha = 1.0; |
| DType beta = 0.0; |
| linalg_gemm(x, wx, dgemmC1, alpha, beta, false, true); |
| if (D == 2) { |
| linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true); |
| } |
| |
| for (index_t t = 0; t < T; t++) { |
| // perform the first direction, X * wx and H * wh for each step |
| // ht-1 * wh, ht-1:[N, H] wh:[H, H] |
| Tensor<cpu, 2, DType> dht_1(ht_1, Shape2(N, D * H)); |
| if (D == 1) { |
| linalg_gemm(dht_1, wh, dgemmC2, alpha, beta, false, true); |
| } else { |
| Tensor<cpu, 3, DType> dht_1_tmp = |
| Tensor<cpu, 3, DType>(reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N)); |
| dht_1_tmp = reshape(dht_1.T(), Shape3(D, H, N)); |
| linalg_gemm(dht_1_tmp[0], wh, dgemmC2, alpha, beta, true, true); |
| } |
| nt = gateN + t * N * H; |
| gemmC1_t = gemmC1 + t * N * H; |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; ++i) { |
| for (int j = 0; j < H; ++j) { |
| index_t tb = i * H; |
| if (mode == 1) { |
| nt[tb + j] = ht[i * D * H + j] = |
| tanh(gemmC1_t[tb + j] + bx[0][j] + gemmC2[tb + j] + bh[0][j]); |
| } else { |
| nt[tb + j] = gemmC1_t[tb + j] + bx[0][j] + gemmC2[tb + j] + bh[0][j]; |
| ht[i * D * H + j] = relu(nt[tb + j]); |
| } |
| } |
| } |
| ht_1 = ht; |
| ht = ht + D * H * N; |
| // perform the second direction |
| if (D == 2) { |
| nt = back_gateN + (T - 1 - t) * N * H; |
| gemmC1_t = back_gemmC1 + (T - 1 - t) * N * H; |
| Tensor<cpu, 2, DType> dback_ht_1(back_ht_1 - H, Shape2(N, D * H)); |
| Tensor<cpu, 3, DType> dback_ht_1_tmp = |
| Tensor<cpu, 3, DType>(reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N)); |
| dback_ht_1_tmp = reshape(dback_ht_1.T(), Shape3(D, H, N)); |
| linalg_gemm(dback_ht_1_tmp[1], back_wh, dgemmC2, alpha, beta, true, true); |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; ++i) { |
| for (int j = 0; j < H; ++j) { |
| index_t tb = i * H; |
| if (mode == 1) { |
| nt[tb + j] = back_ht[i * D * H + j] = |
| tanh(gemmC1_t[tb + j] + back_bx[0][j] + gemmC2[tb + j] + back_bh[0][j]); |
| } else { |
| nt[tb + j] = gemmC1_t[tb + j] + back_bx[0][j] + gemmC2[tb + j] + back_bh[0][j]; |
| back_ht[i * D * H + j] = relu(nt[tb + j]); |
| } |
| } |
| } |
| back_ht_1 = back_ht; |
| back_ht = back_ht - D * H * N; |
| } |
| } |
| |
| // copy last state to hy, from(N, H * D) to (D, N, H) |
| if (state_outputs) { |
| if (D == 1) { |
| DType* y_start = y_ptr + (T - 1) * N * H; |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; i++) |
| for (int j = 0; j < H; j++) { |
| hy_ptr[i * H + j] = y_start[i * H + j]; |
| } |
| } else { |
| DType* y_start = y_ptr + (T - 1) * N * H * D; |
| DType* y_back_start = y_ptr + H; |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; i++) |
| for (int j = 0; j < H; j++) { |
| hy_ptr[i * H + j] = y_start[i * D * H + j]; |
| hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j]; |
| } |
| } |
| } |
| } |
| |
| template <typename DType> |
| void VanillaRNNForwardTraining(DType* ws, |
| DType* rs, |
| bool state_outputs, |
| const int L, |
| const int D, |
| const index_t T, |
| const index_t N, |
| index_t I, |
| const int H, |
| DType* x_ptr, |
| DType* hx_ptr, |
| DType* w_ptr, |
| DType* y_ptr, |
| DType* hy_ptr, |
| const float dropout, |
| int mode, |
| std::mt19937& rnd_engine) { // NOLINT(runtime/references) |
| DType* wx = w_ptr; |
| DType* wh = wx + I * H; |
| DType* bx = wh + H * H + (D - 1) * (H * H + I * H) + (L - 1) * ((D + 1) * H) * H * D; |
| DType* bh = bx + H; |
| Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(D * L, N, H)); |
| DType* hy_l = hy_ptr; |
| DType* gateN_l = rs; |
| DType* y_l = gateN_l + L * T * D * N * H; |
| DType* dropout_random = y_l + L * D * T * N * H; |
| DType* tmp_buf = dropout_random + (L - 1) * D * T * N * H; |
| DType* ws2 = tmp_buf + D * N * H; |
| DType* wx_l = wx; |
| DType* wh_l = wh; |
| DType* bx_l = bx; |
| DType* bh_l = bh; |
| DType* y_tmp = x_ptr; |
| const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); |
| for (int l = 0; l < L; l++) { |
| if (l != 0) { |
| y_tmp = y_l; |
| y_l = y_l + T * N * H * D; |
| } |
| if (dropout > 0.0f && l > 0) { |
| std::uniform_real_distribution<float> distribution(0, 1); |
| for (index_t i = 0; i < T * N * I; i++) { |
| if (distribution(rnd_engine) < dropout) { |
| dropout_random[(l - 1) * T * N * I + i] = 0; |
| y_tmp[i] = 0; |
| } else { |
| dropout_random[(l - 1) * T * N * I + i] = 1.0f - dropout; |
| y_tmp[i] = y_tmp[i] / (1.0f - dropout); |
| } |
| } |
| } |
| Tensor<cpu, 2, DType> x_l(y_tmp, Shape2(T * N, I)); |
| Tensor<cpu, 2, DType> hx_l = hx[D * l]; |
| VanillaRNNForwardTrainingSingleLayer<DType>(ws2, |
| tmp_buf, |
| state_outputs, |
| D, |
| T, |
| N, |
| I, |
| H, |
| x_l, |
| hx_l, |
| wx_l, |
| wh_l, |
| bx_l, |
| bh_l, |
| gateN_l, |
| y_l, |
| hy_l, |
| mode); |
| gateN_l = gateN_l + T * D * N * H; |
| hy_l = hy_l + D * N * H; |
| bx_l = bx_l + H * D * 2; |
| bh_l = bh_l + H * D * 2; |
| |
| wx_l = wx_l + I * H * D + H * H * D; |
| if (l == 0) { |
| I = D * H; |
| } |
| wh_l = wx_l + I * H; |
| } |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < T * N * H * D; ++i) { |
| y_ptr[i] = y_l[i]; |
| } |
| } |
| |
| template <typename DType> |
| void VanillaRNNBackwardSingleLayer(DType* ws, |
| DType* tmp_buf, |
| const int D, |
| const index_t T, |
| const index_t N, |
| const index_t I, |
| const int H, |
| const Tensor<cpu, 2, DType>& x, |
| const Tensor<cpu, 2, DType>& hx, |
| DType* wx_ptr, |
| DType* wh_ptr, |
| DType* y_ptr, |
| DType* dy_ptr, |
| DType* dhy_ptr, |
| DType* gateN, |
| DType* dx, |
| DType* dhx, |
| DType* dwx, |
| DType* dwh, |
| DType* dbx, |
| DType* dbh, |
| int req_data, |
| int req_params, |
| int req_state, |
| int mode) { |
| DType* dyt; |
| DType* ht1; // [N, D, H] |
| DType* dart; |
| DType* nt; |
| DType* dar = ws; // [T, N, H] |
| DType* dht1 = dar + T * N * H; // [D, N, H] |
| DType* hx_ = dht1 + D * N * H; // [N, D, H] |
| |
| DType* back_ht1; |
| DType* back_dht1 = dht1 + N * H; // [N, H] |
| DType* back_gateN = gateN + T * N * H; |
| DType* back_wx_ptr = wx_ptr + I * H + H * H; |
| DType* back_wh_ptr = wh_ptr + I * H + H * H; |
| DType* back_dwx = dwx + I * H + H * H; |
| DType* back_dwh = dwh + I * H + H * H; |
| DType* back_dbx = dbx + H * 2; |
| DType* back_dbh = dbh + H * 2; |
| |
| DType alpha = 1.0; |
| DType beta = 0.0; |
| const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H, I)); |
| const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H, H)); |
| const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H, I)); |
| const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H, H)); |
| const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); |
| if (req_params != kNullOp && req_params != kAddTo) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (int i = 0; i < D * H * H; ++i) { |
| dwh[i] = 0; |
| } |
| #pragma omp parallel for num_threads(omp_threads) |
| for (int i = 0; i < D * H; ++i) { |
| dbx[i] = 0; |
| dbh[i] = 0; |
| } |
| } |
| |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N * H; ++i) { |
| if (dhy_ptr) { |
| dht1[i] = dhy_ptr[i]; |
| } else { |
| dht1[i] = 0; |
| } |
| } |
| |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; ++i) { |
| for (int j = 0; j < H; ++j) { |
| hx_[i * D * H + j] = hx[i][j]; |
| } |
| } |
| |
| if (D == 2) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N * H; ++i) { |
| if (dhy_ptr) { |
| back_dht1[i] = dhy_ptr[N * H + i]; |
| } else { |
| back_dht1[i] = 0; |
| } |
| } |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; ++i) { |
| for (int j = 0; j < H; ++j) { |
| hx_[i * D * H + H + j] = hx[N + i][j]; |
| } |
| } |
| } |
| for (index_t t = T - 1; t >= 0; --t) { |
| if (t) { |
| ht1 = y_ptr + (t - 1) * N * D * H; |
| } else { |
| ht1 = hx_; |
| } |
| // add dy[T, N, D, H] to dhy[D, N, H] |
| dyt = dy_ptr + t * N * D * H; |
| |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; ++i) { |
| for (int j = 0; j < H; ++j) { |
| dht1[i * H + j] += dyt[i * D * H + j]; |
| } |
| } |
| |
| nt = gateN + t * N * H; |
| dart = dar + t * N * H; |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; ++i) { |
| for (int j = 0; j < H; ++j) { |
| index_t id = i * H + j; |
| if (mode == 1) { |
| dart[id] = dht1[id] * (1 - nt[id] * nt[id]); |
| } else { |
| dart[id] = nt[id] > 0.0f ? static_cast<float>(dht1[id]) : 0.0f; |
| } |
| dht1[id] = 0; |
| } |
| } |
| if (req_params != kNullOp) { |
| alpha = 1.0; |
| beta = 1.0; |
| // dht1 = dart * wh [N, H] = [N, H] * [H, H] |
| Tensor<cpu, 2, DType> d_dht1(dht1, Shape2(N, H)); |
| Tensor<cpu, 2, DType> d_dart(dart, Shape2(N, H)); |
| linalg_gemm(d_dart, wh, d_dht1, alpha, beta, false, false); |
| |
| if (req_params == kAddTo) { |
| beta = 2.0; |
| // dwx = da.T * x [H, I] = [H, N] * [N, I] for AddTo |
| Tensor<cpu, 2, DType> d_xt(x.dptr_ + t * N * I, Shape2(N, I)); |
| Tensor<cpu, 2, DType> d_dwx(dwx, Shape2(H, I)); |
| linalg_gemm(d_dart, d_xt, d_dwx, alpha, beta, true, false); |
| } |
| // dwh = dart.T * ht1 [H, H] = [H, N] * [N, H] |
| Tensor<cpu, 2, DType> d_ht1(ht1, Shape2(N, D * H)); |
| Tensor<cpu, 2, DType> d_dwh(dwh, Shape2(H, H)); |
| Tensor<cpu, 3, DType> d_ht1_tmp = |
| Tensor<cpu, 3, DType>(reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N)); |
| d_ht1_tmp = reshape(d_ht1.T(), Shape3(D, H, N)); |
| linalg_gemm(d_dart, d_ht1_tmp[0], d_dwh, alpha, beta, true, true); |
| } |
| } |
| |
| if (req_params != kNullOp) { |
| // dbx = e * da [1, H] = [1, N] * [N, H] |
| if (req_params != kAddTo) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (int i = 0; i < H; ++i) { |
| for (index_t j = 0; j < N * T; ++j) { |
| dbx[i] += dar[j * H + i]; |
| dbh[i] = dbx[i]; |
| } |
| } |
| } else { |
| const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H, T)); |
| const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + H * T, Shape2(H, T)); |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < H * T; ++i) { |
| tmp_dbx.dptr_[i] = 0; |
| tmp_dbh.dptr_[i] = 0; |
| } |
| |
| for (index_t t = T - 1; t >= 0; --t) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (int i = 0; i < H; ++i) { |
| for (index_t j = 0; j < N; ++j) { |
| tmp_dbx[i][t] += dar[t * N * H + j * H + i]; |
| tmp_dbh[i][t] = tmp_dbx[i][t]; |
| } |
| } |
| #pragma omp parallel for num_threads(omp_threads) |
| for (int i = 0; i < H; ++i) { |
| dbx[i] += tmp_dbx[i][t] + dbx[i]; |
| dbh[i] = dbx[i]; |
| } |
| } |
| } |
| } |
| alpha = 1.0; |
| beta = 0.0; |
| |
| // dx = da * wx [T * N, I] = [T * N, H] * [H, I] |
| Tensor<cpu, 2, DType> d_dar(dar, Shape2(T * N, H)); |
| if (req_data != kNullOp) { |
| Tensor<cpu, 2, DType> d_dx(dx, Shape2(T * N, I)); |
| linalg_gemm(d_dar, wx, d_dx, alpha, beta, false, false); |
| } |
| |
| // dwx = da.T * x [H, I] = [H, T * N] * [T * N, I] |
| if (req_params != kNullOp && req_params != kAddTo) { |
| Tensor<cpu, 2, DType> d_dwx(dwx, Shape2(H, I)); |
| linalg_gemm(d_dar, x, d_dwx, alpha, beta, true, false); |
| } |
| |
| if (D == 2) { |
| for (index_t t = 0; t < T; ++t) { |
| if (t == T - 1) { |
| back_ht1 = hx_; |
| } else { |
| back_ht1 = y_ptr + (t + 1) * N * D * H; |
| } |
| |
| // add dy[T, N, D, H] to dhy[D, N, H] |
| dyt = dy_ptr + t * N * D * H; |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; ++i) { |
| for (int j = 0; j < H; ++j) { |
| back_dht1[i * H + j] += dyt[i * D * H + H + j]; |
| } |
| } |
| |
| nt = back_gateN + t * N * H; |
| dart = dar + t * N * H; |
| |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N; ++i) { |
| for (int j = 0; j < H; ++j) { |
| index_t id = i * H + j; |
| if (mode == 1) { |
| dart[id] = back_dht1[id] * (1 - nt[id] * nt[id]); |
| } else { |
| dart[id] = nt[id] > 0.0f ? static_cast<float>(back_dht1[id]) : 0.0f; |
| } |
| back_dht1[id] = 0; |
| } |
| } |
| |
| if (req_params != kNullOp) { |
| alpha = 1.0; |
| beta = 1.0; |
| // dht1 = da * wh [N, H] = [N, H] * [H, H] |
| Tensor<cpu, 2, DType> d_dart(dart, Shape2(N, H)); |
| Tensor<cpu, 2, DType> d_back_dht1(back_dht1, Shape2(N, H)); |
| linalg_gemm(d_dart, back_wh, d_back_dht1, alpha, beta, false, false); |
| |
| // dwh = da.T * ht1 [H, H] = [H, N] * [N, H] |
| Tensor<cpu, 2, DType> d_back_dwh(back_dwh, Shape2(H, H)); |
| Tensor<cpu, 2, DType> d_back_ht1(back_ht1 + H, Shape2(N, D * H)); |
| Tensor<cpu, 3, DType> d_back_ht1_tmp = |
| Tensor<cpu, 3, DType>(reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N)); |
| d_back_ht1_tmp = reshape(d_back_ht1.T(), Shape3(D, H, N)); |
| if (req_params == kAddTo) { |
| beta = 2.0; |
| // dwx = da.T * x [ H, I] = [H, N] * [N, I] for AddTo |
| Tensor<cpu, 2, DType> d_xt(x.dptr_ + t * N * I, Shape2(N, I)); |
| Tensor<cpu, 2, DType> d_back_dwx(back_dwx, Shape2(H, I)); |
| linalg_gemm(d_dart, d_xt, d_back_dwx, alpha, beta, true, false); |
| } |
| linalg_gemm(d_dart, d_back_ht1_tmp[0], d_back_dwh, alpha, beta, true, true); |
| } |
| } |
| |
| if (req_params != kNullOp) { |
| // dbx = e * da [1, H] = [1, N] * [N, H] |
| if (req_params != kAddTo) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (int i = 0; i < H; ++i) { |
| for (index_t j = 0; j < N * T; ++j) { |
| back_dbx[i] += dar[j * H + i]; |
| back_dbh[i] = back_dbx[i]; |
| } |
| } |
| } else { |
| const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H, T)); |
| const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + H * T, Shape2(H, T)); |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < H * T; ++i) { |
| tmp_dbx.dptr_[i] = 0; |
| tmp_dbh.dptr_[i] = 0; |
| } |
| |
| for (index_t t = T - 1; t >= 0; --t) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (int i = 0; i < H; ++i) { |
| for (index_t j = 0; j < N; ++j) { |
| tmp_dbx[i][t] += dar[t * N * H + j * H + i]; |
| tmp_dbh[i][t] = tmp_dbx[i][t]; |
| } |
| } |
| #pragma omp parallel for num_threads(omp_threads) |
| for (int i = 0; i < H; ++i) { |
| back_dbx[i] += tmp_dbx[i][t] + back_dbx[i]; |
| back_dbh[i] = back_dbx[i]; |
| } |
| } |
| } |
| } |
| alpha = 1.0; |
| beta = 1.0; |
| // dxt = da * wx [T * N, I] = [T * N, H] * [H, I] |
| Tensor<cpu, 2, DType> d_dar2(dar, Shape2(T * N, H)); |
| if (req_data != kNullOp) { |
| Tensor<cpu, 2, DType> d_dx(dx, Shape2(T * N, I)); |
| linalg_gemm(d_dar2, back_wx, d_dx, alpha, beta, false, false); |
| } |
| alpha = 1.0; |
| beta = 0.0; |
| // dwx = da.T * x [H, I] = [H, T * N] * [T * N, I] |
| if (req_params != kNullOp && req_params != kAddTo) { |
| Tensor<cpu, 2, DType> d_back_dwx(back_dwx, Shape2(H, I)); |
| linalg_gemm(d_dar2, x, d_back_dwx, alpha, beta, true, false); |
| } |
| } |
| if (req_state != kNullOp) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < N * H * D; ++i) { |
| dhx[i] = dht1[i]; |
| } |
| } |
| } |
| |
| template <typename DType> |
| void VanillaRNNBackward(DType* ws, |
| DType* rs, |
| const int L, |
| const int D, |
| const index_t T, |
| const index_t N, |
| index_t I, |
| const int H, |
| DType* x_ptr, |
| DType* hx_ptr, |
| DType* w_ptr, |
| DType* dy_ptr, |
| DType* dhy_ptr, |
| DType* dx_ptr, |
| DType* dhx_ptr, |
| DType* dw_ptr, |
| int req_data, |
| int req_params, |
| int req_state, |
| const float dropout, |
| int mode) { |
| DType* wx = w_ptr; |
| DType* dwx = dw_ptr; |
| DType* dwh = dwx + I * H; |
| DType* dbx = dwh + H * H + (D - 1) * (H * H + I * H) + (L - 1) * ((D + 1) * H) * H * D; |
| DType* gateN_l = rs + (L - 1) * T * D * N * H; |
| DType* y_l = gateN_l + L * T * D * N * H; |
| DType* dropout_random = y_l + L * D * T * N * H; |
| DType* tmp_buf = dropout_random + (L - 1) * D * T * N * H; |
| DType* dx_l = tmp_buf + T * N * D * H + H * T * 2; |
| DType* ws2 = dx_l + T * N * D * H; |
| DType* wx_l = (L == 1) ? wx : wx + (L - 2) * D * (D + 1) * H * H + D * I * H + D * H * H; |
| DType* wh_l = wx_l; |
| if (L == 1) { |
| wh_l = wh_l + I * H; |
| } else { |
| wh_l = wh_l + (D * H) * H; |
| } |
| DType* dhy_l = nullptr; |
| if (dhy_ptr) |
| dhy_l = dhy_ptr + (L - 1) * D * N * H; |
| DType* dwx_l = (L == 1) ? dwx : dwx + (L - 2) * D * (D + 1) * H * H + D * I * H + D * H * H; |
| DType* dwh_l = nullptr; |
| if (L == 1) { |
| dwh_l = dwx_l + I * H; |
| } else { |
| dwh_l = dwx_l + (D * H) * H; |
| } |
| DType* dbx_l = dbx + (L - 1) * D * H * 2; |
| DType* dbh_l = dbx_l + H; |
| DType* dhx_l = dhx_ptr + (L - 1) * D * N * H; |
| DType* dy_l = dy_ptr; |
| Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(L, D * N, H)); |
| index_t inputsize = I; |
| DType* y_tmp = y_l - T * N * H * D; |
| const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); |
| for (int l = L - 1; l >= 0; --l) { |
| if (l == 0) { |
| I = inputsize; |
| y_tmp = x_ptr; |
| dx_l = dx_ptr; |
| } else { |
| I = D * H; |
| } |
| Tensor<cpu, 2, DType> hx_l = hx[l]; |
| Tensor<cpu, 2, DType> x_l(y_tmp, Shape2(T * N, I)); |
| VanillaRNNBackwardSingleLayer<DType>(ws2, |
| tmp_buf, |
| D, |
| T, |
| N, |
| I, |
| H, |
| x_l, |
| hx_l, |
| wx_l, |
| wh_l, |
| y_l, |
| dy_l, |
| dhy_l, |
| gateN_l, |
| dx_l, |
| dhx_l, |
| dwx_l, |
| dwh_l, |
| dbx_l, |
| dbh_l, |
| req_data, |
| req_params, |
| req_state, |
| mode); |
| if (dropout > 0.0f && l > 0 && req_data != kNullOp) { |
| dropout_random = dropout_random - T * N * D * H; |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < T * N * I; i++) { |
| if (dropout_random[i] == 0) { |
| dx_l[i] = 0; |
| } else { |
| dx_l[i] = dx_l[i] / (1.0f - dropout); |
| } |
| } |
| } |
| if (l > 0) { |
| #pragma omp parallel for num_threads(omp_threads) |
| for (index_t i = 0; i < T * N * H * D; ++i) { |
| dy_l[i] = dx_l[i]; |
| } |
| gateN_l = gateN_l - T * D * N * H; |
| dhx_l = dhx_l - D * N * H; |
| if (dhy_l) |
| dhy_l = dhy_l - D * N * H; |
| y_l = y_l - T * N * H * D; |
| y_tmp = y_l; |
| if (l == 1) { |
| wx_l = wx_l - (inputsize + H) * H * D; |
| wh_l = wx_l + inputsize * H; |
| dwx_l = dwx_l - (inputsize + H) * H * D; |
| dwh_l = dwx_l + inputsize * H; |
| } else { |
| wx_l = wx_l - (I + H) * H * D; |
| wh_l = wx_l + I * H; |
| dwx_l = dwx_l - (I + H) * H * D; |
| dwh_l = dwx_l + I * H; |
| } |
| dbx_l = dbx_l - D * H * 2; |
| dbh_l = dbx_l + H; |
| } |
| } |
| } |
| |
| } // namespace op |
| } // namespace mxnet |
| #endif // MXNET_OPERATOR_RNN_IMPL_H_ |