blob: 28eca41b2ae47e618a5ac52e2463133a5b07991a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file indexing_op.cc
* \brief CPU implementation of indexing operator
* \author Siyi Li, Chi Zhang
*/
#include "./indexing_op.h"
namespace mxnet {
namespace op {
template <bool clip = true>
struct TakeZeroAxisCPU {
// assume that idx have been flattened to a 1-D tensor (N,)
// assume that out_data and in_data have been flattened to 2-D tensors, (N, M) and (K, M)
// M is the number of columns of in_data and out_data
// K is the number of rows of in_data
// i is the index of out_data
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i,
DType* out_data,
const DType* in_data,
const IType* idx,
const size_t M,
const int64_t K) {
int64_t j = static_cast<int64_t>(idx[i]);
if (clip) {
if (j <= 0)
j = 0;
else if (j >= K)
j = K - 1;
} else {
j = j % K;
j += (j < 0) ? K : 0;
}
#pragma GCC diagnostic push
#if __GNUC__ >= 8
#pragma GCC diagnostic ignored "-Wclass-memaccess"
#endif
std::memcpy(out_data + i * M, in_data + j * M, M * sizeof(DType));
#pragma GCC diagnostic pop
}
};
template <bool clip = true>
struct TakeNonzeroAxisCPU {
/*!
* \brief Map function for take operator
* \param i global thread id
* \param out_data ptr to output buffer
* \param in_data ptr to input buffer
* \param indices ptr to indices buffer
* \param outer_dim_stride stride of dimension before axis
* \param axis_dim_stride stride of axis dimension
* \param idx_size size of the indices tensor
* \param axis_dim dim size of the axis dimension
* \param axis axis id
*/
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i,
DType* out_data,
const DType* in_data,
const IType* indices,
const index_t outer_dim_stride,
const index_t axis_dim_stride,
const int idx_size,
const int axis_dim,
const int axis) {
for (index_t j = 0; j < static_cast<index_t>(idx_size); ++j) {
int index = indices[j];
if (clip) {
index = std::max(index, 0);
index = std::min(axis_dim - 1, index);
} else {
index %= axis_dim;
index += (index < 0) ? axis_dim : 0;
}
size_t in_offset = i * outer_dim_stride + index * axis_dim_stride;
size_t out_offset = (i * idx_size + j) * axis_dim_stride;
#pragma GCC diagnostic push
#if __GNUC__ >= 8
#pragma GCC diagnostic ignored "-Wclass-memaccess"
#endif
std::memcpy(out_data + out_offset, in_data + in_offset, axis_dim_stride * sizeof(DType));
#pragma GCC diagnostic pop
}
}
};
/*
* \brief returns true if all indices are between [min, max]
* \param data_ptr the indices to check
* \param data_size the number of indices to examine
* \param min the expected min value for indices
* \param max the expected max value for indices
*/
template <typename DType>
bool CheckIndexOutOfBound(const DType* data_ptr,
size_t data_size,
const DType min,
const DType max) {
bool is_valid = true;
for (size_t i = 0; i < data_size; i++) {
if (data_ptr[i] > max || data_ptr[i] < min) {
is_valid = false;
break;
}
}
return is_valid;
}
// Embedding forward implementation with dense weight
template <>
void EmbeddingOpForwardDnsImpl<cpu>(mshadow::Stream<cpu>* s,
const TBlob& data,
const TBlob& weight,
const OpReqType req,
const TBlob& output) {
using namespace mxnet_op;
const mxnet::TShape& ishape = data.shape_;
const mxnet::TShape& oshape = output.shape_;
MSHADOW_TYPE_SWITCH(output.type_flag_, DType, {
MSHADOW_TYPE_SWITCH(data.type_flag_, IType, {
Tensor<cpu, 1, IType> idx =
data.get_with_shape<cpu, 1, IType>(Shape1(ishape.ProdShape(0, ishape.ndim())), s);
Tensor<cpu, 2, DType> wmat = weight.get<cpu, 2, DType>(s);
Tensor<cpu, 2, DType> out = output.get_with_shape<cpu, 2, DType>(
Shape2(oshape.ProdShape(0, oshape.ndim() - 1), oshape[oshape.ndim() - 1]), s);
Kernel<TakeZeroAxisCPU<true>, cpu>::Launch(s,
oshape.Size() / wmat.shape_[1],
out.dptr_,
wmat.dptr_,
idx.dptr_,
wmat.shape_[1],
wmat.shape_[0]);
});
});
}
template <>
void SparseEmbeddingOpForwardRspImpl<cpu>(const OpContext& ctx,
const TBlob& data,
const NDArray& weight,
const OpReqType req,
const TBlob& output) {
if (req == kNullOp)
return;
using namespace rowsparse;
using namespace mxnet_op;
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
// zeros weight
if (req == kWriteTo && !weight.storage_initialized()) {
size_t out_size = output.shape_.Size();
MSHADOW_TYPE_SWITCH(output.type_flag_, DType, {
Fill<false>(
s, TBlob(output.dptr<DType>(), mshadow::Shape1(out_size), cpu::kDevMask), kWriteTo, 0);
})
return;
}
// check out-of-bound indices
MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
DType min = 0;
DType max = static_cast<DType>(weight.shape()[0] - 1);
// check with single thread is faster since data is small
DType* data_ptr = data.dptr<DType>();
size_t data_size = data.shape_.Size();
bool is_valid = CheckIndexOutOfBound(data_ptr, data_size, min, max);
CHECK(is_valid) << "SparseEmbedding input contains data out of bound";
})
// the weight is actually dense
if (weight.aux_shape(kIdx)[0] == weight.shape()[0]) {
EmbeddingOpForwardDnsImpl<cpu>(s, data, weight.data(), req, output);
} else {
EmbeddingOpForwardRspImpl<cpu>(s, data, weight, req, output);
}
}
template <bool clip>
struct CsrTakeDataKernel {
/*!
* \brief Map function for general case of take grad
* \param tid global thread id
* \param out_idx ptr to out idx
* \param out_data ptr to out data
* \param out_indptr ptr to out indptr
* \param src_data ptr to original csr data
* \param src_idx ptr to original csr idx
* \param idx_ptr ptr to indices
* \param num_rows maximum number of rows in src array
*/
template <typename IType, typename DType, typename RType>
MSHADOW_XINLINE static void Map(int tid,
RType* out_idx,
DType* out_data,
const RType* out_indptr,
const RType* src_idx,
const DType* src_data,
const RType* src_indptr,
const IType* idx_ptr,
const nnvm::dim_t num_rows) {
nnvm::dim_t idx = static_cast<nnvm::dim_t>(idx_ptr[tid]);
// clip mode
if (clip) {
if (idx < 0)
idx = 0;
if (idx >= num_rows)
idx = num_rows - 1;
} else {
// wrap mode
idx = idx % num_rows;
idx += (idx < 0) ? num_rows : 0;
}
int row_nnz = src_indptr[idx + 1] - src_indptr[idx];
for (int i = 0; i < row_nnz; i++) {
out_data[out_indptr[tid] + i] = src_data[src_indptr[idx] + i];
out_idx[out_indptr[tid] + i] = src_idx[src_indptr[idx] + i];
}
}
};
template <bool clip>
struct CsrTakeRowCountKernel {
/*!
* \brief Map function for general case of take grad
* \param tid global thread id
* \param out_indptr ptr to out indptr
* \param src_indptr ptr to original csr indptr
* \param idx_ptr ptr to indices
* \param num_rows maximum number of rows in src array
*/
template <typename IType, typename RType>
MSHADOW_XINLINE static void Map(int tid,
RType* out_indptr,
const RType* src_indptr,
const IType* idx_ptr,
const nnvm::dim_t num_rows) {
if (tid == 0) {
out_indptr[0] = 0;
return;
}
nnvm::dim_t idx = static_cast<nnvm::dim_t>(idx_ptr[tid - 1]);
// clip mode
if (clip) {
if (idx < 0)
idx = 0;
if (idx >= num_rows)
idx = num_rows - 1;
} else {
// wrap mode
idx = idx % num_rows;
idx += (idx < 0) ? num_rows : 0;
}
out_indptr[tid] = src_indptr[idx + 1] - src_indptr[idx];
}
};
template <>
void TakeOpForwardCsrImpl<cpu>(const TakeParam& params,
const OpContext& ctx,
const TBlob& idx,
const NDArray& arr,
OpReqType req,
const NDArray& out) {
using namespace csr;
using namespace mxnet_op;
using nnvm::dim_t;
Stream<cpu>* s = ctx.get_stream<cpu>();
if (req == kNullOp)
return;
if (!arr.storage_initialized()) {
FillZerosCsrImpl(s, out);
return;
}
CHECK_EQ(idx.shape_.ndim(), 1U) << "Take with CSR array only supports one-dimensional indices. "
<< idx.shape_.ndim() << " dimensional input is given instead";
CHECK_EQ(req, kWriteTo) << "req = " << req << " is not supported for take(csr)";
auto axis = params.axis;
CHECK_EQ(axis, 0) << "axis = " << axis << " is not supported for take(csr)";
CHECK(params.mode == take_::kClip || params.mode == take_::kWrap)
<< "mode = " << params.mode << " is not supported";
const dim_t num_rows = out.shape()[0];
const dim_t max_num_rows = arr.shape()[0];
out.CheckAndAllocAuxData(kIndPtr, {Shape1(num_rows + 1)});
MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, {
MSHADOW_TYPE_SWITCH(arr.dtype(), DType, {
MSHADOW_IDX_TYPE_SWITCH(out.aux_type(kIdx), RType, {
RType* out_indptr = out.aux_data(kIndPtr).dptr<RType>();
const RType* src_indptr = arr.aux_data(kIndPtr).dptr<RType>();
const IType* idx_ptr = idx.dptr<IType>();
// gather per row nnz information for output
bool clip = params.mode == take_::kClip;
if (clip) {
Kernel<CsrTakeRowCountKernel<true>, cpu>::Launch(
s, num_rows + 1, out_indptr, src_indptr, idx_ptr, max_num_rows);
} else {
Kernel<CsrTakeRowCountKernel<false>, cpu>::Launch(
s, num_rows + 1, out_indptr, src_indptr, idx_ptr, max_num_rows);
}
// calculate prefix sum with single thread
for (dim_t i = 0; i < num_rows; i++) {
out_indptr[i + 1] += out_indptr[i];
}
// total number of non-zero rows
const dim_t nnz = out_indptr[num_rows];
if (nnz == 0) {
FillZerosCsrImpl(s, out);
return;
}
out.CheckAndAllocAuxData(kIdx, {Shape1(nnz)});
out.CheckAndAllocData(Shape1(nnz));
RType* out_idx = out.aux_data(kIdx).dptr<RType>();
DType* out_data = out.data().dptr<DType>();
const RType* src_idx = arr.aux_data(kIdx).dptr<RType>();
const DType* src_data = arr.data().dptr<DType>();
// copy indices and data for output
if (clip) {
Kernel<CsrTakeDataKernel<true>, cpu>::Launch(s,
num_rows,
out_idx,
out_data,
out_indptr,
src_idx,
src_data,
src_indptr,
idx_ptr,
max_num_rows);
} else {
Kernel<CsrTakeDataKernel<false>, cpu>::Launch(s,
num_rows,
out_idx,
out_data,
out_indptr,
src_idx,
src_data,
src_indptr,
idx_ptr,
max_num_rows);
}
});
});
});
}
template <>
void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
if (req[take_::kOut] == kNullOp)
return;
const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
const mxnet::TShape& idxshape = inputs[take_::kIdx].shape_;
const mxnet::TShape& arrshape = inputs[take_::kArr].shape_;
const mxnet::TShape& oshape = outputs[take_::kOut].shape_;
if (idxshape.Size() == 0) {
return;
}
Stream<cpu>* s = ctx.get_stream<cpu>();
const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0);
MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[take_::kOut].type_flag_, DType, { // output data type
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[take_::kIdx].type_flag_, IType, { // index data type
if (param.mode == take_::kRaise) {
IType min = 0;
IType max = static_cast<IType>(arrshape[actual_axis] - 1);
// check with single thread is faster since data is small
IType* idx_ptr = inputs[take_::kIdx].dptr<IType>();
size_t idx_size = idxshape.Size();
bool is_valid = CheckIndexOutOfBound(idx_ptr, idx_size, min, max);
CHECK(is_valid) << "take operator contains indices out of bound";
}
if (actual_axis == 0) {
if (param.mode == take_::kClip) {
Kernel<TakeZeroAxisCPU<true>, cpu>::Launch(s,
idxshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
oshape.Size() / idxshape.Size(),
arrshape[0]);
} else {
Kernel<TakeZeroAxisCPU<false>, cpu>::Launch(s,
idxshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
oshape.Size() / idxshape.Size(),
arrshape[0]);
}
} else {
mshadow::Shape<10> in_strides;
index_t stride = 1;
for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) {
in_strides[i] = stride;
}
int outer_dimensions = 1;
for (int i = 0; i < actual_axis; i++) {
outer_dimensions *= oshape[i];
}
if (param.mode == take_::kClip) {
Kernel<TakeNonzeroAxisCPU<true>, cpu>::Launch(s,
outer_dimensions,
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
in_strides[actual_axis - 1],
in_strides[actual_axis],
idxshape.Size(),
arrshape[actual_axis],
actual_axis);
} else {
Kernel<TakeNonzeroAxisCPU<false>, cpu>::Launch(s,
outer_dimensions,
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
in_strides[actual_axis - 1],
in_strides[actual_axis],
idxshape.Size(),
arrshape[actual_axis],
actual_axis);
}
}
});
});
}
template <>
inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const bool deterministic,
const OpContext& ctx,
const TBlob& ograd,
const TBlob& data,
const OpReqType req,
const NDArray& output) {
using namespace mshadow;
using namespace mxnet_op;
using namespace mshadow::expr;
using namespace rowsparse;
using nnvm::dim_t;
if (req == kNullOp)
return;
CHECK_EQ(req, kWriteTo) << "SparseEmbedding layer doesn't support "
<< "weight gradient calculation with req != write";
// Request temporary storage for marking non-zero rows and prefix sum
Stream<cpu>* s = ctx.get_stream<cpu>();
dim_t num_rows = output.shape()[0];
dim_t row_length = output.shape()[1];
size_t workspace_size = num_rows * sizeof(dim_t);
Tensor<cpu, 1, char> workspace =
ctx.requested[embedding::kTempSpace].get_space_typed<cpu, 1, char>(Shape1(workspace_size), s);
dim_t* row_flg = reinterpret_cast<dim_t*>(workspace.dptr_);
// prefix sum array re-uses the row_flg array temp space
dim_t* prefix_sum = row_flg;
dim_t data_size = static_cast<dim_t>(data.shape_.Size());
MSHADOW_TYPE_SWITCH(data.type_flag_, IType, {
MSHADOW_SGL_DBL_TYPE_SWITCH(ograd.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), RType, {
// check out of bound indices
{
IType min = 0;
IType max = static_cast<IType>(output.shape()[0] - 1);
// check with single thread is faster since data is small
IType* data_ptr = data.dptr<IType>();
bool is_valid = CheckIndexOutOfBound(data_ptr, data.shape_.Size(), min, max);
CHECK(is_valid) << "Embedding input contains data out of bound";
}
// mark row flags
Fill<false>(s, TBlob(row_flg, Shape1(num_rows), cpu::kDevMask), kWriteTo, 0);
Kernel<MarkRowFlgKernel, cpu>::Launch(s, data_size, row_flg, data.dptr<IType>());
// calculate inclusive prefix sum
// TODO(haibin) ideally this is should be done in parallel
prefix_sum[0] = row_flg[0];
for (dim_t i = 1; i < num_rows; i++) {
prefix_sum[i] = prefix_sum[i - 1] + row_flg[i];
}
// total number of non-zero rows
dim_t nnr = prefix_sum[num_rows - 1];
if (nnr == 0) {
FillZerosRspImpl(s, output);
return;
}
output.CheckAndAlloc({Shape1(nnr)});
RType* grad_row_idx = output.aux_data(kIdx).dptr<RType>();
// fill row_idx array of output matrix, using the row_flg values
Kernel<FillRspRowIdxKernel, cpu>::Launch(s, num_rows, grad_row_idx, prefix_sum, num_rows);
// prefill with zeros
DType* grad_data = output.data().dptr<DType>();
Fill<false>(s, TBlob(grad_data, Shape1(nnr * row_length), cpu::kDevMask), kWriteTo, 0);
// add the final gradients
const int num_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
dim_t segment_len = (nnr + num_threads - 1) / num_threads;
Kernel<AddTakeGradRspKernel, cpu>::Launch(s,
num_threads,
grad_data,
prefix_sum,
ograd.dptr<DType>(),
row_length,
data.dptr<IType>(),
data_size,
segment_len,
num_rows);
});
});
});
}
/*
* \brief check if any of the indices is out of bound
* \param s the stream
* \param idx_ptr the indices on the stream
* \param N the number of indices in an axis
* \param M the number of axises to exmaine
* \param mshape the array that stores shape for each dimension
* \param is_valid_dim_ptr the temparary workspace that contains out-of-bound indices
*/
template <typename DType>
void GatherNDCheckBoundCPU(mshadow::Stream<cpu>* s,
const DType* idx_ptr,
index_t N,
index_t M,
const mshadow::Shape<10> mshape,
DType* is_valid_dim_ptr) {
using namespace mxnet_op;
Kernel<set_zero, cpu>::Launch(s, M, is_valid_dim_ptr);
Kernel<is_valid_check_gather_nd, cpu>::Launch(s, M, is_valid_dim_ptr, idx_ptr, N, mshape);
for (index_t m = 0; m < M; m++) {
if (is_valid_dim_ptr[m] > mshape[m] - 1 || is_valid_dim_ptr[m] < -mshape[m]) {
LOG(FATAL) << "IndexError: index " << is_valid_dim_ptr[m] << " is out of bounds for axis "
<< m << " with size " << mshape[m];
}
}
}
void GatherNDForwardCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
using namespace mshadow;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
if (req[0] == kNullOp)
return;
mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
const mxnet::TShape& dshape = inputs[0].shape_;
const mxnet::TShape& ishape = inputs[1].shape_;
index_t M = ishape[0];
index_t N = ishape.Size() / M;
index_t K = dshape.ProdShape(M, dshape.ndim());
mshadow::Shape<10> strides;
mshadow::Shape<10> mshape;
for (index_t i = M - 1, stride = K; i >= 0; stride *= dshape[i], --i) {
strides[i] = stride;
mshape[i] = dshape[i];
}
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { // output data type switch
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // indices data type switch
// check whether indices are out of bound
IType* idx_ptr = inputs[1].dptr<IType>();
Tensor<cpu, 1, IType> workspace =
ctx.requested[0].get_space_typed<cpu, 1, IType>(Shape1(M), s);
IType* is_valid_dim_ptr = reinterpret_cast<IType*>(workspace.dptr_);
GatherNDCheckBoundCPU(s, idx_ptr, N, M, mshape, is_valid_dim_ptr);
Kernel<gather_nd, cpu>::Launch(s,
N,
req[0],
N,
M,
K,
strides,
mshape,
outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(),
inputs[1].dptr<IType>());
});
});
}
template <typename DType, typename IType>
inline typename std::enable_if<(!std::is_same<DType, mshadow::half::half_t>::value), void>::type
GatherNDBackwardImpl(index_t N,
index_t M,
index_t K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<cpu>* s) {
#pragma omp parallel for
for (index_t i = 0; i < N; i++) {
index_t offset = 0;
for (index_t j = 0; j < M; ++j) {
offset += strides[j] * static_cast<index_t>(indices[j * N + i]);
}
for (index_t j = 0; j < K; ++j) {
#pragma omp atomic
out[offset + j] += data[i * K + j];
}
}
}
template <typename DType, typename IType>
inline typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value, void>::type
GatherNDBackwardImpl(index_t N,
index_t M,
index_t K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<cpu>* s) {
for (index_t i = 0; i < N; i++) {
index_t offset = 0;
for (index_t j = 0; j < M; ++j) {
offset += strides[j] * static_cast<index_t>(indices[j * N + i]);
}
for (index_t j = 0; j < K; ++j) {
out[offset + j] += data[i * K + j];
}
}
}
DMLC_REGISTER_PARAMETER(EmbeddingParam);
DMLC_REGISTER_PARAMETER(TakeParam);
DMLC_REGISTER_PARAMETER(OneHotParam);
DMLC_REGISTER_PARAMETER(ScatterNDParam);
NNVM_REGISTER_OP(Embedding)
MXNET_ADD_SPARSE_OP_ALIAS(Embedding)
.add_alias("_npx_embedding")
.describe(R"code(Maps integer indices to vector representations (embeddings).
This operator maps words to real-valued vectors in a high-dimensional space,
called word embeddings. These embeddings can capture semantic and syntactic properties of the words.
For example, it has been noted that in the learned embedding spaces, similar words tend
to be close to each other and dissimilar words far apart.
For an input array of shape (d1, ..., dK),
the shape of an output array is (d1, ..., dK, output_dim).
All the input values should be integers in the range [0, input_dim).
If the input_dim is ip0 and output_dim is op0, then shape of the embedding weight matrix must be
(ip0, op0).
When "sparse_grad" is False, if any index mentioned is too large, it is replaced by the index that
addresses the last vector in an embedding matrix.
When "sparse_grad" is True, an error will be raised if invalid indices are found.
Examples::
input_dim = 4
output_dim = 5
// Each row in weight matrix y represents a word. So, y = (w0,w1,w2,w3)
y = [[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[ 10., 11., 12., 13., 14.],
[ 15., 16., 17., 18., 19.]]
// Input array x represents n-grams(2-gram). So, x = [(w1,w3), (w0,w2)]
x = [[ 1., 3.],
[ 0., 2.]]
// Mapped input x to its vector representation y.
Embedding(x, y, 4, 5) = [[[ 5., 6., 7., 8., 9.],
[ 15., 16., 17., 18., 19.]],
[[ 0., 1., 2., 3., 4.],
[ 10., 11., 12., 13., 14.]]]
The storage type of weight can be either row_sparse or default.
.. Note::
If "sparse_grad" is set to True, the storage type of gradient w.r.t weights will be
"row_sparse". Only a subset of optimizers support sparse gradients, including SGD, AdaGrad
and Adam. Note that by default lazy updates is turned on, which may perform differently
from standard updates. For more details, please check the Optimization API at:
https://mxnet.apache.org/versions/master/api/python/docs/api/optimizer/index.html
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<EmbeddingParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "weight"};
})
.set_attr<mxnet::FInferShape>("FInferShape", EmbeddingOpShape<EmbeddingParam>)
.set_attr<nnvm::FInferType>("FInferType", EmbeddingOpType<EmbeddingParam>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.set_attr<FCompute>("FCompute<cpu>", EmbeddingOpForward<cpu>)
.set_attr<nnvm::FGradient>(
"FGradient",
[](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
return MakeNonlossGradNode(
"_backward_Embedding", n, ograds, {n->inputs[0]}, n->attrs.dict);
})
.add_argument("data", "NDArray-or-Symbol", "The input array to the embedding operator.")
.add_argument("weight", "NDArray-or-Symbol", "The embedding weight matrix.")
.add_arguments(EmbeddingParam::__FIELDS__());
NNVM_REGISTER_OP(_backward_Embedding)
.set_num_inputs(2)
.set_num_outputs(2)
.set_attr_parser(ParamParser<EmbeddingParam>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FInferStorageType>("FInferStorageType", EmbeddingOpBackwardStorageType)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", EmbeddingOpBackward<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", EmbeddingOpBackwardEx<cpu>);
NNVM_REGISTER_OP(take)
.add_alias("_npi_take")
.describe(R"code(Takes elements from an input array along the given axis.
This function slices the input array along a particular axis with the provided indices.
Given data tensor of rank r >= 1, and indices tensor of rank q, gather entries of the axis
dimension of data (by default outer-most one as axis=0) indexed by indices, and concatenates them
in an output tensor of rank q + (r - 1).
Examples::
x = [4. 5. 6.]
// Trivial case, take the second element along the first axis.
take(x, [1]) = [ 5. ]
// The other trivial case, axis=-1, take the third element along the first axis
take(x, [3], axis=-1, mode='clip') = [ 6. ]
x = [[ 1., 2.],
[ 3., 4.],
[ 5., 6.]]
// In this case we will get rows 0 and 1, then 1 and 2. Along axis 0
take(x, [[0,1],[1,2]]) = [[[ 1., 2.],
[ 3., 4.]],
[[ 3., 4.],
[ 5., 6.]]]
// In this case we will get rows 0 and 1, then 1 and 2 (calculated by wrapping around).
// Along axis 1
take(x, [[0, 3], [-1, -2]], axis=1, mode='wrap') = [[[ 1. 2.]
[ 2. 1.]]
[[ 3. 4.]
[ 4. 3.]]
[[ 5. 6.]
[ 6. 5.]]]
The storage type of ``take`` output depends upon the input storage type:
- take(default, default) = default
- take(csr, default, axis=0) = csr
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<TakeParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a", "indices"};
})
.set_attr<mxnet::FInferShape>("FInferShape", TakeOpShape)
.set_attr<nnvm::FInferType>("FInferType", TakeOpType)
.set_attr<FInferStorageType>("FInferStorageType", TakeOpForwardStorageType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.set_attr<FCompute>("FCompute<cpu>", TakeOpForward<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", TakeOpForwardEx<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::ObjectPtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
return MakeNonlossGradNode(
"_backward_take", n, ograds, {n->inputs[1]}, n->attrs.dict);
})
.add_argument("a", "NDArray-or-Symbol", "The input array.")
.add_argument("indices", "NDArray-or-Symbol", "The indices of the values to be extracted.")
.add_arguments(TakeParam::__FIELDS__());
NNVM_REGISTER_OP(_backward_take)
.set_num_inputs(2)
.set_num_outputs(2)
.set_attr_parser(ParamParser<TakeParam>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", TakeOpBackward<cpu>);
NNVM_REGISTER_OP(batch_take)
.describe(R"code(Takes elements from a data batch.
.. note::
`batch_take` is deprecated. Use `pick` instead.
Given an input array of shape ``(d0, d1)`` and indices of shape ``(i0,)``, the result will be
an output array of shape ``(i0,)`` with::
output[i] = input[i, indices[i]]
Examples::
x = [[ 1., 2.],
[ 3., 4.],
[ 5., 6.]]
// takes elements with specified indices
batch_take(x, [0,1,0]) = [ 1. 4. 5.]
)code" ADD_FILELINE)
.set_num_outputs(1)
.set_num_inputs(2)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a", "indices"};
})
.set_attr<mxnet::FInferShape>("FInferShape", BatchTakeOpShape)
.set_attr<nnvm::FInferType>("FInferType", BatchTakeOpType)
.set_attr<FCompute>("FCompute<cpu>", BatchTakeOpForward<cpu>)
.add_argument("a", "NDArray-or-Symbol", "The input array")
.add_argument("indices", "NDArray-or-Symbol", "The index array");
NNVM_REGISTER_OP(one_hot)
.add_alias("_npx_one_hot")
.describe(R"code(Returns a one-hot array.
The locations represented by `indices` take value `on_value`, while all
other locations take value `off_value`.
`one_hot` operation with `indices` of shape ``(i0, i1)`` and `depth` of ``d`` would result
in an output array of shape ``(i0, i1, d)`` with::
output[i,j,:] = off_value
output[i,j,indices[i,j]] = on_value
Examples::
one_hot([1,0,2,0], 3) = [[ 0. 1. 0.]
[ 1. 0. 0.]
[ 0. 0. 1.]
[ 1. 0. 0.]]
one_hot([1,0,2,0], 3, on_value=8, off_value=1,
dtype='int32') = [[1 8 1]
[8 1 1]
[1 1 8]
[8 1 1]]
one_hot([[1,0],[1,0],[2,0]], 3) = [[[ 0. 1. 0.]
[ 1. 0. 0.]]
[[ 0. 1. 0.]
[ 1. 0. 0.]]
[[ 0. 0. 1.]
[ 1. 0. 0.]]]
)code" ADD_FILELINE)
.set_num_outputs(1)
.set_num_inputs(1)
.set_attr_parser(ParamParser<OneHotParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"indices"};
})
.set_attr<mxnet::FInferShape>("FInferShape", OneHotOpShape)
.set_attr<nnvm::FInferType>("FInferType", OneHotOpType)
.set_attr<FCompute>("FCompute<cpu>", OneHotOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("indices", "NDArray-or-Symbol", "array of locations where to set on_value")
.add_arguments(OneHotParam::__FIELDS__());
NNVM_REGISTER_OP(gather_nd)
.add_alias("_npi_gather_nd")
.add_alias("_npx_gather_nd")
.describe(R"code(Gather elements or slices from `data` and store to a tensor whose
shape is defined by `indices`.
Given `data` with shape `(X_0, X_1, ..., X_{N-1})` and indices with shape
`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})`,
where `M <= N`. If `M == N`, output shape will simply be `(Y_0, ..., Y_{K-1})`.
The elements in output is defined as follows::
output[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}] = data[indices[0, y_0, ..., y_{K-1}],
...,
indices[M-1, y_0, ..., y_{K-1}],
x_M, ..., x_{N-1}]
Examples::
data = [[0, 1], [2, 3]]
indices = [[1, 1, 0], [0, 1, 0]]
gather_nd(data, indices) = [2, 3, 0]
data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
indices = [[0, 1], [1, 0]]
gather_nd(data, indices) = [[3, 4], [5, 6]]
)code")
.set_num_outputs(1)
.set_num_inputs(2)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "indices"};
})
.set_attr<mxnet::FInferShape>("FInferShape", GatherNDShape)
.set_attr<nnvm::FInferType>("FInferType", GatherNDType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", GatherNDForwardCPU)
.set_attr<nnvm::FGradient>(
"FGradient",
[](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get("_backward_gather_nd");
p->attrs.name = n->attrs.name + "_backward";
p->inputs.push_back(ograds[0]);
p->inputs.push_back(n->inputs[1]);
p->control_deps.emplace_back(n);
auto zero = MakeNode(
"zeros_like", n->attrs.name + "_backward_indices", {n->inputs[1]}, nullptr, &n);
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(p);
ret.emplace_back(zero);
return ret;
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.add_argument("data", "NDArray-or-Symbol", "data")
.add_argument("indices", "NDArray-or-Symbol", "indices");
NNVM_REGISTER_OP(scatter_nd)
.describe(R"code(Scatters data into a new tensor according to indices.
Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices with shape
`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., X_{N-1})`,
where `M <= N`. If `M == N`, data shape should simply be `(Y_0, ..., Y_{K-1})`.
The elements in output is defined as follows::
output[indices[0, y_0, ..., y_{K-1}],
...,
indices[M-1, y_0, ..., y_{K-1}],
x_M, ..., x_{N-1}] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
all other entries in output are 0.
.. warning::
If the indices have duplicates, the result will be non-deterministic and
the gradient of `scatter_nd` will not be correct!!
Examples::
data = [2, 3, 0]
indices = [[1, 1, 0], [0, 1, 0]]
shape = (2, 2)
scatter_nd(data, indices, shape) = [[0, 0], [2, 3]]
data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
indices = [[0, 1], [1, 1]]
shape = (2, 2, 2, 2)
scatter_nd(data, indices, shape) = [[[[0, 0],
[0, 0]],
[[1, 2],
[3, 4]]],
[[[0, 0],
[0, 0]],
[[5, 6],
[7, 8]]]]
)code")
.set_num_outputs(1)
.set_num_inputs(2)
.set_attr_parser(ParamParser<ScatterNDParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "indices"};
})
.set_attr<mxnet::FInferShape>("FInferShape", ScatterNDShape)
.set_attr<nnvm::FInferType>("FInferType", ScatterNDType)
.set_attr<FCompute>("FCompute<cpu>", ScatterNDForward<cpu>)
.set_attr<nnvm::FGradient>(
"FGradient",
[](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get("gather_nd");
p->attrs.name = n->attrs.name + "_backward";
p->inputs.push_back(ograds[0]);
p->inputs.push_back(n->inputs[1]);
p->control_deps.emplace_back(n);
auto zero = MakeNode(
"zeros_like", n->attrs.name + "_backward_indices", {n->inputs[1]}, nullptr, &n);
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(p);
ret.emplace_back(zero);
return ret;
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.add_argument("data", "NDArray-or-Symbol", "data")
.add_argument("indices", "NDArray-or-Symbol", "indices")
.add_arguments(ScatterNDParam::__FIELDS__());
NNVM_REGISTER_OP(_backward_gather_nd)
.describe(R"code(Accumulates data according to indices and get the result. It's the backward of
`gather_nd`.
Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices with shape
`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., X_{N-1})`,
where `M <= N`. If `M == N`, data shape should simply be `(Y_0, ..., Y_{K-1})`.
The elements in output is defined as follows::
output[indices[0, y_0, ..., y_{K-1}],
...,
indices[M-1, y_0, ..., y_{K-1}],
x_M, ..., x_{N-1}] += data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
all other entries in output are 0 or the original value if AddTo is triggered.
Examples::
data = [2, 3, 0]
indices = [[1, 1, 0], [0, 1, 0]]
shape = (2, 2)
_backward_gather_nd(data, indices, shape) = [[0, 0], [2, 3]] # Same as scatter_nd
# The difference between scatter_nd and scatter_nd_acc is the latter will accumulate
# the values that point to the same index.
data = [2, 3, 0]
indices = [[1, 1, 0], [1, 1, 0]]
shape = (2, 2)
_backward_gather_nd(data, indices, shape) = [[0, 0], [0, 5]]
)code")
.set_num_outputs(1)
.set_num_inputs(2)
.set_attr_parser(ParamParser<ScatterNDParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "indices"};
})
.set_attr<mxnet::FInferShape>("FInferShape", ScatterNDShape)
.set_attr<nnvm::FInferType>("FInferType", ScatterNDType)
.set_attr<FCompute>("FCompute<cpu>", GatherNDBackward<cpu>)
.set_attr<nnvm::FGradient>(
"FGradient",
[](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get("gather_nd");
p->attrs.name = n->attrs.name + "_backward";
p->inputs.push_back(ograds[0]);
p->inputs.push_back(n->inputs[1]);
p->control_deps.emplace_back(n);
auto zero = MakeNode(
"zeros_like", n->attrs.name + "_backward_indices", {n->inputs[1]}, nullptr, &n);
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(p);
ret.emplace_back(zero);
return ret;
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.add_argument("data", "NDArray-or-Symbol", "data")
.add_argument("indices", "NDArray-or-Symbol", "indices")
.add_arguments(ScatterNDParam::__FIELDS__());
NNVM_REGISTER_OP(_scatter_set_nd)
.add_alias("_npi_scatter_set_nd")
.describe(R"code(This operator has the same functionality as scatter_nd
except that it does not reset the elements not indexed by the input
index `NDArray` in the input data `NDArray`. output should be explicitly
given and be the same as lhs.
.. note:: This operator is for internal use only.
Examples::
data = [2, 3, 0]
indices = [[1, 1, 0], [0, 1, 0]]
out = [[1, 1], [1, 1]]
_scatter_set_nd(lhs=out, rhs=data, indices=indices, out=out)
out = [[0, 1], [2, 3]]
)code")
.set_num_outputs(1)
.set_num_inputs(3)
.set_attr_parser(ParamParser<ScatterNDParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"lhs", "rhs", "indices"};
})
.set_attr<mxnet::FInferShape>(
"FInferShape",
[](const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 1U);
SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
mxnet::ShapeVector tmp_in_attrs = {in_attrs->at(1), in_attrs->at(2)};
if (!ScatterNDShape(attrs, &tmp_in_attrs, out_attrs)) {
return false;
}
SHAPE_ASSIGN_CHECK(*in_attrs, 1, tmp_in_attrs[0]);
SHAPE_ASSIGN_CHECK(*in_attrs, 2, tmp_in_attrs[1]);
SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
return true;
})
.set_attr<nnvm::FInferType>(
"FInferType",
[](const nnvm::NodeAttrs& attrs, std::vector<int>* in_attrs, std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
std::vector<int> tmp_in_attrs = {in_attrs->at(1), in_attrs->at(2)};
if (!ScatterNDType(attrs, &tmp_in_attrs, out_attrs)) {
return false;
}
TYPE_ASSIGN_CHECK(*in_attrs, 1, tmp_in_attrs[0]);
TYPE_ASSIGN_CHECK(*in_attrs, 2, tmp_in_attrs[1]);
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
return true;
})
.set_attr<FCompute>("FCompute<cpu>", ScatterSetNDForward<cpu>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
[](const NodeAttrs& attrs) {
return std::vector<bool>{true};
})
.add_argument("lhs", "NDArray-or-Symbol", "source input")
.add_argument("rhs", "NDArray-or-Symbol", "value to assign")
.add_argument("indices", "NDArray-or-Symbol", "indices")
.add_arguments(ScatterNDParam::__FIELDS__());
} // namespace op
} // namespace mxnet