| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file indexing_op.h |
| * \brief Function definition of indexing operator |
| * \author Bing Xu, Siyi Li, Chi Zhang, Haibin Lin |
| */ |
| #ifndef MXNET_OPERATOR_TENSOR_INDEXING_OP_H_ |
| #define MXNET_OPERATOR_TENSOR_INDEXING_OP_H_ |
| |
| #include <dmlc/logging.h> |
| #include <dmlc/parameter.h> |
| #include <mxnet/operator.h> |
| #include <mxnet/operator_util.h> |
| #include <map> |
| #include <vector> |
| #include <string> |
| #include <utility> |
| #include <algorithm> |
| #include <type_traits> |
| #include "../operator_common.h" |
| #include "../mshadow_op.h" |
| #include "../elemwise_op_common.h" |
| #include "./util/tensor_util-inl.h" |
| #include "../mxnet_op.h" |
| #include "./sort_op.h" |
| #include "./init_op.h" |
| #include "../../engine/openmp.h" |
| #include "../../common/utils.h" |
| #ifdef __CUDACC__ |
| #include "./indexing_op-inl.cuh" |
| #endif |
| |
| namespace mxnet { |
| namespace op { |
| |
| namespace embedding { |
| enum EmbeddingOpInputs { kData, kWeight }; |
| enum EmbeddingOpOutputs { kOut }; |
| enum EmbeddingOpResource { kTempSpace }; |
| } // namespace embedding |
| |
| namespace quantized_embedding { |
| enum QuantizedEmbeddingOpInputs { kData, kWeight, kWeightMin, kWeightMax }; |
| enum QuantizedEmbeddingOpOutputs { kOut, kOutMin, kOutMax }; |
| enum QuantizedEmbeddingOpResource { kTempSpace }; |
| } // namespace quantized_embedding |
| |
| struct EmbeddingParam : public dmlc::Parameter<EmbeddingParam> { |
| index_t input_dim; |
| index_t output_dim; |
| int dtype; |
| bool sparse_grad; |
| DMLC_DECLARE_PARAMETER(EmbeddingParam) { |
| DMLC_DECLARE_FIELD(input_dim).set_lower_bound(1).describe( |
| "Vocabulary size of the input indices."); |
| DMLC_DECLARE_FIELD(output_dim) |
| .set_lower_bound(1) |
| .describe("Dimension of the embedding vectors."); |
| DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32) |
| MXNET_ADD_ALL_TYPES.describe("Data type of weight."); |
| DMLC_DECLARE_FIELD(sparse_grad) |
| .set_default(false) |
| .describe( |
| "Compute row sparse gradient in the backward calculation. If set to True, " |
| "the grad's storage type is row_sparse."); |
| } |
| void SetAttrDict(std::unordered_map<std::string, std::string>* dict) { |
| std::ostringstream input_dim_s, output_dim_s, dtype_s, sparse_grad_s; |
| input_dim_s << input_dim; |
| output_dim_s << output_dim; |
| dtype_s << dtype; |
| sparse_grad_s << sparse_grad; |
| (*dict)["input_dim"] = input_dim_s.str(); |
| (*dict)["output_dim"] = output_dim_s.str(); |
| (*dict)["sparse_grad"] = sparse_grad_s.str(); |
| (*dict)["dtype"] = MXNetTypeWithBool2String(dtype); |
| } |
| }; |
| |
| /*! |
| * \brief CPU/GPU: Return the amount of temporary storage in bytes required by |
| AddTakeGradLargeBatch |
| * \param num_items number of keys |
| */ |
| template <typename IndexType, typename xpu> |
| inline typename std::enable_if<std::is_same<xpu, cpu>::value, size_t>::type |
| AddTakeGradLargeBatchWorkspaceSize(size_t num_keys) { |
| return 0; |
| } |
| /*! |
| * \brief CPU/GPU: Return the amount of temporary storage in bytes required by |
| AddTakeGradLargeBatch |
| * \param num_items number of keys |
| */ |
| template <typename IndexType, typename xpu> |
| inline typename std::enable_if<std::is_same<xpu, gpu>::value, size_t>::type |
| AddTakeGradLargeBatchWorkspaceSize(size_t num_keys); |
| /*! |
| * \brief CPU/GPU: Gradient accumulate of embedding matrix. |
| dst[sorted[i]] += src[index[i]] |
| Called when the batchsize of src is larger than the featuredim |
| * \param dst destination |
| * \param sorted the sorted indices |
| * \param index original index of the sorted indices |
| * \param src source output |
| * \param workspace (optional) temporary storage |
| */ |
| template <typename IndexType, typename DType> |
| inline void AddTakeGradLargeBatch(mshadow::Tensor<cpu, 2, DType> dst, |
| const mshadow::Tensor<cpu, 1, IndexType>& sorted, |
| const mshadow::Tensor<cpu, 1, IndexType>& index, |
| const mshadow::Tensor<cpu, 2, DType>& src, |
| mshadow::Tensor<cpu, 1, char>* workspace = nullptr) { |
| for (index_t y = 0; y < sorted.size(0); ++y) { |
| dst[sorted[y]] += src[index[y]]; |
| } |
| } |
| template <typename ParamType> |
| inline bool EmbeddingOpShape(const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| using namespace mshadow; |
| const mxnet::TShape& dshape = (*in_attrs)[embedding::kData]; |
| if (!ndim_is_known(dshape)) |
| return false; |
| const ParamType& param = nnvm::get<ParamType>(attrs.parsed); |
| SHAPE_ASSIGN_CHECK(*in_attrs, embedding::kWeight, Shape2(param.input_dim, param.output_dim)); |
| out_attrs->clear(); |
| |
| mxnet::TShape oshape(dshape.ndim() + 1, -1); |
| for (int i = 0; i < dshape.ndim(); ++i) { |
| oshape[i] = dshape[i]; |
| } |
| oshape[dshape.ndim()] = param.output_dim; |
| |
| out_attrs->push_back(oshape); |
| return shape_is_known(oshape); |
| } |
| |
| template <typename ParamType> |
| inline bool EmbeddingOpType(const nnvm::NodeAttrs& attrs, |
| std::vector<int>* in_type, |
| std::vector<int>* out_type) { |
| const ParamType& param = nnvm::get<ParamType>(attrs.parsed); |
| CHECK_EQ(in_type->size(), 2U); |
| CHECK_GE(out_type->size(), 1U); |
| int itype = (*in_type)[0]; |
| CHECK_NE(itype, -1) << "First input must have specified type"; |
| int dtype_in = (*in_type)[1]; |
| int dtype_out = (*out_type)[0]; |
| int dtype = param.dtype; |
| if (dtype_in != -1 && dtype_out != -1) { |
| // Both types defined, make sure they are the same |
| CHECK_EQ(dtype_in, dtype_out) << "Input and output weights must have same type"; |
| dtype = dtype_in; |
| } else if (dtype_in != -1 || dtype_out != -1) { |
| // One of the types defined, choose the one that was defined |
| dtype = (dtype_in != -1) ? dtype_in : dtype_out; |
| } |
| if ((*in_type)[1] == -1) |
| (*in_type)[1] = dtype; |
| out_type->clear(); |
| out_type->push_back(dtype); |
| return true; |
| } |
| |
| // storage type inference function for _backward_Embedding |
| inline bool EmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs, |
| const int dev_mask, |
| DispatchMode* dispatch_mode, |
| std::vector<int>* in_attrs, |
| std::vector<int>* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 2U); |
| CHECK_EQ(out_attrs->size(), 2U); |
| const bool sparse_grad = nnvm::get<EmbeddingParam>(attrs.parsed).sparse_grad; |
| const NDArrayStorageType target_stype = sparse_grad ? kRowSparseStorage : kDefaultStorage; |
| const auto target_mode = sparse_grad ? DispatchMode::kFComputeEx : DispatchMode::kFCompute; |
| |
| const int ograd_stype = in_attrs->at(0); |
| const int data_stype = in_attrs->at(1); |
| int& data_grad_stype = out_attrs->at(0); |
| int& weight_grad_stype = out_attrs->at(1); |
| bool dispatched = false; |
| if (!dispatched && ograd_stype == kDefaultStorage && data_stype == kDefaultStorage) { |
| // dns, dns -> dns, dns/rsp |
| if (type_assign(&data_grad_stype, kDefaultStorage) && |
| type_assign(&weight_grad_stype, target_stype)) { |
| dispatched = dispatch_mode_assign(dispatch_mode, target_mode); |
| } |
| } |
| // Print user friendly error message to notify misuses of sparse_grad |
| if (weight_grad_stype != target_stype) { |
| LOG(FATAL) << "Cannot use sparse_grad = " << sparse_grad |
| << ", while stype of gradients w.r.t embedding weight is " |
| << common::stype_string(weight_grad_stype); |
| } |
| return dispatched; |
| } |
| |
| /*! \brief TakeNonzeroAxis is designated for general take when |
| * axis is not zero (for CPU optimized version use TakeNonZeroAxisCPU and |
| for axis zero use TakeZeroAxisGPU or TakeZeroAxisCPU) |
| */ |
| template <bool clip = true> |
| struct TakeNonzeroAxis { |
| /*! |
| * \brief Map function for take operator |
| * \param i global thread id |
| * \param out_data ptr to output buffer |
| * \param in_data ptr to input buffer |
| * \param idx ptr to indices buffer |
| * \param in_ndims # of dims of input tensor |
| * \param out_ndims # of dims of output tensor |
| * \param idx_ndims # of dims of indices tensor |
| * \param axis_dim dim size of the axis dimension |
| * \param axis axis id |
| */ |
| template <typename DType, typename IType> |
| MSHADOW_XINLINE static void Map(index_t i, |
| DType* out_data, |
| const DType* in_data, |
| const IType* idx, |
| const index_t out_prev_stride, |
| const index_t in_prev_stride, |
| const index_t in_stride, |
| const int in_ndims, |
| const int out_ndims, |
| const int idx_ndims, |
| const int axis_dim, |
| const int axis) { |
| // i is the global flattened index in the output |
| const index_t out_head_index = i / out_prev_stride; |
| const index_t out_rest_index = i % out_prev_stride; |
| const index_t out_mid_index = out_rest_index / in_stride; |
| const index_t out_tail_index = (axis == in_ndims - 1) ? 0 : (out_rest_index % in_stride); |
| index_t idx_index = static_cast<index_t>(idx[out_mid_index]); |
| if (clip) { |
| idx_index = (idx_index < 0) ? 0 : idx_index; |
| idx_index = (idx_index > axis_dim - 1) ? (axis_dim - 1) : idx_index; |
| } else { |
| idx_index %= axis_dim; |
| idx_index += (idx_index < 0) ? axis_dim : 0; |
| } |
| const index_t in_tail_index = out_tail_index; |
| const index_t in_head_index = out_head_index; |
| index_t in_src_index = in_tail_index + idx_index * in_stride; |
| in_src_index += in_head_index * in_prev_stride; |
| out_data[i] = in_data[in_src_index]; |
| } |
| }; |
| |
| // Embedding forward implementation with dense weight |
| template <typename xpu> |
| void EmbeddingOpForwardDnsImpl(mshadow::Stream<xpu>* s, |
| const TBlob& data, |
| const TBlob& weight, |
| const OpReqType req, |
| const TBlob& output); |
| |
| template <int req> |
| struct TakeRspKernel { |
| /*! |
| * \brief |
| * \param i thread id |
| * \param data input data |
| * \param out output |
| * \param weight_idx indices of rsp weight |
| * \param weight_data data of rsp weight |
| * \param row_length number of elements per row |
| * \param nnr number of non-zero rows |
| */ |
| template <typename DType, typename IType, typename RType> |
| MSHADOW_XINLINE static void Map(index_t i, |
| const IType* data, |
| DType* out, |
| const RType* weight_idx, |
| const DType* weight_data, |
| const nnvm::dim_t row_length, |
| const nnvm::dim_t nnr) { |
| using nnvm::dim_t; |
| const dim_t val = static_cast<dim_t>(data[i]); |
| const DType zero = 0; |
| // Use binary search to find the lower_bound of val in weight_idx array |
| // (adapted based on the binary search in dot kernel) |
| const RType* first = weight_idx; |
| const RType* last = weight_idx + nnr; |
| const RType* it; |
| dim_t count = last - first, step; |
| while (count > 0) { |
| it = first; |
| step = count / 2; |
| it += step; |
| if (*it < val) { |
| first = ++it; |
| count -= step + 1; |
| } else { |
| count = step; |
| } |
| } |
| // end of binary search |
| const dim_t idx_offset = first - weight_idx; |
| const dim_t out_offset = i * row_length; |
| const dim_t weight_offset = idx_offset * row_length; |
| // target idx might be missing in weight.idx. For example, |
| // weight.idx = [5,10] and data = [3,7], so binary search fails to |
| // find any matching indices in weight_idx. |
| if (idx_offset >= nnr || *(weight_idx + idx_offset) > val) { |
| // val not found, fill zeros |
| for (int j = 0; j < row_length; j++) { |
| KERNEL_ASSIGN(out[out_offset + j], req, zero); |
| } |
| } else { |
| for (int j = 0; j < row_length; j++) { |
| KERNEL_ASSIGN(out[out_offset + j], req, weight_data[weight_offset + j]); |
| } |
| } |
| } |
| }; |
| |
| template <typename xpu> |
| inline void EmbeddingOpForwardRspImpl(mshadow::Stream<xpu>* s, |
| const TBlob& data, |
| const NDArray& weight, |
| const OpReqType req, |
| const TBlob& output) { |
| using namespace mxnet_op; |
| using namespace rowsparse; |
| MSHADOW_TYPE_SWITCH(output.type_flag_, DType, { |
| MSHADOW_TYPE_SWITCH(data.type_flag_, IType, { |
| MSHADOW_TYPE_SWITCH(weight.aux_type(kIdx), RType, { |
| MXNET_ASSIGN_REQ_SWITCH(req, req_t, { |
| size_t data_size = data.shape_.Size(); |
| // only using the second dim since weight.ndim() == 2 |
| const nnvm::dim_t row_length = weight.shape()[1]; |
| Kernel<TakeRspKernel<req_t>, xpu>::Launch(s, |
| data_size, |
| data.dptr<IType>(), |
| output.dptr<DType>(), |
| weight.aux_data(kIdx).dptr<RType>(), |
| weight.data().dptr<DType>(), |
| row_length, |
| weight.aux_shape(kIdx)[0]); |
| }); |
| }); |
| }); |
| }); |
| } |
| |
| // Embedding forward implementation with row_sparse weight |
| template <typename xpu> |
| void SparseEmbeddingOpForwardRspImpl(const OpContext& ctx, |
| const TBlob& data, |
| const NDArray& weight, |
| const OpReqType req, |
| const TBlob& output); |
| |
| template <typename xpu> |
| void EmbeddingOpForward(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| CHECK_EQ(req[embedding::kOut], kWriteTo); |
| CHECK_EQ(inputs.size(), 2U); |
| CHECK_EQ(outputs.size(), 1U); |
| CHECK_EQ(inputs[embedding::kWeight].ndim(), 2U) |
| << "Embedding layer expects its weight to be two-dimensional. " |
| << inputs[embedding::kWeight].ndim() << " dimensional input is given instead"; |
| mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); |
| EmbeddingOpForwardDnsImpl<xpu>(s, |
| inputs[embedding::kData], |
| inputs[embedding::kWeight], |
| req[embedding::kOut], |
| outputs[embedding::kOut]); |
| } |
| |
| /*! \brief cast to type and clip to range [0, K - 1] |
| */ |
| struct tcast_clip { |
| template <typename OType, typename IType> |
| MSHADOW_XINLINE static void Map(int i, OType* out_data, const IType* in_data, const OType K) { |
| OType j = static_cast<OType>(in_data[i]); |
| if (j <= 0) |
| j = 0; |
| else if (j >= K) |
| j = K - 1; |
| out_data[i] = j; |
| } |
| }; |
| |
| template <typename xpu> |
| void EmbeddingOpBackward(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| using namespace mshadow; |
| using namespace mshadow::expr; |
| CHECK_EQ(inputs.size(), 2U); |
| CHECK_EQ(outputs.size(), 2U); |
| CHECK_EQ(req[embedding::kData], kNullOp) |
| << "Embedding layer doesn't support calculate data gradient"; |
| CHECK_EQ(outputs[1].type_flag_, inputs[0].type_flag_); |
| |
| const mxnet::TShape& ishape = inputs[1].shape_; |
| const mxnet::TShape& oshape = inputs[0].shape_; |
| |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| |
| bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true); |
| if (!safe_acc && outputs[1].type_flag_ == mshadow::kFloat16) { |
| common::LogOnce( |
| "MXNET_SAFE_ACCUMULATION=1 is recommended for EmbeddingOpBackward " |
| "with float16 inputs. " |
| "See https://mxnet.apache.org/api/faq/env_var " |
| "for more details."); |
| } |
| MXNET_REAL_ACC_TYPE_SWITCH(outputs[1].type_flag_, DType, AType, { |
| MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { |
| Tensor<xpu, 1, IType> data = |
| inputs[1].get_with_shape<xpu, 1, IType>(Shape1(ishape.ProdShape(0, ishape.ndim())), s); |
| Tensor<xpu, 2, DType> grad_out = inputs[0].get_with_shape<xpu, 2, DType>( |
| Shape2(oshape.ProdShape(0, oshape.ndim() - 1), oshape[oshape.ndim() - 1]), s); |
| Tensor<xpu, 2, DType> grad_in = outputs[1].get<xpu, 2, DType>(s); |
| |
| if (req[embedding::kWeight] == kWriteTo || req[embedding::kWeight] == kAddTo) { |
| if (req[embedding::kWeight] == kWriteTo) { |
| grad_in = scalar<DType>(0.0f); |
| } |
| if (safe_acc) { |
| // Temporary storage for safe accumulation |
| size_t temp_space_size = grad_in.size(0) * grad_in.size(1) * sizeof(AType); |
| Tensor<xpu, 1, char> temp_space = |
| ctx.requested[embedding::kTempSpace].get_space_typed<xpu, 1, char>( |
| Shape1(temp_space_size), s); |
| Tensor<xpu, 2, AType> temp_grad_in( |
| reinterpret_cast<AType*>(temp_space.dptr_), grad_in.shape_, s); |
| AddTakeGrad(grad_in, temp_grad_in, data, grad_out); |
| } else { |
| AddTakeGrad(grad_in, data, grad_out); |
| } |
| } else { |
| LOG(FATAL) << "wrong req"; |
| } |
| }); |
| }); |
| } |
| |
| struct AddTakeGradRspKernel { |
| /*! |
| * \brief Each thread i is responsible for row slices in [segment_start, segment_end) |
| of the result gradient |
| * \param tid global thread id |
| * \param grad the gradient to calculate |
| * \param prefix_sum the inclusive prefix sum of row ids of the gradient |
| * \param ograd output gradient |
| * \param row_length the length of the row slices of the gradient |
| * \param data_val the values of input data |
| * \param data_size number of values of input data |
| * \param segment_length the length of row segment to process for each thread |
| * \param nnr total number of non-zero rows of result gradient |
| */ |
| template <typename DType, typename IType> |
| MSHADOW_CINLINE static void Map(int tid, |
| DType* grad, |
| const nnvm::dim_t* prefix_sum, |
| const DType* ograd, |
| const nnvm::dim_t row_length, |
| const IType* data_val, |
| const nnvm::dim_t data_size, |
| const nnvm::dim_t segment_length, |
| const nnvm::dim_t nnr) { |
| using nnvm::dim_t; |
| dim_t segment_start = tid * segment_length; |
| dim_t segment_end = std::min(nnr, segment_start + segment_length); |
| // scan all data |
| for (dim_t data_i = 0; data_i < data_size; data_i++) { |
| dim_t data = static_cast<dim_t>(data_val[data_i]); |
| dim_t grad_row_id = prefix_sum[data] - 1; |
| if (grad_row_id < segment_start || grad_row_id >= segment_end) |
| continue; |
| // no projection is performed |
| dim_t ograd_i = data_i * row_length; |
| dim_t grad_i = grad_row_id * row_length; |
| for (dim_t offset = 0; offset < row_length; offset++) { |
| grad[grad_i + offset] += ograd[ograd_i + offset]; |
| } |
| } |
| } |
| }; |
| |
| template <typename xpu> |
| inline void SparseEmbeddingOpBackwardRspImpl(const bool deterministic, |
| const OpContext& ctx, |
| const TBlob& ograd, |
| const TBlob& data, |
| const OpReqType req, |
| const NDArray& output); |
| |
| template <typename xpu> |
| void EmbeddingOpBackwardEx(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<NDArray>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<NDArray>& outputs) { |
| CHECK_EQ(inputs.size(), 2U); |
| CHECK_EQ(outputs.size(), 2U); |
| const NDArray& weight_grad = outputs[1]; |
| const NDArray& ograd = inputs[0]; |
| const NDArray& data = inputs[1]; |
| // check dtype |
| CHECK_EQ(weight_grad.dtype(), ograd.dtype()); |
| // check req |
| CHECK_EQ(req[embedding::kData], kNullOp) |
| << "Embedding layer doesn't support calculate data gradient"; |
| if (data.storage_type() == kDefaultStorage && ograd.storage_type() == kDefaultStorage && |
| weight_grad.storage_type() == kRowSparseStorage) { |
| SparseEmbeddingOpBackwardRspImpl<xpu>( |
| true, ctx, ograd.data(), data.data(), req[embedding::kWeight], weight_grad); |
| } else { |
| LogUnimplementedOp(attrs, ctx, inputs, req, outputs); |
| } |
| } |
| |
| namespace take_ { // to avoid name conflict |
| enum TakeOpInputs { kArr, kIdx }; |
| enum TakeOpOutputs { kOut }; |
| enum TakeOpResource { kTempSpace }; |
| enum TakeOpMode { kRaise, kWrap, kClip }; |
| } // namespace take_ |
| |
| // TODO(somebody): behaviors specified by params |
| struct TakeParam : public dmlc::Parameter<TakeParam> { |
| int axis; |
| int mode; |
| DMLC_DECLARE_PARAMETER(TakeParam) { |
| DMLC_DECLARE_FIELD(axis).set_default(0).describe( |
| "The axis of input array to be taken." |
| "For input tensor of rank r, it could be in the range of [-r, r-1]"); |
| DMLC_DECLARE_FIELD(mode) |
| .add_enum("raise", take_::kRaise) |
| .add_enum("wrap", take_::kWrap) |
| .add_enum("clip", take_::kClip) |
| .set_default(take_::kClip) |
| .describe( |
| "Specify how out-of-bound indices bahave. Default is \"clip\"." |
| " \"clip\" means clip to the range. So, if all indices mentioned are too large," |
| " they are replaced by the index that addresses the last element along an axis." |
| " \"wrap\" means to wrap around." |
| " \"raise\" means to raise an error when index out of range."); |
| } |
| |
| void SetAttrDict(std::unordered_map<std::string, std::string>* dict) { |
| std::ostringstream axis_s, mode_s; |
| axis_s << axis; |
| mode_s << mode; |
| (*dict)["axis"] = axis_s.str(); |
| (*dict)["mode"] = mode_s.str(); |
| switch (mode) { |
| case take_::kRaise: |
| (*dict)["mode"] = "raise"; |
| break; |
| case take_::kClip: |
| (*dict)["mode"] = "clip"; |
| break; |
| case take_::kWrap: |
| (*dict)["mode"] = "wrap"; |
| break; |
| default: |
| (*dict)["mode"] = mode_s.str(); |
| } |
| } |
| }; |
| |
| inline bool TakeOpShape(const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| using namespace mshadow; |
| const mxnet::TShape& arrshape = (*in_attrs)[take_::kArr]; |
| const mxnet::TShape& idxshape = (*in_attrs)[take_::kIdx]; |
| if (!shape_is_known(idxshape)) |
| return false; |
| const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed); |
| CHECK(param.axis >= -1 * arrshape.ndim() && param.axis < arrshape.ndim()) |
| << "Axis should be in the range of [-r, r-1] where r is the rank of input tensor"; |
| |
| out_attrs->clear(); |
| |
| const index_t actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); |
| mxnet::TShape oshape(idxshape.ndim() + arrshape.ndim() - 1, -1); |
| for (index_t i = 0; i < idxshape.ndim(); ++i) { |
| oshape[i + actual_axis] = idxshape[i]; |
| } |
| for (index_t i = 0; i < arrshape.ndim(); i++) { |
| if (i < actual_axis) { |
| oshape[i] = arrshape[i]; |
| } else if (i > actual_axis) { |
| oshape[i + idxshape.ndim() - 1] = arrshape[i]; |
| } |
| } |
| out_attrs->push_back(oshape); |
| return shape_is_known(oshape); |
| } |
| |
| inline bool TakeOpType(const nnvm::NodeAttrs& attrs, |
| std::vector<int>* in_attrs, |
| std::vector<int>* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 2U); |
| CHECK_EQ(out_attrs->size(), 1U); |
| CHECK_NE((*in_attrs)[1], -1) << "Index type must be set for take operator"; |
| |
| TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]); |
| TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]); |
| return (*in_attrs)[0] != -1; |
| } |
| |
| // storage type inference function for take |
| inline bool TakeOpForwardStorageType(const nnvm::NodeAttrs& attrs, |
| const int dev_mask, |
| DispatchMode* dispatch_mode, |
| std::vector<int>* in_attrs, |
| std::vector<int>* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 2U); |
| CHECK_EQ(out_attrs->size(), 1U); |
| const int& idx_stype = in_attrs->at(take_::kIdx); |
| const int& arr_stype = in_attrs->at(take_::kArr); |
| int& out_stype = out_attrs->at(take_::kOut); |
| bool dispatched = false; |
| const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed); |
| if (!dispatched && idx_stype == kDefaultStorage && arr_stype == kDefaultStorage) { |
| // dns, dns -> dns |
| dispatched = |
| storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); |
| } |
| if (!dispatched && idx_stype == kDefaultStorage && arr_stype == kCSRStorage && param.axis == 0 && |
| (param.mode == take_::kWrap || param.mode == take_::kClip)) { |
| // take(dns, csr, axis=0) -> csr |
| dispatched = |
| storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, DispatchMode::kFComputeEx); |
| } |
| if (!dispatched) { |
| dispatched = dispatch_fallback(out_attrs, dispatch_mode); |
| } |
| return dispatched; |
| } |
| |
| template <typename xpu> |
| void TakeOpForwardCsrImpl(const TakeParam& params, |
| const OpContext& ctx, |
| const TBlob& idx, |
| const NDArray& arr, |
| OpReqType req, |
| const NDArray& output); |
| |
| template <typename xpu> |
| void TakeOpForwardEx(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<NDArray>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<NDArray>& outputs) { |
| CHECK_EQ(req[take_::kOut], kWriteTo); |
| CHECK_EQ(inputs.size(), 2U); |
| CHECK_EQ(outputs.size(), 1U); |
| const NDArray& idx = inputs[take_::kIdx]; |
| const NDArray& arr = inputs[take_::kArr]; |
| const NDArray& out = outputs[take_::kOut]; |
| const auto idx_stype = idx.storage_type(); |
| const auto arr_stype = arr.storage_type(); |
| const auto out_stype = out.storage_type(); |
| const auto params = nnvm::get<TakeParam>(attrs.parsed); |
| if (idx_stype == kDefaultStorage && arr_stype == kCSRStorage && out_stype == kCSRStorage) { |
| // dns, csr -> csr |
| TakeOpForwardCsrImpl<xpu>(params, ctx, idx.data(), arr, req[0], out); |
| } else { |
| LogUnimplementedOp(attrs, ctx, inputs, req, outputs); |
| } |
| } |
| |
| template <typename xpu> |
| void TakeOpForward(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs); |
| |
| struct TakeGradGeneralKernel { |
| /*! |
| * \brief Map function for general case of take grad |
| * \param tid global thread id |
| * \param arr_grad ptr to in_grad |
| * \param ograd ptr to out_grad |
| * \param src_indptr ptr to indptr to src indices |
| * \param original_idx ptr to original indices of the inputs |
| * \param in_strides strides of inputs |
| * \param out_strides strides of outputs |
| * \param in_ndims # of dims of input tensor |
| * \param out_ndims # of dims of output tensor |
| * \param idx_ndims # of dims of indices tensor |
| * \param axis_dim dim size of the axis dimension |
| * \param axis axis id |
| */ |
| template <typename DType, typename IType> |
| MSHADOW_XINLINE static void Map(int tid, |
| DType* arr_grad, |
| const DType* ograd, |
| const IType* src_indptr, |
| const IType* original_idx, |
| mshadow::Shape<10> in_strides, |
| mshadow::Shape<10> out_strides, |
| const int in_ndims, |
| const int out_ndims, |
| const int idx_ndims, |
| const int axis, |
| const int K) { |
| const int in_head_index = (axis == 0) ? 0 : tid / in_strides[axis - 1]; |
| const int in_rest_index = (axis == 0) ? tid : tid % in_strides[axis - 1]; |
| const int in_mid_index = in_rest_index / in_strides[axis]; |
| const int in_tail_index = (axis == in_ndims - 1) ? 0 : (in_rest_index % in_strides[axis]); |
| for (IType i = src_indptr[in_mid_index]; i < src_indptr[in_mid_index + 1]; ++i) { |
| int out_mid_index = original_idx[i]; |
| out_mid_index = (out_mid_index < 0) ? out_mid_index + K : out_mid_index; |
| int target = in_tail_index + out_mid_index * in_strides[axis]; |
| target += (axis == 0) ? 0 : in_head_index * out_strides[axis - 1]; |
| arr_grad[tid] += ograd[target]; |
| } |
| } |
| }; |
| |
| struct TakeGradGeneralKernelSafeAccumulation { |
| /*! |
| * \brief Map function for general case of take grad |
| * \param tid global thread id |
| * \param arr_grad ptr to in_grad |
| * \param temp ptr to temporal space to perform accumulation |
| * \param ograd ptr to out_grad |
| * \param src_indptr ptr to indptr to src indices |
| * \param original_idx ptr to original indices of the inputs |
| * \param in_strides strides of inputs |
| * \param out_strides strides of outputs |
| * \param in_ndims # of dims of input tensor |
| * \param out_ndims # of dims of output tensor |
| * \param idx_ndims # of dims of indices tensor |
| * \param axis_dim dim size of the axis dimension |
| * \param axis axis id |
| */ |
| template <typename DType, typename IType, typename AType> |
| MSHADOW_XINLINE static void Map(int tid, |
| DType* arr_grad, |
| AType* temp, |
| const DType* ograd, |
| const IType* src_indptr, |
| const IType* original_idx, |
| mshadow::Shape<10> in_strides, |
| mshadow::Shape<10> out_strides, |
| const int in_ndims, |
| const int out_ndims, |
| const int idx_ndims, |
| const int axis, |
| const int K) { |
| const int in_head_index = (axis == 0) ? 0 : tid / in_strides[axis - 1]; |
| const int in_rest_index = (axis == 0) ? tid : tid % in_strides[axis - 1]; |
| const int in_mid_index = in_rest_index / in_strides[axis]; |
| const int in_tail_index = (axis == in_ndims - 1) ? 0 : (in_rest_index % in_strides[axis]); |
| temp[tid] = static_cast<AType>(arr_grad[tid]); |
| for (IType i = src_indptr[in_mid_index]; i < src_indptr[in_mid_index + 1]; ++i) { |
| int out_mid_index = original_idx[i]; |
| out_mid_index = (out_mid_index < 0) ? out_mid_index + K : out_mid_index; |
| int target = in_tail_index + out_mid_index * in_strides[axis]; |
| target += (axis == 0) ? 0 : in_head_index * out_strides[axis - 1]; |
| temp[tid] += ograd[target]; |
| } |
| arr_grad[tid] = temp[tid]; |
| } |
| }; |
| |
| template <bool clip = true, bool safe_acc = false, typename AType> |
| void TakeOpBackwardImpl(mshadow::Stream<cpu>* s, |
| const OpContext& ctx, |
| const TBlob& arr, |
| const TBlob& idx, |
| const TBlob& ograd, |
| const int axis) { |
| using namespace mxnet_op; |
| using namespace mshadow; |
| CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy implementation"; |
| const mxnet::TShape& arrshape = arr.shape_; |
| const mxnet::TShape& idxshape = idx.shape_; |
| const mxnet::TShape& oshape = ograd.shape_; |
| MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, { |
| // get size of temporary storage for sort |
| int* src_indptr_ptr = nullptr; |
| size_t temp_storage_bytes = SortByKeyWorkspaceSize<int, int, cpu>(idxshape.Size()); |
| size_t original_idx_bytes = idxshape.Size() * sizeof(int); |
| size_t src_indptr_bytes = (arrshape[axis] + 1) * sizeof(int); |
| size_t temp_accumulation_arrgrad_bytes = 0; |
| if (safe_acc) { |
| temp_accumulation_arrgrad_bytes = arr.Size() * sizeof(AType); |
| } |
| size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + temp_storage_bytes + |
| temp_accumulation_arrgrad_bytes; |
| Tensor<cpu, 1, char> workspace = |
| ctx.requested[0].get_space_typed<cpu, 1, char>(Shape1(workspace_bytes), s); |
| AType* temp_accum_arrgrad_ptr = reinterpret_cast<AType*>(workspace.dptr_); |
| int* sorted_idx_ptr = reinterpret_cast<int*>(workspace.dptr_ + temp_accumulation_arrgrad_bytes); |
| int* original_idx_ptr = reinterpret_cast<int*>(workspace.dptr_ + original_idx_bytes + |
| temp_accumulation_arrgrad_bytes); |
| src_indptr_ptr = reinterpret_cast<int*>(workspace.dptr_ + 2 * original_idx_bytes + |
| temp_accumulation_arrgrad_bytes); |
| Tensor<cpu, 1, char> temp_storage(workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes + |
| temp_accumulation_arrgrad_bytes, |
| Shape1(temp_storage_bytes), |
| s); |
| // Reset indptr to zero |
| Kernel<set_zero, cpu>::Launch(s, arrshape[axis] + 1, src_indptr_ptr); |
| // Fill original_idx |
| Kernel<range_fwd, cpu>::Launch(s, idxshape.Size(), 1, 0, 1, kWriteTo, original_idx_ptr); |
| // Fill sorted_idx_ptr with unsorted copy of idx |
| Kernel<mshadow_op::identity_with_cast, cpu>::Launch( |
| s, idxshape.Size(), sorted_idx_ptr, idx.dptr<IType>()); |
| if (clip) { |
| Kernel<op_with_req<mshadow_op::clip, kWriteTo>, cpu>::Launch( |
| s, |
| idxshape.Size(), |
| sorted_idx_ptr, |
| sorted_idx_ptr, |
| 0, |
| static_cast<int>(arrshape[axis] - 1)); |
| } else { |
| Kernel<op_with_req<mshadow_op::mod, kWriteTo>, cpu>::Launch( |
| s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr, static_cast<int>(arrshape[axis])); |
| } |
| Tensor<cpu, 1, int> original_idx(original_idx_ptr, Shape1(idxshape.Size()), s); |
| int num_bits = common::ilog2ui(static_cast<unsigned int>(idxshape.Size()) - 1); |
| Tensor<cpu, 1, int> sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s); |
| SortByKey(sorted_idx, original_idx, true, &temp_storage, 0, num_bits); |
| for (size_t i = 0; i < idxshape.Size(); ++i) { |
| src_indptr_ptr[sorted_idx_ptr[i] + 1] += 1; |
| } |
| for (int i = 0; i < arrshape[axis]; ++i) { |
| src_indptr_ptr[i + 1] += src_indptr_ptr[i]; |
| } |
| Shape<10> in_strides; |
| int stride = 1; |
| for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) { |
| in_strides[i] = stride; |
| } |
| Shape<10> out_strides; |
| stride = 1; |
| for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) { |
| out_strides[i] = stride; |
| } |
| MSHADOW_TYPE_SWITCH(arr.type_flag_, DType, { |
| if (safe_acc) { |
| Kernel<TakeGradGeneralKernelSafeAccumulation, cpu>::Launch( |
| s, |
| arrshape.Size(), |
| arr.dptr<DType>(), |
| temp_accum_arrgrad_ptr, |
| ograd.dptr<DType>(), |
| src_indptr_ptr, |
| original_idx_ptr, |
| in_strides, |
| out_strides, |
| arrshape.ndim(), |
| oshape.ndim(), |
| idxshape.ndim(), |
| axis, |
| static_cast<int>(arrshape[axis])); |
| } else { |
| Kernel<TakeGradGeneralKernel, cpu>::Launch(s, |
| arrshape.Size(), |
| arr.dptr<DType>(), |
| ograd.dptr<DType>(), |
| src_indptr_ptr, |
| original_idx_ptr, |
| in_strides, |
| out_strides, |
| arrshape.ndim(), |
| oshape.ndim(), |
| idxshape.ndim(), |
| axis, |
| static_cast<int>(arrshape[axis])); |
| } |
| }); |
| }); |
| } |
| |
| #ifdef __CUDACC__ |
| template <bool clip = true, bool safe_acc = false, typename AType> |
| void TakeOpBackwardImpl(mshadow::Stream<gpu>* s, |
| const OpContext& ctx, |
| const TBlob& arr, |
| const TBlob& idx, |
| const TBlob& ograd, |
| const int axis) { |
| using namespace mxnet_op; |
| using namespace mshadow; |
| CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy implementation"; |
| const mxnet::TShape& arrshape = arr.shape_; |
| const mxnet::TShape& idxshape = idx.shape_; |
| const mxnet::TShape& oshape = ograd.shape_; |
| MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, { |
| // get size of temporary storage for sort |
| char* temp_storage_ptr = nullptr; |
| size_t scan_temp_storage_bytes = 0; |
| int* src_indptr_ptr = nullptr; |
| cub::DeviceScan::ExclusiveSum(temp_storage_ptr, |
| scan_temp_storage_bytes, |
| src_indptr_ptr, |
| src_indptr_ptr, |
| arrshape[axis] + 1, |
| mshadow::Stream<gpu>::GetStream(s)); |
| size_t sort_temp_storage_bytes = SortByKeyWorkspaceSize<int, int, gpu>(idxshape.Size()); |
| size_t histo_temp_storage_bytes = 0; |
| int* sorted_idx_ptr = nullptr; |
| cub::DeviceHistogram::HistogramEven(temp_storage_ptr, |
| histo_temp_storage_bytes, |
| sorted_idx_ptr, |
| src_indptr_ptr, |
| static_cast<int>(arrshape[axis] + 1), |
| 0, |
| static_cast<int>(arrshape[axis] + 1), |
| static_cast<int>(idxshape.Size()), |
| mshadow::Stream<gpu>::GetStream(s)); |
| size_t temp_storage_bytes = max(scan_temp_storage_bytes, sort_temp_storage_bytes); |
| temp_storage_bytes = max(temp_storage_bytes, histo_temp_storage_bytes); |
| size_t original_idx_bytes = idxshape.Size() * sizeof(int); |
| size_t src_indptr_bytes = (arrshape[axis] + 1) * sizeof(int); |
| size_t temp_accumulation_igrad_bytes = 0; |
| if (safe_acc) { |
| temp_accumulation_igrad_bytes = arr.Size() * sizeof(AType); |
| } |
| size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + temp_storage_bytes + |
| temp_accumulation_igrad_bytes; |
| Tensor<gpu, 1, char> workspace = |
| ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(workspace_bytes), s); |
| AType* temp_accum_igrad_ptr = reinterpret_cast<AType*>(workspace.dptr_); |
| sorted_idx_ptr = reinterpret_cast<int*>(workspace.dptr_ + temp_accumulation_igrad_bytes); |
| int* original_idx_ptr = reinterpret_cast<int*>(workspace.dptr_ + original_idx_bytes + |
| temp_accumulation_igrad_bytes); |
| src_indptr_ptr = reinterpret_cast<int*>(workspace.dptr_ + 2 * original_idx_bytes + |
| temp_accumulation_igrad_bytes); |
| temp_storage_ptr = |
| workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes + temp_accumulation_igrad_bytes; |
| |
| // Reset indptr to zero |
| Kernel<set_zero, gpu>::Launch(s, arrshape[axis] + 1, src_indptr_ptr); |
| // Fill original_idx |
| Kernel<range_fwd, gpu>::Launch(s, |
| idxshape.Size(), |
| 1, |
| static_cast<int>(0), |
| static_cast<int>(1), |
| kWriteTo, |
| original_idx_ptr); |
| // Fill sorted_idx_ptr with unsorted copy of idx |
| Kernel<mshadow_op::identity_with_cast, gpu>::Launch( |
| s, idxshape.Size(), sorted_idx_ptr, idx.dptr<IType>()); |
| if (clip) { |
| Kernel<op_with_req<mshadow_op::clip, kWriteTo>, gpu>::Launch( |
| s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr, 0, static_cast<int>(arrshape[axis])); |
| } else { |
| Kernel<op_with_req<mshadow_op::mod, kWriteTo>, gpu>::Launch( |
| s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr, static_cast<int>(arrshape[axis])); |
| } |
| Tensor<gpu, 1, int> original_idx(original_idx_ptr, Shape1(idxshape.Size()), s); |
| Tensor<gpu, 1, char> temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s); |
| int num_bits = common::ilog2ui(static_cast<unsigned int>(idxshape.Size()) - 1); |
| Tensor<gpu, 1, int> sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s); |
| SortByKey(sorted_idx, original_idx, true, &temp_storage, 0, num_bits); |
| cub::DeviceHistogram::HistogramEven(temp_storage_ptr, |
| temp_storage_bytes, |
| sorted_idx_ptr, |
| src_indptr_ptr, |
| static_cast<int>(arrshape[axis] + 1), |
| 0, |
| static_cast<int>(arrshape[axis] + 1), |
| static_cast<int>(idxshape.Size()), |
| mshadow::Stream<gpu>::GetStream(s)); |
| cub::DeviceScan::ExclusiveSum(temp_storage_ptr, |
| temp_storage_bytes, |
| src_indptr_ptr, |
| src_indptr_ptr, |
| arrshape[axis] + 1, |
| mshadow::Stream<gpu>::GetStream(s)); |
| |
| Shape<10> in_strides; |
| int stride = 1; |
| for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) { |
| in_strides[i] = stride; |
| } |
| Shape<10> out_strides; |
| stride = 1; |
| for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) { |
| out_strides[i] = stride; |
| } |
| MSHADOW_TYPE_SWITCH(arr.type_flag_, DType, { |
| if (safe_acc) { |
| Kernel<TakeGradGeneralKernelSafeAccumulation, gpu>::Launch( |
| s, |
| arrshape.Size(), |
| arr.dptr<DType>(), |
| temp_accum_igrad_ptr, |
| ograd.dptr<DType>(), |
| src_indptr_ptr, |
| original_idx_ptr, |
| in_strides, |
| out_strides, |
| arrshape.ndim(), |
| oshape.ndim(), |
| idxshape.ndim(), |
| axis, |
| static_cast<int>(arrshape[axis])); |
| } else { |
| Kernel<TakeGradGeneralKernel, gpu>::Launch(s, |
| arrshape.Size(), |
| arr.dptr<DType>(), |
| ograd.dptr<DType>(), |
| src_indptr_ptr, |
| original_idx_ptr, |
| in_strides, |
| out_strides, |
| arrshape.ndim(), |
| oshape.ndim(), |
| idxshape.ndim(), |
| axis, |
| static_cast<int>(arrshape[axis])); |
| } |
| }); |
| }); |
| } |
| #endif |
| |
| template <typename xpu> |
| void TakeOpBackward(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| using namespace mshadow; |
| using namespace mshadow::expr; |
| CHECK_EQ(inputs.size(), 2U); |
| CHECK_EQ(outputs.size(), 2U); |
| CHECK_NE(req[take_::kIdx], kAddTo) |
| << "take layer doesn't support gradient of req type kAddTo to index"; |
| |
| const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed); |
| |
| // grad_out is the gradient of the outputs in the feed-forward |
| // grad_in is the gradient of the inputs in the feed-forward |
| Stream<xpu>* s = ctx.get_stream<xpu>(); |
| |
| bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", true); |
| if (!safe_acc && outputs[0].type_flag_ == mshadow::kFloat16) { |
| common::LogOnce( |
| "MXNET_SAFE_ACCUMULATION=1 is recommended for TakeOpBackward " |
| "with float16 inputs. " |
| "See https://mxnet.apache.org/api/faq/env_var " |
| "for more details."); |
| } |
| MXNET_REAL_ACC_TYPE_SWITCH(outputs[0].type_flag_, DType, AType, { |
| MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index data type |
| // inputs are specified in the .cc file, which are the gradients from |
| // the upper layer and the input index |
| // outputs are the gradients of inputs in the feed-forward pass |
| const mxnet::TShape& idxshape = inputs[1].shape_; |
| const mxnet::TShape& arrshape = outputs[0].shape_; |
| const mxnet::TShape& oshape = inputs[0].shape_; |
| |
| Tensor<xpu, 2, DType> grad_in = outputs[0].get_with_shape<xpu, 2, DType>( |
| Shape2(arrshape[0], arrshape.ProdShape(1, arrshape.ndim())), s); |
| |
| if (req[take_::kArr] == kWriteTo) { |
| grad_in = scalar<DType>(0.0f); |
| } |
| |
| if (idxshape.Size() == 0) { |
| return; |
| } |
| |
| if (req[take_::kIdx] != kNullOp) { |
| mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch( |
| s, idxshape.Size(), outputs[take_::kIdx].dptr<IType>()); |
| } |
| |
| const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); |
| |
| int idxndim = idxshape.ndim(); |
| Tensor<xpu, 1, IType> idx = |
| inputs[1].get_with_shape<xpu, 1, IType>(Shape1(idxshape.ProdShape(0, idxndim)), s); |
| Tensor<xpu, 2, DType> grad_out = inputs[0].get_with_shape<xpu, 2, DType>( |
| Shape2(oshape.ProdShape(0, idxndim), oshape.ProdShape(idxndim, oshape.ndim())), s); |
| |
| // re-using the previous code for axis = 0 case |
| if (actual_axis == 0) { |
| if (req[take_::kArr] == kWriteTo || req[take_::kArr] == kAddTo) { |
| if (safe_acc) { |
| // Temporary storage for safe accumulation |
| size_t temp_space_size = grad_in.size(0) * grad_in.size(1) * sizeof(AType); |
| Tensor<xpu, 1, char> temp_space = |
| ctx.requested[take_::kTempSpace].get_space_typed<xpu, 1, char>( |
| Shape1(temp_space_size), s); |
| Tensor<xpu, 2, AType> temp_grad_in( |
| reinterpret_cast<AType*>(temp_space.dptr_), grad_in.shape_, s); |
| if (param.mode == take_::kClip) { |
| AddTakeGrad(grad_in, temp_grad_in, idx, grad_out); |
| } else { |
| AddTakeGrad<false>(grad_in, temp_grad_in, idx, grad_out); |
| } |
| } else { |
| if (param.mode == take_::kClip) { |
| AddTakeGrad(grad_in, idx, grad_out); |
| } else { |
| AddTakeGrad<false>(grad_in, idx, grad_out); |
| } |
| } |
| } else { |
| LOG(FATAL) << "wrong req"; |
| } |
| // for all other cases |
| } else { |
| const TBlob& idx = inputs[1]; |
| const TBlob& arr = outputs[0]; |
| const TBlob& ograd = inputs[0]; |
| |
| if (safe_acc) { |
| if (param.mode == take_::kClip) { |
| TakeOpBackwardImpl<true, true, AType>(s, ctx, arr, idx, ograd, actual_axis); |
| } else { |
| TakeOpBackwardImpl<false, true, AType>(s, ctx, arr, idx, ograd, actual_axis); |
| } |
| } else { |
| if (param.mode == take_::kClip) { |
| TakeOpBackwardImpl<true, false, AType>(s, ctx, arr, idx, ograd, actual_axis); |
| } else { |
| TakeOpBackwardImpl<false, false, AType>(s, ctx, arr, idx, ograd, actual_axis); |
| } |
| } |
| } |
| }); |
| }); |
| } |
| |
| inline bool BatchTakeOpShape(const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| LOG(INFO) << "batch_take is deprecated. Please use pick instead."; |
| CHECK_EQ(in_attrs->size(), 2U) << "BatchTake op requires two inputs"; |
| if ((*in_attrs)[1].ndim() != 0) { |
| SHAPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[1]); |
| } else if ((*out_attrs)[0].ndim() != 0) { |
| SHAPE_ASSIGN_CHECK(*in_attrs, 1, (*out_attrs)[0]); |
| } |
| if ((*in_attrs)[0].ndim() == 0) |
| return false; |
| CHECK_GE((*in_attrs)[0].ndim(), 2) << "Data array must have at least 2 dimensional"; |
| if ((*out_attrs)[0].ndim() == 0) |
| return false; |
| CHECK_EQ((*in_attrs)[0].Size() / (*in_attrs)[0][(*in_attrs)[0].ndim() - 1], |
| (*out_attrs)[0].Size()) |
| << "Index array's size must be the same as data array's size excluding the first dimension"; |
| return true; |
| } |
| |
| inline bool BatchTakeOpType(const nnvm::NodeAttrs& attrs, |
| std::vector<int>* in_attrs, |
| std::vector<int>* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 2U); |
| if ((*in_attrs)[0] != -1) { |
| TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]); |
| } else if ((*out_attrs)[0] != -1) { |
| TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]); |
| } |
| TYPE_ASSIGN_CHECK(*in_attrs, 1, mshadow::kInt32); |
| return true; |
| } |
| |
| /*! \brief take scalar value from 2d data array */ |
| template <int req> |
| struct batch_take { |
| template <typename DType> |
| MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a, const int* idx, int M) { |
| int j = idx[i]; |
| if (j < 0) |
| j = 0; |
| else if (j >= M) |
| j = M - 1; |
| KERNEL_ASSIGN(out[i], req, a[i * M + j]); |
| } |
| }; |
| |
| template <typename xpu> |
| void BatchTakeOpForward(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| CHECK_EQ(inputs.size(), 2U); |
| CHECK_EQ(outputs.size(), 1U); |
| CHECK_EQ(req.size(), 1U); |
| using namespace mxnet_op; |
| mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { |
| MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { |
| Kernel<batch_take<req_type>, xpu>::Launch(s, |
| outputs[0].Size(), |
| outputs[0].dptr<DType>(), |
| inputs[0].dptr<DType>(), |
| inputs[1].dptr<int>(), |
| inputs[0].Size() / inputs[0].shape_[0]); |
| }); |
| }); |
| } |
| |
| /*! |
| * \brief The parameters of the one_hot operator. |
| */ |
| struct OneHotParam : public dmlc::Parameter<OneHotParam> { |
| index_t depth; |
| double on_value; |
| double off_value; |
| int axis; |
| int dtype; |
| DMLC_DECLARE_PARAMETER(OneHotParam) { |
| DMLC_DECLARE_FIELD(depth).describe("Depth of the one hot dimension."); |
| DMLC_DECLARE_FIELD(on_value).set_default(1.0f).describe( |
| "The value assigned to the locations represented by indices."); |
| DMLC_DECLARE_FIELD(off_value).set_default(0.0f).describe( |
| "The value assigned to the locations not represented by indices."); |
| DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32) |
| MXNET_ADD_ALL_TYPES.describe("DType of the output"); |
| } |
| void SetAttrDict(std::unordered_map<std::string, std::string>* dict) { |
| std::ostringstream depth_s, on_value_s, off_value_s, axis_s, dtype_s; |
| depth_s << depth; |
| on_value_s << on_value; |
| off_value_s << off_value; |
| dtype_s << dtype; |
| (*dict)["depth"] = depth_s.str(); |
| (*dict)["on_value"] = on_value_s.str(); |
| (*dict)["off_value"] = off_value_s.str(); |
| (*dict)["dtype"] = MXNetTypeWithBool2String(dtype); |
| } |
| }; |
| |
| inline void GetOneHotParams(const OneHotParam& param, |
| index_t* depth, |
| double* on_value, |
| double* off_value, |
| int* dtype) { |
| *depth = param.depth; |
| CHECK_GE(*depth, 0) << "Dimension size, depth, must be a non-negative integer"; |
| *on_value = param.on_value; |
| *off_value = param.off_value; |
| *dtype = param.dtype; |
| } |
| |
| inline bool OneHotOpShape(const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| const OneHotParam& param = nnvm::get<OneHotParam>(attrs.parsed); |
| CHECK_EQ(in_attrs->size(), 1U); |
| CHECK_EQ(out_attrs->size(), 1U); |
| // The shape of indices |
| const mxnet::TShape& ishape = (*in_attrs)[0]; |
| if (!shape_is_known(ishape)) |
| return false; |
| |
| index_t depth = 0; |
| double on_value = 1.0; |
| double off_value = 0.0; |
| int dtype = mshadow::kFloat32; |
| GetOneHotParams(param, &depth, &on_value, &off_value, &dtype); |
| |
| mxnet::TShape oshape(ishape.ndim() + 1, -1); |
| for (index_t i = 0; i < ishape.ndim(); ++i) { |
| oshape[i] = ishape[i]; |
| } |
| oshape[oshape.ndim() - 1] = depth; |
| SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); |
| return shape_is_known(oshape); |
| } |
| |
| inline bool OneHotOpType(const nnvm::NodeAttrs& attrs, |
| std::vector<int>* in_attrs, |
| std::vector<int>* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 1U); |
| CHECK_EQ(out_attrs->size(), 1U); |
| CHECK_NE((*in_attrs)[0], -1) << "Index type must be set for one_hot operator"; |
| index_t depth = 0; |
| double on_value = 1.0; |
| double off_value = 0.0; |
| int dtype = -1; |
| const OneHotParam& param = nnvm::get<OneHotParam>(attrs.parsed); |
| GetOneHotParams(param, &depth, &on_value, &off_value, &dtype); |
| TYPE_ASSIGN_CHECK(*out_attrs, 0, dtype); // assign output type |
| |
| return true; |
| } |
| |
| template <int req> |
| struct one_hot { |
| template <typename DType, typename IType> |
| MSHADOW_XINLINE static void Map(index_t i, |
| DType* out, |
| const IType* indices, |
| index_t depth, |
| DType on_value) { |
| index_t offset = i * depth; |
| index_t j = static_cast<index_t>(indices[i]); |
| if (j >= 0 && j < depth) { |
| KERNEL_ASSIGN(out[offset + j], req, on_value); |
| } |
| } |
| }; |
| |
| template <typename xpu> |
| void OneHotOpForward(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| CHECK_EQ(inputs.size(), 1U); |
| CHECK_EQ(outputs.size(), 1U); |
| CHECK_EQ(req.size(), 1U); |
| // The following line is needed to guard the situation when |
| // an output array is empty on GPU. In that case, out.dptr() = 0x0 |
| if (outputs[0].Size() == 0) |
| return; |
| index_t depth = 0; |
| double on_value = 1.0; |
| double off_value = 0.0; |
| int dtype = mshadow::kFloat32; |
| const OneHotParam& param = nnvm::get<OneHotParam>(attrs.parsed); |
| GetOneHotParams(param, &depth, &on_value, &off_value, &dtype); |
| using namespace mxnet_op; |
| using namespace mshadow::expr; |
| mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); |
| MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output data type switch |
| mshadow::Tensor<xpu, 1, DType> out = outputs[0].FlatTo1D<xpu, DType>(s); |
| ASSIGN_DISPATCH(out, req[0], static_cast<DType>(off_value)); |
| MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { // request type switch |
| MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { // indices data type switch |
| Kernel<one_hot<req_type>, xpu>::Launch(s, |
| inputs[0].Size(), |
| outputs[0].dptr<DType>(), |
| inputs[0].dptr<IType>(), |
| depth, |
| static_cast<DType>(on_value)); |
| }); |
| }); |
| }); |
| } |
| |
| struct gather_nd { |
| template <typename DType, typename IType> |
| MSHADOW_XINLINE static void Map(index_t i, |
| OpReqType req, |
| index_t N, |
| index_t M, |
| index_t K, |
| const mshadow::Shape<10> strides, |
| const mshadow::Shape<10> mshape, |
| DType* out, |
| const DType* data, |
| const IType* indices) { |
| index_t offset = 0; |
| for (index_t j = 0; j < M; ++j) { |
| offset += strides[j] * (static_cast<index_t>(indices[j * N + i] + mshape[j]) % mshape[j]); |
| } |
| for (index_t j = 0; j < K; ++j) { |
| KERNEL_ASSIGN(out[i * K + j], req, data[offset + j]); |
| } |
| } |
| }; |
| |
| /*! |
| * \brief If any index in a dimension is out of bound, |
| then the value in this dimension will be set to be the out-of-bound index |
| */ |
| struct is_valid_check_gather_nd { |
| template <typename DType> |
| MSHADOW_XINLINE static void Map(int i, |
| DType* is_valid_dim_ptr, |
| const DType* idx_ptr, |
| const index_t N, |
| const mshadow::Shape<10> mshape) { |
| index_t n = N - 1; |
| while (n >= 0) { |
| if (idx_ptr[i * N + n] < -mshape[i] || idx_ptr[i * N + n] > mshape[i] - 1) { |
| is_valid_dim_ptr[i] = idx_ptr[i * N + n]; |
| break; |
| } |
| n--; |
| } |
| } |
| }; |
| |
| inline bool GatherNDShape(const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 2U); |
| CHECK_EQ(out_attrs->size(), 1U); |
| // The shape of indices |
| const mxnet::TShape& dshape = (*in_attrs)[0]; |
| const mxnet::TShape& ishape = (*in_attrs)[1]; |
| |
| if (shape_is_none(dshape) || shape_is_none(ishape)) |
| return false; |
| |
| CHECK_GT(ishape.ndim(), 1) << "gather_nd requires index tensor to have at least 2 dimensions"; |
| |
| CHECK_LE(ishape[0], dshape.ndim()) << "Number of indices exceeds data dimension"; |
| |
| CHECK_LE(ishape[0], 10) << "gather_nd supports indexing along at most 10 dimensions."; |
| |
| mxnet::TShape oshape(ishape.ndim() - 1 + dshape.ndim() - ishape[0], -1); |
| |
| for (int i = 0; i < ishape.ndim() - 1; ++i) { |
| oshape[i] = ishape[i + 1]; |
| } |
| for (int i = 0; i < dshape.ndim() - ishape[0]; ++i) { |
| oshape[ishape.ndim() - 1 + i] = dshape[ishape[0] + i]; |
| } |
| |
| SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); |
| return shape_is_known(oshape); |
| } |
| |
| inline bool GatherNDType(const nnvm::NodeAttrs& attrs, |
| std::vector<int>* in_attrs, |
| std::vector<int>* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 2U); |
| CHECK_EQ(out_attrs->size(), 1U); |
| TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]); |
| TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]); |
| return true; |
| } |
| |
| struct ScatterNDParam : public dmlc::Parameter<ScatterNDParam> { |
| mxnet::TShape shape; |
| DMLC_DECLARE_PARAMETER(ScatterNDParam) { |
| DMLC_DECLARE_FIELD(shape).describe("Shape of output."); |
| } |
| }; |
| |
| inline bool ScatterNDShape(const nnvm::NodeAttrs& attrs, |
| mxnet::ShapeVector* in_attrs, |
| mxnet::ShapeVector* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 2U); |
| CHECK_EQ(out_attrs->size(), 1U); |
| const auto& params = dmlc::get<ScatterNDParam>(attrs.parsed); |
| |
| SHAPE_ASSIGN_CHECK(*out_attrs, 0, params.shape); |
| |
| const mxnet::TShape& dshape = (*in_attrs)[0]; |
| const mxnet::TShape& ishape = (*in_attrs)[1]; |
| const mxnet::TShape& oshape = (*out_attrs)[0]; |
| |
| if (shape_is_none(dshape) || shape_is_none(ishape) || shape_is_none(oshape)) |
| return false; |
| |
| CHECK_GT(ishape.ndim(), 1) << "scatter_nd requires index tensor to have at least 2 dimensions"; |
| |
| CHECK_LE(ishape[0], oshape.ndim()) |
| << "Number of indices exceeds output dimension in operator scatter_nd"; |
| |
| CHECK_LE(ishape[0], 10) << "scatter_nd supports indexing along at most 10 dimensions."; |
| |
| bool valid = dshape.ndim() == ishape.ndim() - 1 + oshape.ndim() - ishape[0]; |
| |
| for (int i = 0; i < ishape.ndim() - 1; ++i) { |
| valid = valid && dshape[i] == ishape[i + 1]; |
| } |
| for (int i = 0; i < oshape.ndim() - ishape[0]; ++i) { |
| valid = valid && dshape[ishape.ndim() - 1 + i] == oshape[ishape[0] + i]; |
| } |
| |
| CHECK(valid) << "Invalid data, indices, and output shape combination for scatter_nd: " << dshape |
| << ", " << ishape << ", " << oshape; |
| |
| return true; |
| } |
| |
| inline bool ScatterNDType(const nnvm::NodeAttrs& attrs, |
| std::vector<int>* in_attrs, |
| std::vector<int>* out_attrs) { |
| CHECK_EQ(in_attrs->size(), 2U); |
| CHECK_EQ(out_attrs->size(), 1U); |
| TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]); |
| TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]); |
| return in_attrs->at(0) != -1 && in_attrs->at(1) != -1; |
| } |
| |
| struct scatter_nd { |
| template <typename DType, typename IType> |
| MSHADOW_XINLINE static void Map(index_t i, |
| OpReqType req, |
| index_t N, |
| index_t M, |
| index_t K, |
| const mshadow::Shape<10> strides, |
| DType* out, |
| const DType* data, |
| const IType* indices) { |
| index_t offset = 0; |
| for (index_t j = 0; j < M; ++j) { |
| offset += strides[j] * static_cast<index_t>(indices[j * N + i]); |
| } |
| for (index_t j = 0; j < K; ++j) { |
| KERNEL_ASSIGN(out[offset + j], req, data[i * K + j]); |
| } |
| } |
| }; |
| |
| template <typename xpu> |
| void ScatterNDForward(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| using namespace mshadow; |
| using nnvm::dim_t; |
| CHECK_EQ(inputs.size(), 2U); |
| CHECK_EQ(outputs.size(), 1U); |
| if (req[0] == kNullOp) |
| return; |
| mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); |
| const mxnet::TShape& oshape = outputs[0].shape_; |
| const mxnet::TShape& ishape = inputs[1].shape_; |
| dim_t M = ishape[0]; |
| dim_t N = ishape.Size() / M; |
| dim_t K = oshape.ProdShape(M, oshape.ndim()); |
| mshadow::Shape<10> strides; |
| for (dim_t i = M - 1, stride = K; i >= 0; stride *= oshape[i], --i) |
| strides[i] = stride; |
| if (kWriteTo == req[0]) { |
| Fill<true>(s, outputs[0], req[0], 0); |
| } |
| MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { // output data type switch |
| MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[1].type_flag_, IType, { // indices data type switch |
| mxnet_op::Kernel<scatter_nd, xpu>::Launch(s, |
| N, |
| req[0], |
| N, |
| M, |
| K, |
| strides, |
| outputs[0].dptr<DType>(), |
| inputs[0].dptr<DType>(), |
| inputs[1].dptr<IType>()); |
| }); |
| }); |
| } |
| |
| template <typename DType, typename IType> |
| inline typename std::enable_if<(!std::is_same<DType, mshadow::half::half_t>::value), void>::type |
| GatherNDBackwardImpl(index_t N, |
| index_t M, |
| index_t K, |
| const mshadow::Shape<10> strides, |
| DType* out, |
| const DType* data, |
| const IType* indices, |
| mshadow::Stream<cpu>* s); |
| |
| template <typename DType, typename IType> |
| inline typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value, void>::type |
| GatherNDBackwardImpl(index_t N, |
| index_t M, |
| index_t K, |
| const mshadow::Shape<10> strides, |
| DType* out, |
| const DType* data, |
| const IType* indices, |
| mshadow::Stream<cpu>* s); |
| |
| template <typename DType, typename IType> |
| inline void GatherNDBackwardImpl(index_t N, |
| index_t M, |
| index_t K, |
| const mshadow::Shape<10> strides, |
| DType* out, |
| const DType* data, |
| const IType* indices, |
| mshadow::Stream<gpu>* s); |
| |
| template <typename xpu> |
| void GatherNDBackward(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| using namespace mshadow; |
| using nnvm::dim_t; |
| CHECK_EQ(inputs.size(), 2U); |
| CHECK_EQ(outputs.size(), 1U); |
| if (req[0] == kNullOp) |
| return; |
| mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); |
| const mxnet::TShape& oshape = outputs[0].shape_; |
| const mxnet::TShape& ishape = inputs[1].shape_; |
| dim_t M = ishape[0]; |
| dim_t N = ishape.Size() / M; |
| dim_t K = oshape.ProdShape(M, oshape.ndim()); |
| mshadow::Shape<10> strides; |
| for (dim_t i = M - 1, stride = K; i >= 0; stride *= oshape[i], --i) |
| strides[i] = stride; |
| if (kWriteTo == req[0]) { |
| Fill<true>(s, outputs[0], req[0], 0); |
| } |
| MXNET_NO_INT8_TYPE_SWITCH(inputs[0].type_flag_, DType, { // output data type switch |
| MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // indices data type switch |
| GatherNDBackwardImpl(N, |
| M, |
| K, |
| strides, |
| outputs[0].dptr<DType>(), |
| inputs[0].dptr<DType>(), |
| inputs[1].dptr<IType>(), |
| s); |
| }); |
| }); |
| } |
| |
| /*! |
| * This is for internal use only. |
| * DO NOT call this function unless you have to. |
| */ |
| template <typename xpu> |
| void ScatterSetNDForward(const nnvm::NodeAttrs& attrs, |
| const OpContext& ctx, |
| const std::vector<TBlob>& inputs, |
| const std::vector<OpReqType>& req, |
| const std::vector<TBlob>& outputs) { |
| CHECK_EQ(inputs.size(), 3U); |
| CHECK_EQ(outputs.size(), 1U); |
| CHECK_EQ(inputs[0].dptr_, outputs[0].dptr_); |
| ScatterNDForward<xpu>(attrs, ctx, {inputs[1], inputs[2]}, {kWriteInplace}, outputs); |
| } |
| |
| } // namespace op |
| } // namespace mxnet |
| |
| #endif // MXNET_OPERATOR_TENSOR_INDEXING_OP_H_ |