| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file transformer.cc |
| * \brief CPU implementation of the operators used in Transformer |
| */ |
| #include <mxnet/base.h> |
| #include "./transformer-inl.h" |
| #include "../tensor/elemwise_unary_op.h" |
| |
| namespace mxnet { |
| namespace op { |
| |
| DMLC_REGISTER_PARAMETER(InterleavedMatMulParam); |
| |
| static bool InterleavedMatMulSelfAttQKShape(const NodeAttrs& attrs, |
| mxnet::ShapeVector* in_shape, |
| mxnet::ShapeVector* out_shape) { |
| const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed); |
| CHECK_EQ(in_shape->size(), 1U) << "Input:[queries_keys_values] currently have, " |
| << in_shape->size() << " inputs"; |
| auto qkv_shape = in_shape->at(0); |
| CHECK_EQ(qkv_shape.ndim(), 3U) |
| << "Input queries_keys_values should be 3D in seq_length-batch-proj_dim, " |
| << "currently is: " << qkv_shape.ndim() << "D"; |
| out_shape->resize(1); |
| SHAPE_ASSIGN_CHECK( |
| *out_shape, 0, mxnet::TShape({params.heads * qkv_shape[1], qkv_shape[0], qkv_shape[0]})); |
| return true; |
| } |
| |
| static bool InterleavedMatMulSelfAttValAttShape(const NodeAttrs& attrs, |
| mxnet::ShapeVector* in_shape, |
| mxnet::ShapeVector* out_shape) { |
| CHECK_EQ(in_shape->size(), 2U) << "Input:[queries_keys_values, attention] currently have, " |
| << in_shape->size() << " inputs"; |
| auto qkv_shape = in_shape->at(0); |
| auto att_shape = in_shape->at(1); |
| CHECK_EQ(qkv_shape.ndim(), 3U) |
| << "Input queries_keys_values should be 3D in seq_length-batch-3*proj_dim, " |
| << "currently is: " << qkv_shape.ndim() << "D"; |
| CHECK_EQ(att_shape.ndim(), 3U) << "Input attention should be 3D in batch-seq_length-seq_length, " |
| << "currently is: " << att_shape.ndim() << "D"; |
| CHECK_EQ(qkv_shape[0], att_shape[1]) |
| << "queries_keys_values.shape[0] and attention.shape[1] should be the same, " |
| << "currently are " << qkv_shape[0] << " and " << att_shape[1]; |
| CHECK_EQ(qkv_shape[0], att_shape[2]) |
| << "queries_keys_values.shape[0] and attention.shape[2] should be the same, " |
| << "currently are " << qkv_shape[0] << " and " << att_shape[2]; |
| CHECK_EQ(qkv_shape[2] % 3, 0) << "queries_keys_values.shape[2] should be a multiple of 3, " |
| << "currently is " << qkv_shape[2]; |
| SHAPE_ASSIGN_CHECK(*out_shape, 0, mxnet::TShape({qkv_shape[0], qkv_shape[1], qkv_shape[2] / 3})); |
| return true; |
| } |
| |
| static bool InterleavedMatMulEncDecQKShape(const NodeAttrs& attrs, |
| mxnet::ShapeVector* in_shape, |
| mxnet::ShapeVector* out_shape) { |
| const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed); |
| CHECK_EQ(in_shape->size(), 2U) << "Input:[queries, keys_values], currently have " |
| << in_shape->size() << " inputs"; |
| auto q_shape = in_shape->at(0); |
| auto kv_shape = in_shape->at(1); |
| CHECK_EQ(q_shape.ndim(), 3U) << "Input queries should be 3D in seq_length-batch-proj_dim, " |
| << "currently is " << q_shape.ndim() << "D"; |
| CHECK_EQ(kv_shape.ndim(), 3U) << "Input queries should be 3D in seq_length-batch-2*proj_dim, " |
| << "currently is " << kv_shape.ndim() << "D"; |
| CHECK_EQ(q_shape[2] * 2, kv_shape[2]) |
| << "keys_values.shape[2] should be equal to queries.shape[2] * 2, " |
| << "currently are: " << kv_shape[2] << " and " << q_shape[2]; |
| CHECK_EQ(q_shape[1], kv_shape[1]) << "queries.shape[1] should be equal to keys_values.shape[1], " |
| << "currently are: " << q_shape[1] << " and " << kv_shape[1]; |
| SHAPE_ASSIGN_CHECK( |
| *out_shape, 0, mxnet::TShape({q_shape[1] * params.heads, q_shape[0], kv_shape[0]})); |
| return true; |
| } |
| |
| static bool InterleavedMatMulEncDecValAttShape(const NodeAttrs& attrs, |
| mxnet::ShapeVector* in_shape, |
| mxnet::ShapeVector* out_shape) { |
| const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed); |
| CHECK_EQ(in_shape->size(), 2U) << "Input: [keys_values, attention], currently have " |
| << in_shape->size() << " inputs"; |
| auto kv_shape = in_shape->at(0); |
| auto att_shape = in_shape->at(1); |
| CHECK_EQ(kv_shape.ndim(), 3U) << "Input keys_values should be 3D in seq_length-batch-2*proj_dim, " |
| << "currently is " << kv_shape.ndim() << "D"; |
| CHECK_EQ(att_shape.ndim(), 3U) << "Input attention should be 3D in batch-seq_length-seq_length, " |
| << "currently is " << att_shape.ndim() << "D"; |
| CHECK_EQ(kv_shape[0], att_shape[2]) |
| << "keys_values.shape[0] should be equal to attention.shape[2], currently are " << kv_shape[0] |
| << " and " << att_shape[2]; |
| CHECK_EQ(kv_shape[1] * params.heads, att_shape[0]) |
| << "attention.shape[0] " |
| << "should be equal to keys_values.shape[1] * heads, currently are: " << att_shape[2] |
| << " and " << kv_shape[1]; |
| SHAPE_ASSIGN_CHECK(*out_shape, 0, mxnet::TShape({att_shape[1], kv_shape[1], kv_shape[2] / 2})); |
| return true; |
| } |
| |
| void strided_batch_sgemm(bool transA, |
| bool transB, |
| index_t m, |
| index_t n, |
| index_t k, |
| float alpha, |
| const float* a, |
| index_t lda, |
| index_t strideA, |
| const float* b, |
| index_t ldb, |
| index_t strideB, |
| float beta, |
| float* c, |
| index_t ldc, |
| index_t strideC, |
| int32_t batchCount) { |
| std::vector<const float*> pp_A(batchCount, nullptr); |
| std::vector<const float*> pp_B(batchCount, nullptr); |
| std::vector<float*> pp_C(batchCount, nullptr); |
| |
| for (int i = 0; i < batchCount; i++) { |
| pp_A[i] = a + i * strideA; |
| pp_B[i] = b + i * strideB; |
| pp_C[i] = c + i * strideC; |
| } |
| |
| #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) |
| const int GROUP_SIZE = 1; |
| MKL_INT p_m[GROUP_SIZE] = {static_cast<MKL_INT>(m)}; |
| MKL_INT p_n[GROUP_SIZE] = {static_cast<MKL_INT>(n)}; |
| MKL_INT p_k[GROUP_SIZE] = {static_cast<MKL_INT>(k)}; |
| MKL_INT p_lda[GROUP_SIZE] = {static_cast<MKL_INT>(lda)}; |
| MKL_INT p_ldb[GROUP_SIZE] = {static_cast<MKL_INT>(ldb)}; |
| MKL_INT p_ldc[GROUP_SIZE] = {static_cast<MKL_INT>(ldc)}; |
| |
| float p_alpha[GROUP_SIZE] = {alpha}; |
| float p_beta[GROUP_SIZE] = {beta}; |
| |
| CBLAS_TRANSPOSE cblas_a_trans = transA ? CblasTrans : CblasNoTrans; |
| CBLAS_TRANSPOSE cblas_b_trans = transB ? CblasTrans : CblasNoTrans; |
| |
| MKL_INT p_group_sizeb[GROUP_SIZE] = {static_cast<MKL_INT>(batchCount)}; |
| CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans}; |
| CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans}; |
| |
| cblas_sgemm_batch(CblasColMajor, |
| p_transa, |
| p_transb, |
| p_m, |
| p_n, |
| p_k, |
| p_alpha, |
| pp_A.data(), |
| p_lda, |
| pp_B.data(), |
| p_ldb, |
| p_beta, |
| pp_C.data(), |
| p_ldc, |
| GROUP_SIZE, |
| p_group_sizeb); |
| #else |
| for (int i = 0; i < batchCount; ++i) { |
| cblas_sgemm(CblasColMajor, |
| transA ? CblasTrans : CblasNoTrans, |
| transB ? CblasTrans : CblasNoTrans, |
| m, |
| n, |
| k, |
| alpha, |
| pp_A[i], |
| lda, |
| pp_B[i], |
| ldb, |
| beta, |
| pp_C[i], |
| ldc); |
| } |
| #endif |
| } |
| |
| void InterleavedMatMulSelfAttQKCPU(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed); |
| |
| if (req[0] == kNullOp) |
| return; |
| |
| CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) |
| << "Only FP32 is supported on CPU at the moment"; |
| |
| mshadow::Stream<cpu>* s = ctx.get_stream<cpu>(); |
| const float* queries_keys_values = inputs[0].FlatTo2D<cpu, float>(s).dptr_; |
| float* output = outputs[0].FlatTo2D<cpu, float>(s).dptr_; |
| |
| const index_t qkv_seq_len = inputs[0].shape_[0]; |
| const index_t sequences = inputs[0].shape_[1]; |
| const index_t output_lin_dim = inputs[0].shape_[2]; |
| const index_t embed_dim = output_lin_dim / 3; |
| const index_t head_dim = embed_dim / params.heads; |
| const index_t attn_batches = params.heads * sequences; |
| const index_t lead_dim = attn_batches * 3 * head_dim; |
| const index_t batch_stride = 3 * head_dim; |
| const float beta = req[0] == kAddTo ? 1.f : 0.f; |
| const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); |
| |
| strided_batch_sgemm(true, |
| false, |
| qkv_seq_len, |
| qkv_seq_len, |
| head_dim, |
| scale, |
| queries_keys_values + head_dim, |
| lead_dim, |
| batch_stride, |
| queries_keys_values, |
| lead_dim, |
| batch_stride, |
| beta, |
| output, |
| qkv_seq_len, |
| qkv_seq_len * qkv_seq_len, |
| attn_batches); |
| } |
| |
| void BackwardInterleavedMatMulSelfAttQKCPU(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed); |
| if (req[0] == kNullOp) |
| return; |
| |
| mshadow::Stream<cpu>* s = ctx.get_stream<cpu>(); |
| CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) |
| << "Only FP32 is supported on CPU at the moment"; |
| |
| const float* output_grads = inputs[0].FlatTo2D<cpu, float>(s).dptr_; |
| const float* queries_keys_values = inputs[1].FlatTo2D<cpu, float>(s).dptr_; |
| float* queries_keys_values_grads = outputs[0].FlatTo2D<cpu, float>(s).dptr_; |
| const index_t qkv_seq_len = inputs[1].shape_[0]; |
| const index_t sequences = inputs[1].shape_[1]; |
| const index_t output_lin_dim = inputs[1].shape_[2]; |
| const index_t embed_dim = output_lin_dim / 3; |
| const index_t head_dim = embed_dim / params.heads; |
| const index_t attn_batches = params.heads * sequences; |
| const index_t lead_dim = attn_batches * 3 * head_dim; |
| const index_t batch_stride = 3 * head_dim; |
| const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); |
| const float beta = req[0] == kAddTo ? 1.f : 0.f; |
| |
| if (req[0] == kWriteTo) { |
| memset(queries_keys_values_grads, 0, outputs[0].shape_.Size() * sizeof(float)); |
| } |
| |
| strided_batch_sgemm(false, |
| false, |
| head_dim, |
| qkv_seq_len, |
| qkv_seq_len, |
| scale, |
| queries_keys_values + head_dim, |
| lead_dim, |
| batch_stride, |
| output_grads, |
| qkv_seq_len, |
| qkv_seq_len * qkv_seq_len, |
| beta, |
| queries_keys_values_grads, |
| lead_dim, |
| batch_stride, |
| attn_batches); |
| |
| strided_batch_sgemm(false, |
| true, |
| head_dim, |
| qkv_seq_len, |
| qkv_seq_len, |
| scale, |
| queries_keys_values, |
| lead_dim, |
| batch_stride, |
| output_grads, |
| qkv_seq_len, |
| qkv_seq_len * qkv_seq_len, |
| beta, |
| queries_keys_values_grads + head_dim, |
| lead_dim, |
| batch_stride, |
| attn_batches); |
| } |
| |
| void InterleavedMatMulSelfAttValAttCPU(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed); |
| if (req[0] == kNullOp) |
| return; |
| |
| CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) |
| << "Only FP32 is supported on CPU at the moment"; |
| |
| mshadow::Stream<cpu>* s = ctx.get_stream<cpu>(); |
| const float* queries_keys_values = inputs[0].FlatTo2D<cpu, float>(s).dptr_; |
| const float* attention_maps = inputs[1].FlatTo2D<cpu, float>(s).dptr_; |
| float* output = outputs[0].FlatTo2D<cpu, float>(s).dptr_; |
| const index_t qkv_seq_len = inputs[0].shape_[0]; |
| const index_t sequences = inputs[0].shape_[1]; |
| const index_t output_lin_dim = inputs[0].shape_[2]; |
| const index_t embed_dim = output_lin_dim / 3; |
| const index_t head_dim = embed_dim / params.heads; |
| const index_t attn_batches = params.heads * sequences; |
| const index_t lead_dim = attn_batches * 3 * head_dim; |
| const index_t batch_stride = 3 * head_dim; |
| const float alpha = 1.f; |
| const float beta = req[0] == kAddTo ? 1.f : 0.f; |
| |
| strided_batch_sgemm(false, |
| false, |
| head_dim, |
| qkv_seq_len, |
| qkv_seq_len, |
| alpha, |
| queries_keys_values + 2 * head_dim, |
| lead_dim, |
| batch_stride, |
| attention_maps, |
| qkv_seq_len, |
| qkv_seq_len * qkv_seq_len, |
| beta, |
| output, |
| head_dim * attn_batches, |
| head_dim, |
| attn_batches); |
| } |
| |
| void BackwardInterleavedMatMulSelfAttValAttCPU(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed); |
| if (req[0] == kNullOp) |
| return; |
| |
| CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) |
| << "Only FP32 is supported on CPU at the moment"; |
| |
| mshadow::Stream<cpu>* s = ctx.get_stream<cpu>(); |
| const float* output_grads = inputs[0].FlatTo2D<cpu, float>(s).dptr_; |
| const float* queries_keys_values = inputs[1].FlatTo2D<cpu, float>(s).dptr_; |
| const float* attention_maps = inputs[2].FlatTo2D<cpu, float>(s).dptr_; |
| float* queries_keys_values_grads = outputs[0].FlatTo2D<cpu, float>(s).dptr_; |
| float* attention_maps_grads = outputs[1].FlatTo2D<cpu, float>(s).dptr_; |
| const index_t qkv_seq_len = inputs[1].shape_[0]; |
| const index_t sequences = inputs[1].shape_[1]; |
| const index_t output_lin_dim = inputs[1].shape_[2]; |
| const index_t embed_dim = output_lin_dim / 3; |
| const index_t head_dim = embed_dim / params.heads; |
| const index_t attn_batches = params.heads * sequences; |
| const index_t lead_dim = attn_batches * 3 * head_dim; |
| const index_t batch_stride = 3 * head_dim; |
| const float alpha = 1.f; |
| if (req[0] != kNullOp) { |
| if (req[0] == kWriteTo) { |
| memset(queries_keys_values_grads, 0, outputs[0].shape_.Size() * sizeof(float)); |
| } |
| |
| const float beta = req[0] == kAddTo ? 1.f : 0.f; |
| strided_batch_sgemm(false, |
| true, |
| head_dim, |
| qkv_seq_len, |
| qkv_seq_len, |
| alpha, |
| output_grads, |
| head_dim * attn_batches, |
| head_dim, |
| attention_maps, |
| qkv_seq_len, |
| qkv_seq_len * qkv_seq_len, |
| beta, |
| queries_keys_values_grads + 2 * head_dim, |
| lead_dim, |
| batch_stride, |
| attn_batches); |
| } |
| if (req[1] != kNullOp) { |
| const float beta = req[1] == kAddTo ? 1.f : 0.f; |
| strided_batch_sgemm(true, |
| false, |
| qkv_seq_len, |
| qkv_seq_len, |
| head_dim, |
| alpha, |
| queries_keys_values + 2 * head_dim, |
| lead_dim, |
| batch_stride, |
| output_grads, |
| head_dim * attn_batches, |
| head_dim, |
| beta, |
| attention_maps_grads, |
| qkv_seq_len, |
| qkv_seq_len * qkv_seq_len, |
| attn_batches); |
| } |
| } |
| |
| void InterleavedMatMulEncDecQKCPU(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed); |
| if (req[0] == kNullOp) |
| return; |
| |
| CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) |
| << "Only FP32 is supported on CPU at the moment"; |
| |
| mshadow::Stream<cpu>* s = ctx.get_stream<cpu>(); |
| const float* queries = inputs[0].FlatTo2D<cpu, float>(s).dptr_; |
| const float* keys_values = inputs[1].FlatTo2D<cpu, float>(s).dptr_; |
| float* output = outputs[0].FlatTo2D<cpu, float>(s).dptr_; |
| const index_t q_seq_len = inputs[0].shape_[0]; |
| const index_t sequences = inputs[0].shape_[1]; |
| const index_t output_lin_q_dim = inputs[0].shape_[2]; |
| const index_t kv_seq_len = inputs[1].shape_[0]; |
| const index_t embed_dim = output_lin_q_dim; |
| const index_t head_dim = embed_dim / params.heads; |
| const index_t attn_batches = params.heads * sequences; |
| const index_t lead_dim_q = attn_batches * head_dim; |
| const index_t lead_dim_kv = attn_batches * 2 * head_dim; |
| const index_t batch_stride_q = head_dim; |
| const index_t batch_stride_kv = head_dim * 2; |
| const float beta = req[0] == kAddTo ? 1.f : 0.f; |
| const float scale = 1.f / sqrt(static_cast<float>(head_dim)); |
| |
| strided_batch_sgemm(true, |
| false, |
| kv_seq_len, |
| q_seq_len, |
| head_dim, |
| scale, |
| keys_values, |
| lead_dim_kv, |
| batch_stride_kv, |
| queries, |
| lead_dim_q, |
| batch_stride_q, |
| beta, |
| output, |
| kv_seq_len, |
| kv_seq_len * q_seq_len, |
| attn_batches); |
| } |
| |
| void BackwardInterleavedMatMulEncDecQKCPU(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed); |
| if (req[0] == kNullOp) |
| return; |
| |
| CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) |
| << "Only FP32 is supported on CPU at the moment"; |
| |
| mshadow::Stream<cpu>* s = ctx.get_stream<cpu>(); |
| const float* output_grads = inputs[0].FlatTo2D<cpu, float>(s).dptr_; |
| const float* queries = inputs[1].FlatTo2D<cpu, float>(s).dptr_; |
| const float* keys_values = inputs[2].FlatTo2D<cpu, float>(s).dptr_; |
| float* queries_grads = outputs[0].FlatTo2D<cpu, float>(s).dptr_; |
| float* keys_values_grads = outputs[1].FlatTo2D<cpu, float>(s).dptr_; |
| const index_t q_seq_len = inputs[1].shape_[0]; |
| const index_t sequences = inputs[1].shape_[1]; |
| const index_t output_lin_q_dim = inputs[1].shape_[2]; |
| const index_t kv_seq_len = inputs[2].shape_[0]; |
| const index_t embed_dim = output_lin_q_dim; |
| const index_t head_dim = embed_dim / params.heads; |
| const index_t attn_batches = params.heads * sequences; |
| const index_t lead_dim_q = attn_batches * head_dim; |
| const index_t lead_dim_kv = attn_batches * 2 * head_dim; |
| const index_t batch_stride_q = head_dim; |
| const index_t batch_stride_kv = head_dim * 2; |
| const float scale = 1.f / sqrt(static_cast<float>(head_dim)); |
| |
| if (req[0] != kNullOp) { |
| const float beta = req[0] == kAddTo ? 1.f : 0.f; |
| strided_batch_sgemm(false, |
| false, |
| head_dim, |
| q_seq_len, |
| kv_seq_len, |
| scale, |
| keys_values, |
| lead_dim_kv, |
| batch_stride_kv, |
| output_grads, |
| kv_seq_len, |
| kv_seq_len * q_seq_len, |
| beta, |
| queries_grads, |
| lead_dim_q, |
| batch_stride_q, |
| attn_batches); |
| } |
| if (req[1] != kNullOp) { |
| if (req[1] == kWriteTo) { |
| memset(keys_values_grads, 0, outputs[1].shape_.Size() * sizeof(float)); |
| } |
| const float beta = req[1] == kAddTo ? 1.f : 0.f; |
| strided_batch_sgemm(false, |
| true, |
| head_dim, |
| kv_seq_len, |
| q_seq_len, |
| scale, |
| queries, |
| lead_dim_q, |
| batch_stride_q, |
| output_grads, |
| kv_seq_len, |
| kv_seq_len * q_seq_len, |
| beta, |
| keys_values_grads, |
| lead_dim_kv, |
| batch_stride_kv, |
| attn_batches); |
| } |
| } |
| |
| void InterleavedMatMulEncDecValAttCPU(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed); |
| if (req[0] == kNullOp) |
| return; |
| |
| CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) |
| << "Only FP32 is supported on CPU at the moment"; |
| |
| mshadow::Stream<cpu>* s = ctx.get_stream<cpu>(); |
| const float* keys_values = inputs[0].FlatTo2D<cpu, float>(s).dptr_; |
| const float* attention_maps = inputs[1].FlatTo2D<cpu, float>(s).dptr_; |
| float* output = outputs[0].FlatTo2D<cpu, float>(s).dptr_; |
| const index_t kv_seq_len = inputs[0].shape_[0]; |
| const index_t output_lin_kv_dim = inputs[0].shape_[2]; |
| const index_t attn_batches = inputs[1].shape_[0]; |
| const index_t q_seq_len = inputs[1].shape_[1]; |
| const index_t embed_dim = output_lin_kv_dim / 2; |
| const index_t head_dim = embed_dim / params.heads; |
| const index_t lead_dim_kv = attn_batches * head_dim * 2; |
| const index_t batch_stride_kv = 2 * head_dim; |
| const float alpha = 1.f; |
| const float beta = req[0] == kAddTo ? 1.f : 0.f; |
| |
| strided_batch_sgemm(false, |
| false, |
| head_dim, |
| q_seq_len, |
| kv_seq_len, |
| alpha, |
| keys_values + head_dim, |
| lead_dim_kv, |
| batch_stride_kv, |
| attention_maps, |
| kv_seq_len, |
| kv_seq_len * q_seq_len, |
| beta, |
| output, |
| head_dim * attn_batches, |
| head_dim, |
| attn_batches); |
| } |
| |
| void BackwardInterleavedMatMulEncDecValAttCPU(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed); |
| CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) |
| << "Only FP32 is supported on CPU at the moment"; |
| |
| mshadow::Stream<cpu>* s = ctx.get_stream<cpu>(); |
| const float* output_grads = inputs[0].FlatTo2D<cpu, float>(s).dptr_; |
| const float* keys_values = inputs[1].FlatTo2D<cpu, float>(s).dptr_; |
| const float* attention_maps = inputs[2].FlatTo2D<cpu, float>(s).dptr_; |
| float* keys_values_grads = outputs[0].FlatTo2D<cpu, float>(s).dptr_; |
| float* attention_maps_grads = outputs[1].FlatTo2D<cpu, float>(s).dptr_; |
| const index_t kv_seq_len = inputs[1].shape_[0]; |
| const index_t output_lin_kv_dim = inputs[1].shape_[2]; |
| const index_t attn_batches = inputs[2].shape_[0]; |
| const index_t q_seq_len = inputs[2].shape_[1]; |
| const index_t embed_dim = output_lin_kv_dim / 2; |
| const index_t head_dim = embed_dim / params.heads; |
| const index_t lead_dim_kv = attn_batches * head_dim * 2; |
| const index_t batch_stride_kv = 2 * head_dim; |
| const float alpha = 1.f; |
| |
| if (req[0] != kNullOp) { |
| if (req[0] == kWriteTo) { |
| memset(keys_values_grads, 0, outputs[0].shape_.Size() * sizeof(float)); |
| } |
| const float beta = req[0] == kAddTo ? 1.f : 0.f; |
| strided_batch_sgemm(false, |
| true, |
| head_dim, |
| kv_seq_len, |
| q_seq_len, |
| alpha, |
| output_grads, |
| head_dim * attn_batches, |
| head_dim, |
| attention_maps, |
| kv_seq_len, |
| kv_seq_len * q_seq_len, |
| beta, |
| keys_values_grads + head_dim, |
| lead_dim_kv, |
| batch_stride_kv, |
| attn_batches); |
| } |
| if (req[1] != kNullOp) { |
| const float beta = req[1] == kAddTo ? 1.f : 0.f; |
| strided_batch_sgemm(true, |
| false, |
| kv_seq_len, |
| q_seq_len, |
| head_dim, |
| alpha, |
| keys_values + head_dim, |
| lead_dim_kv, |
| batch_stride_kv, |
| output_grads, |
| head_dim * attn_batches, |
| head_dim, |
| beta, |
| attention_maps_grads, |
| kv_seq_len, |
| kv_seq_len * q_seq_len, |
| attn_batches); |
| } |
| } |
| |
| NNVM_REGISTER_OP(_contrib_interleaved_matmul_selfatt_qk) |
| .add_alias("_npx_interleaved_matmul_selfatt_qk") |
| .describe(R"code(Compute the matrix multiplication between the projections of |
| queries and keys in multihead attention use as self attention. |
| |
| the input must be a single tensor of interleaved projections |
| of queries, keys and values following the layout: |
| (seq_length, batch_size, num_heads * head_dim * 3) |
| |
| the equivalent code would be:: |
| |
| tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1)) |
| q_proj = mx.nd.transpose(tmp[:,:,:,0,:], axes=(1, 2, 0, 3)) |
| q_proj = mx.nd.reshape(q_proj, shape=(-1, 0, 0), reverse=True) |
| q_proj = mx.nd.contrib.div_sqrt_dim(q_proj) |
| k_proj = mx.nd.transpose(tmp[:,:,:,1,:], axes=(1, 2, 0, 3)) |
| k_proj = mx.nd.reshape(k_proj, shape=(-1, 0, 0), reverse=True) |
| output = mx.nd.batch_dot(q_proj, k_proj, transpose_b=True) |
| |
| )code" ADD_FILELINE) |
| .set_num_inputs(1) |
| .set_num_outputs(1) |
| .set_attr_parser(ParamParser<InterleavedMatMulParam>) |
| .set_attr<nnvm::FListInputNames>("FListInputNames", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::string>{"queries_keys_values"}; |
| }) |
| .set_attr<nnvm::FListOutputNames>("FListOutputNames", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::string>{"output"}; |
| }) |
| .set_attr<mxnet::FInferShape>("FInferShape", InterleavedMatMulSelfAttQKShape) |
| .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) |
| .set_attr<FCompute>("FCompute<cpu>", InterleavedMatMulSelfAttQKCPU) |
| .set_attr<nnvm::FGradient>("FGradient", |
| ElemwiseGradUseIn{"_backward_interleaved_matmul_selfatt_qk"}) |
| .add_argument("queries_keys_values", |
| "NDArray-or-Symbol", |
| "Interleaved queries, keys and values") |
| .add_arguments(InterleavedMatMulParam::__FIELDS__()); |
| |
| NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_qk) |
| .set_num_inputs(2) |
| .set_num_outputs(1) |
| .set_attr<nnvm::TIsBackward>("TIsBackward", true) |
| .set_attr_parser(ParamParser<InterleavedMatMulParam>) |
| .set_attr<FCompute>("FCompute<cpu>", BackwardInterleavedMatMulSelfAttQKCPU); |
| |
| NNVM_REGISTER_OP(_contrib_interleaved_matmul_selfatt_valatt) |
| .add_alias("_npx_interleaved_matmul_selfatt_valatt") |
| .describe(R"code(Compute the matrix multiplication between the projections of |
| values and the attention weights in multihead attention use as self attention. |
| |
| the inputs must be a tensor of interleaved projections |
| of queries, keys and values following the layout: |
| (seq_length, batch_size, num_heads * head_dim * 3) |
| |
| and the attention weights following the layout: |
| (batch_size, seq_length, seq_length) |
| |
| the equivalent code would be:: |
| |
| tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1)) |
| v_proj = mx.nd.transpose(tmp[:,:,:,2,:], axes=(1, 2, 0, 3)) |
| v_proj = mx.nd.reshape(v_proj, shape=(-1, 0, 0), reverse=True) |
| output = mx.nd.batch_dot(attention, v_proj) |
| output = mx.nd.reshape(output, shape=(-1, num_heads, 0, 0), reverse=True) |
| output = mx.nd.transpose(output, axes=(2, 0, 1, 3)) |
| output = mx.nd.reshape(output, shape=(0, 0, -1)) |
| |
| )code" ADD_FILELINE) |
| .set_num_inputs(2) |
| .set_num_outputs(1) |
| .set_attr_parser(ParamParser<InterleavedMatMulParam>) |
| .set_attr<nnvm::FListInputNames>( |
| "FListInputNames", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::string>{"queries_keys_values", "attention"}; |
| }) |
| .set_attr<nnvm::FListOutputNames>("FListOutputNames", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::string>{"output"}; |
| }) |
| .set_attr<mxnet::FInferShape>("FInferShape", InterleavedMatMulSelfAttValAttShape) |
| .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) |
| .set_attr<FCompute>("FCompute<cpu>", InterleavedMatMulSelfAttValAttCPU) |
| .set_attr<nnvm::FGradient>("FGradient", |
| ElemwiseGradUseIn{"_backward_interleaved_matmul_selfatt_valatt"}) |
| .add_argument("queries_keys_values", |
| "NDArray-or-Symbol", |
| "Queries, keys and values interleaved") |
| .add_argument("attention", "NDArray-or-Symbol", "Attention maps") |
| .add_arguments(InterleavedMatMulParam::__FIELDS__()); |
| |
| NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_valatt) |
| .set_num_inputs(3) |
| .set_num_outputs(2) |
| .set_attr<nnvm::TIsBackward>("TIsBackward", true) |
| .set_attr_parser(ParamParser<InterleavedMatMulParam>) |
| .set_attr<FCompute>("FCompute<cpu>", BackwardInterleavedMatMulSelfAttValAttCPU); |
| |
| NNVM_REGISTER_OP(_contrib_interleaved_matmul_encdec_qk) |
| .add_alias("_npx_interleaved_matmul_encdec_qk") |
| .describe(R"code(Compute the matrix multiplication between the projections of |
| queries and keys in multihead attention use as encoder-decoder. |
| |
| the inputs must be a tensor of projections of queries following the layout: |
| (seq_length, batch_size, num_heads * head_dim) |
| |
| and a tensor of interleaved projections of values and keys following the layout: |
| (seq_length, batch_size, num_heads * head_dim * 2) |
| |
| the equivalent code would be:: |
| |
| q_proj = mx.nd.transpose(queries, axes=(1, 2, 0, 3)) |
| q_proj = mx.nd.reshape(q_proj, shape=(-1, 0, 0), reverse=True) |
| q_proj = mx.nd.contrib.div_sqrt_dim(q_proj) |
| tmp = mx.nd.reshape(keys_values, shape=(0, 0, num_heads, 2, -1)) |
| k_proj = mx.nd.transpose(tmp[:,:,:,0,:], axes=(1, 2, 0, 3)) |
| k_proj = mx.nd.reshap(k_proj, shape=(-1, 0, 0), reverse=True) |
| output = mx.nd.batch_dot(q_proj, k_proj, transpose_b=True) |
| |
| )code" ADD_FILELINE) |
| .set_num_inputs(2) |
| .set_num_outputs(1) |
| .set_attr_parser(ParamParser<InterleavedMatMulParam>) |
| .set_attr<nnvm::FListInputNames>("FListInputNames", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::string>{"queries", "keys_values"}; |
| }) |
| .set_attr<nnvm::FListOutputNames>("FListOutputNames", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::string>{"output"}; |
| }) |
| .set_attr<mxnet::FInferShape>("FInferShape", InterleavedMatMulEncDecQKShape) |
| .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) |
| .set_attr<FCompute>("FCompute<cpu>", InterleavedMatMulEncDecQKCPU) |
| .set_attr<nnvm::FGradient>("FGradient", |
| ElemwiseGradUseIn{"_backward_interleaved_matmul_encdec_qk"}) |
| .add_argument("queries", "NDArray-or-Symbol", "Queries") |
| .add_argument("keys_values", "NDArray-or-Symbol", "Keys and values interleaved") |
| .add_arguments(InterleavedMatMulParam::__FIELDS__()); |
| |
| NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_qk) |
| .set_num_inputs(3) |
| .set_num_outputs(2) |
| .set_attr<nnvm::TIsBackward>("TIsBackward", true) |
| .set_attr_parser(ParamParser<InterleavedMatMulParam>) |
| .set_attr<FCompute>("FCompute<cpu>", BackwardInterleavedMatMulEncDecQKCPU); |
| |
| NNVM_REGISTER_OP(_contrib_interleaved_matmul_encdec_valatt) |
| .add_alias("_npx_interleaved_matmul_encdec_valatt") |
| .describe(R"code(Compute the matrix multiplication between the projections of |
| values and the attention weights in multihead attention use as encoder-decoder. |
| |
| the inputs must be a tensor of interleaved projections of |
| keys and values following the layout: |
| (seq_length, batch_size, num_heads * head_dim * 2) |
| |
| and the attention weights following the layout: |
| (batch_size, seq_length, seq_length) |
| |
| the equivalent code would be:: |
| |
| tmp = mx.nd.reshape(queries_keys_values, shape=(0, 0, num_heads, 3, -1)) |
| v_proj = mx.nd.transpose(tmp[:,:,:,1,:], axes=(1, 2, 0, 3)) |
| v_proj = mx.nd.reshape(v_proj, shape=(-1, 0, 0), reverse=True) |
| output = mx.nd.batch_dot(attention, v_proj, transpose_b=True) |
| output = mx.nd.reshape(output, shape=(-1, num_heads, 0, 0), reverse=True) |
| output = mx.nd.transpose(output, axes=(0, 2, 1, 3)) |
| output = mx.nd.reshape(output, shape=(0, 0, -1)) |
| |
| )code" ADD_FILELINE) |
| .set_num_inputs(2) |
| .set_num_outputs(1) |
| .set_attr_parser(ParamParser<InterleavedMatMulParam>) |
| .set_attr<nnvm::FListInputNames>("FListInputNames", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::string>{"keys_values", "attention"}; |
| }) |
| .set_attr<nnvm::FListOutputNames>("FListOutputNames", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::string>{"output"}; |
| }) |
| .set_attr<mxnet::FInferShape>("FInferShape", InterleavedMatMulEncDecValAttShape) |
| .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) |
| .set_attr<FCompute>("FCompute<cpu>", InterleavedMatMulEncDecValAttCPU) |
| .set_attr<nnvm::FGradient>("FGradient", |
| ElemwiseGradUseIn{"_backward_interleaved_matmul_encdec_valatt"}) |
| .add_argument("keys_values", "NDArray-or-Symbol", "Keys and values interleaved") |
| .add_argument("attention", "NDArray-or-Symbol", "Attention maps") |
| .add_arguments(InterleavedMatMulParam::__FIELDS__()); |
| |
| NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_valatt) |
| .set_num_inputs(3) |
| .set_num_outputs(2) |
| .set_attr<nnvm::TIsBackward>("TIsBackward", true) |
| .set_attr_parser(ParamParser<InterleavedMatMulParam>) |
| .set_attr<FCompute>("FCompute<cpu>", BackwardInterleavedMatMulEncDecValAttCPU); |
| |
| // relu |
| MXNET_OPERATOR_REGISTER_UNARY(_contrib_div_sqrt_dim) |
| .describe(R"code(Rescale the input by the square root of the channel dimension. |
| |
| out = data / sqrt(data.shape[-1]) |
| |
| )code" ADD_FILELINE) |
| .set_attr<FCompute>("FCompute<cpu>", DivSqrtDimForward_<cpu>) |
| .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_contrib_div_sqrt_dim"}); |
| |
| DMLC_REGISTER_PARAMETER(SldWinAttenParam); |
| |
| NNVM_REGISTER_OP(_contrib_sldwin_atten_mask_like) |
| .add_alias("_npx_sldwin_atten_mask_like") |
| .describe(R"code(Compute the mask for the sliding window attention score, used in |
| Longformer (https://arxiv.org/pdf/2004.05150.pdf). |
| |
| In this attention pattern, |
| given a fixed window size *2w*, each token attends to *w* tokens on the left side |
| if we use causal attention (setting *symmetric* to *False*), |
| otherwise each token attends to *w* tokens on each side. |
| |
| The shapes of the inputs are: |
| - *score* : |
| |
| - (batch_size, seq_length, num_heads, w + w + 1) if symmetric is True, |
| - (batch_size, seq_length, num_heads, w + 1) otherwise. |
| |
| - *dilation* : (num_heads,) |
| - *valid_length* : (batch_size,) |
| |
| The shape of the output is: |
| - *mask* : same as the shape of *score* |
| |
| )code" ADD_FILELINE) |
| .set_num_inputs(3) |
| .set_num_outputs(1) |
| .set_attr_parser(ParamParser<SldWinAttenParam>) |
| .set_attr<nnvm::FListInputNames>( |
| "FListInputNames", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::string>{"score", "dilation", "valid_length"}; |
| }) |
| .set_attr<nnvm::FListOutputNames>("FListOutputNames", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::string>{"mask"}; |
| }) |
| .set_attr<mxnet::FInferShape>("FInferShape", |
| [](const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 3U); |
| CHECK_EQ(out_attrs->size(), 1U); |
| const mxnet::TShape& dshape = (*in_attrs)[0]; |
| if (!shape_is_known(dshape)) |
| return false; |
| SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape); |
| return true; |
| }) |
| .set_attr<nnvm::FInferType>("FInferType", |
| [](const nnvm::NodeAttrs& attrs, |
| std::vector<int>* in_attrs, |
| std::vector<int>* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 3U); |
| CHECK_EQ(out_attrs->size(), 1U); |
| TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32); |
| return out_attrs->at(0) != -1; |
| }) |
| .set_attr<FCompute>("FCompute<cpu>", SldWinAttenMaskLikeForward<cpu>) |
| .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes) |
| .add_argument("score", "NDArray-or-Symbol", "sliding window attention score") |
| .add_argument("dilation", "NDArray-or-Symbol", "dilation") |
| .add_argument("valid_length", "NDArray-or-Symbol", "valid length") |
| .add_arguments(SldWinAttenParam::__FIELDS__()); |
| |
| NNVM_REGISTER_OP(_contrib_sldwin_atten_score) |
| .add_alias("_npx_sldwin_atten_score") |
| .describe(R"code(Compute the sliding window attention score, which is used in |
| Longformer (https://arxiv.org/pdf/2004.05150.pdf). In this attention pattern, |
| given a fixed window size *2w*, each token attends to *w* tokens on the left side |
| if we use causal attention (setting *symmetric* to *False*), |
| otherwise each token attends to *w* tokens on each side. |
| |
| The shapes of the inputs are: |
| - *query* : (batch_size, seq_length, num_heads, num_head_units) |
| - *key* : (batch_size, seq_length, num_heads, num_head_units) |
| - *dilation* : (num_heads,) |
| |
| The shape of the output is: |
| - *score* : |
| |
| - (batch_size, seq_length, num_heads, w + w + 1) if symmetric is True, |
| - (batch_size, seq_length, num_heads, w + 1) otherwise. |
| |
| )code" ADD_FILELINE) |
| .set_num_inputs(3) |
| .set_num_outputs(1) |
| .set_attr_parser(ParamParser<SldWinAttenParam>) |
| .set_attr<nnvm::FListInputNames>("FListInputNames", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::string>{"query", "key", "dilation"}; |
| }) |
| .set_attr<nnvm::FListOutputNames>("FListOutputNames", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::string>{"score"}; |
| }) |
| .set_attr<mxnet::FInferShape>("FInferShape", |
| [](const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* inshapes, |
| mxnet::ShapeVector* outshapes) { |
| unsigned int batch_size = inshapes->at(0)[0]; |
| unsigned int seq_length = inshapes->at(0)[1]; |
| unsigned int num_heads = inshapes->at(0)[2]; |
| unsigned int lhs_last_dim = inshapes->at(0)[3]; |
| unsigned int num_hidden = inshapes->at(1)[3]; |
| CHECK_EQ(lhs_last_dim, num_hidden); |
| CHECK_EQ(inshapes->at(2)[0], num_heads); |
| const SldWinAttenParam& param = |
| nnvm::get<SldWinAttenParam>(attrs.parsed); |
| unsigned int w_len = |
| param.symmetric ? (param.w + param.w + 1) : (param.w + 1); |
| outshapes->at(0) = |
| mshadow::Shape4(batch_size, seq_length, num_heads, w_len); |
| return true; |
| }) |
| .set_attr<nnvm::FInferType>("FInferType", |
| [](const nnvm::NodeAttrs& attrs, |
| std::vector<int>* in_attrs, |
| std::vector<int>* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 3U); |
| CHECK_EQ(out_attrs->size(), 1U); |
| TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32); |
| return out_attrs->at(0) != -1; |
| }) |
| .set_attr<FCompute>("FCompute<cpu>", SldWinAttenScoreForward<cpu>) |
| .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_sldwin_atten_score"}) |
| .add_argument("query", "NDArray-or-Symbol", "query") |
| .add_argument("key", "NDArray-or-Symbol", "key") |
| .add_argument("dilation", "NDArray-or-Symbol", "dilation") |
| .add_arguments(SldWinAttenParam::__FIELDS__()); |
| |
| NNVM_REGISTER_OP(_backward_sldwin_atten_score) |
| .set_num_inputs(4) |
| .set_num_outputs(3) |
| .set_attr<nnvm::TIsBackward>("TIsBackward", true) |
| .set_attr_parser(ParamParser<SldWinAttenParam>) |
| .set_attr<FCompute>("FCompute<cpu>", SldWinAttenScoreBackward<cpu>); |
| |
| NNVM_REGISTER_OP(_contrib_sldwin_atten_context) |
| .add_alias("_npx_sldwin_atten_context") |
| .describe(R"code(Compute the context vector for sliding window attention, used in |
| Longformer (https://arxiv.org/pdf/2004.05150.pdf). |
| |
| In this attention pattern, |
| given a fixed window size *2w*, each token attends to *w* tokens on the left side |
| if we use causal attention (setting *symmetric* to *False*), |
| otherwise each token attends to *w* tokens on each side. |
| |
| The shapes of the inputs are: |
| - *score* : |
| |
| - (batch_size, seq_length, num_heads, w + w + 1) if symmetric is True, |
| - (batch_size, seq_length, num_heads, w + 1) otherwise |
| |
| - *value* : (batch_size, seq_length, num_heads, num_head_units) |
| - *dilation* : (num_heads,) |
| |
| The shape of the output is: |
| - *context_vec* : (batch_size, seq_length, num_heads, num_head_units) |
| |
| )code" ADD_FILELINE) |
| .set_num_inputs(3) |
| .set_num_outputs(1) |
| .set_attr_parser(ParamParser<SldWinAttenParam>) |
| .set_attr<nnvm::FListInputNames>( |
| "FListInputNames", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::string>{"score", "value", "dilation"}; |
| }) |
| .set_attr<nnvm::FListOutputNames>("FListOutputNames", |
| [](const NodeAttrs& attrs) { |
| return std::vector<std::string>{"context_vec"}; |
| }) |
| .set_attr<mxnet::FInferShape>("FInferShape", |
| [](const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* inshapes, |
| mxnet::ShapeVector* outshapes) { |
| unsigned int batch_size = inshapes->at(0)[0]; |
| unsigned int seq_length = inshapes->at(0)[1]; |
| unsigned int num_heads = inshapes->at(0)[2]; |
| unsigned int lhs_last_dim = inshapes->at(0)[3]; |
| unsigned int num_hidden = inshapes->at(1)[3]; |
| CHECK_EQ(inshapes->at(2)[0], num_heads); |
| const SldWinAttenParam& param = |
| nnvm::get<SldWinAttenParam>(attrs.parsed); |
| unsigned int w_len = |
| param.symmetric ? (param.w + param.w + 1) : (param.w + 1); |
| CHECK_EQ(lhs_last_dim, w_len); |
| |
| outshapes->at(0) = mshadow::Shape4( |
| batch_size, seq_length, num_heads, num_hidden); |
| |
| return true; |
| }) |
| .set_attr<nnvm::FInferType>("FInferType", |
| [](const nnvm::NodeAttrs& attrs, |
| std::vector<int>* in_attrs, |
| std::vector<int>* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 3U); |
| CHECK_EQ(out_attrs->size(), 1U); |
| TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32); |
| return out_attrs->at(0) != -1; |
| }) |
| .set_attr<FCompute>("FCompute<cpu>", SldWinAttenContextForward<cpu>) |
| .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_sldwin_atten_context"}) |
| .add_argument("score", "NDArray-or-Symbol", "score") |
| .add_argument("value", "NDArray-or-Symbol", "value") |
| .add_argument("dilation", "NDArray-or-Symbol", "dilation") |
| .add_arguments(SldWinAttenParam::__FIELDS__()); |
| |
| NNVM_REGISTER_OP(_backward_sldwin_atten_context) |
| .set_num_inputs(4) |
| .set_num_outputs(3) |
| .set_attr<nnvm::TIsBackward>("TIsBackward", true) |
| .set_attr_parser(ParamParser<SldWinAttenParam>) |
| .set_attr<FCompute>("FCompute<cpu>", SldWinAttenContextBackward<cpu>); |
| |
| } // namespace op |
| } // namespace mxnet |