blob: 91d7e576d8a385ae7c758747abd53192f0fb999b [file] [log] [blame]
/*!
* Copyright (c) 2017 by Contributors
* \file indexing_op.h
* \brief
* \author Bing Xu, Siyi Li, Chi Zhang
*/
#ifndef MXNET_OPERATOR_TENSOR_INDEXING_OP_H_
#define MXNET_OPERATOR_TENSOR_INDEXING_OP_H_
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <mxnet/operator_util.h>
#include <map>
#include <vector>
#include <string>
#include <utility>
#include <algorithm>
#include <type_traits>
#include "../operator_common.h"
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
#include "../mxnet_op.h"
#include "./sort_op.h"
namespace mxnet {
namespace op {
namespace embedding {
enum EmbeddingOpInputs {kData, kWeight};
enum EmbeddingOpOutputs {kOut};
enum EmbeddingOpResource {kTempSpace};
} // namespace embedding
struct EmbeddingParam: public dmlc::Parameter<EmbeddingParam> {
int input_dim;
int output_dim;
DMLC_DECLARE_PARAMETER(EmbeddingParam) {
DMLC_DECLARE_FIELD(input_dim).set_lower_bound(1)
.describe("vocabulary size of the input indices.");
DMLC_DECLARE_FIELD(output_dim).set_lower_bound(1)
.describe("dimension of the embedding vectors.");
}
};
/*!
* \brief CPU/GPU: Return the amount of temporary storage in bytes required by
AddTakeGradLargeBatch
* \param num_items number of keys
*/
template <typename IndexType, typename xpu>
inline typename std::enable_if<std::is_same<xpu, cpu>::value, size_t>::type
AddTakeGradLargeBatchWorkspaceSize(size_t num_keys) {
return 0;
}
/*!
* \brief CPU/GPU: Return the amount of temporary storage in bytes required by
AddTakeGradLargeBatch
* \param num_items number of keys
*/
template <typename IndexType, typename xpu>
inline typename std::enable_if<std::is_same<xpu, gpu>::value, size_t>::type
AddTakeGradLargeBatchWorkspaceSize(size_t num_keys);
/*!
* \brief CPU/GPU: Gradient accumulate of embedding matrix.
dst[sorted[i]] += src[index[i]]
Called when the batchsize of src is larger than the featuredim
* \param dst destination
* \param sorted the sorted indices
* \param index original index of the sorted indices
* \param src source output
* \param workspace (optional) temporary storage
*/
template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(mshadow::Tensor<cpu, 2, DType> dst,
const mshadow::Tensor<cpu, 1, IndexType>& sorted,
const mshadow::Tensor<cpu, 1, IndexType>& index,
const mshadow::Tensor<cpu, 2, DType> &src,
mshadow::Tensor<cpu, 1, char>* workspace = NULL) {
for (index_t y = 0; y < sorted.size(0); ++y) {
dst[sorted[y]] += src[index[y]];
}
}
/*!
* \brief CPU/GPU: Gradient accumulate of embedding matrix.
dst[sorted[i]] += src[index[i]]
Called when the batchsize of src is larger than the featuredim
* \param dst destination
* \param sorted the sorted indices
* \param index original index of the sorted indices
* \param src source output
* \param workspace (optional) temporary storage
*/
template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, DType> dst,
const mshadow::Tensor<gpu, 1, IndexType>& sorted,
const mshadow::Tensor<gpu, 1, IndexType>& index,
const mshadow::Tensor<gpu, 2, DType> &src,
mshadow::Tensor<gpu, 1, char>* workspace = NULL);
inline bool EmbeddingOpShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
using namespace mshadow;
const TShape &dshape = (*in_attrs)[embedding::kData];
if (dshape.ndim() == 0) return false;
const EmbeddingParam& param = nnvm::get<EmbeddingParam>(attrs.parsed);
SHAPE_ASSIGN_CHECK(*in_attrs, embedding::kWeight, Shape2(param.input_dim,
param.output_dim));
out_attrs->clear();
TShape oshape(dshape.ndim()+1);
for (size_t i = 0; i < dshape.ndim(); ++i) {
oshape[i] = dshape[i];
}
oshape[dshape.ndim()] = param.output_dim;
out_attrs->push_back(oshape);
return true;
}
inline bool EmbeddingOpType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type) {
CHECK_GE(in_type->size(), 1U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
for (index_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
CHECK_EQ((*in_type)[i], dtype) << "This layer requires uniform type. "
<< "Expected " << dtype << " v.s. given "
<< (*in_type)[i];
}
}
out_type->clear();
out_type->push_back(dtype);
return true;
}
template<typename xpu>
void EmbeddingOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(req[embedding::kOut], kWriteTo);
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(inputs[embedding::kWeight].ndim(), 2U)
<< "Embedding layer expects its weight to be two-dimensional. "
<< inputs[embedding::kWeight].ndim()
<< " dimensional input is given instead";
const TShape& ishape = inputs[embedding::kData].shape_;
const TShape& oshape = outputs[embedding::kOut].shape_;
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> data = inputs[embedding::kData].get_with_shape<xpu, 1, DType>(
Shape1(ishape.ProdShape(0, ishape.ndim())), s);
Tensor<xpu, 2, DType> wmat = inputs[embedding::kWeight].get<xpu, 2, DType>(s);
Tensor<xpu, 2, DType> out = outputs[embedding::kOut].get_with_shape<xpu, 2, DType>(
Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s);
out = take(data, wmat);
});
}
// Returns integer log2(a) rounded up
inline int ilog2(unsigned int a) {
int k = 1;
while (a >>= 1) k++;
return k;
}
template<typename xpu, typename IndexType, typename DType>
void AddTakeGradLargeBatchCaller(const OpContext& ctx, mshadow::Tensor<xpu, 2, DType> dst,
const mshadow::Tensor<xpu, 1, IndexType>& index,
const mshadow::Tensor<xpu, 2, DType> &src) {
using namespace mshadow;
using namespace mshadow::expr;
Stream<xpu> *s = ctx.get_stream<xpu>();
// Calculate amount of temporary storage
size_t sort_workspace_size = mxnet::op::SortByKeyWorkspaceSize<int, int, xpu>
(index.shape_.Size());
size_t addtake_workspace_size = mxnet::op::AddTakeGradLargeBatchWorkspaceSize<int, xpu>
(index.shape_.Size());
size_t temp_storage_size = std::max(sort_workspace_size, addtake_workspace_size);
size_t workspace_size = 2*(index.shape_.Size()*sizeof(int)) + temp_storage_size;
// Request temporary storage
Tensor<xpu, 1, char> workspace =
ctx.requested[embedding::kTempSpace].get_space_typed<xpu, 1, char>(
Shape1(workspace_size), s);
// Create tensors
size_t pos = 0;
Tensor<xpu, 1, int> sorted_data(reinterpret_cast<int*>(&workspace[pos]),
Shape1(index.shape_.Size()), s);
pos += index.shape_.Size()*sizeof(int);
Tensor<xpu, 1, int> original_index(reinterpret_cast<int*>(&workspace[pos]),
Shape1(index.shape_.Size()), s);
pos += index.shape_.Size()*sizeof(int);
Tensor<xpu, 1, char> temp_storage(&workspace[pos], Shape1(temp_storage_size), s);
sorted_data = tcast<int>(index);
original_index = range<int>(0, index.shape_.Size());
int num_bits = ilog2((dst.shape_[0] - 1));
mxnet::op::SortByKey(sorted_data, original_index, true, &temp_storage, 0, num_bits);
mxnet::op::AddTakeGradLargeBatch(dst, sorted_data, original_index, src, &temp_storage);
}
template<typename xpu>
void EmbeddingOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 2U);
CHECK_EQ(req[embedding::kData], kNullOp)
<< "Embedding layer doesn't support calculate data gradient";
const TShape& ishape = inputs[1].shape_;
const TShape& oshape = inputs[0].shape_;
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor < xpu, 1, DType > data = inputs[1].get_with_shape<xpu, 1, DType>(
Shape1(ishape.ProdShape(0, ishape.ndim())), s);
Tensor<xpu, 2, DType> grad_out = inputs[0].get_with_shape<xpu, 2, DType>(
Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s);
Tensor<xpu, 2, DType> grad_in = outputs[1].get<xpu, 2, DType>(s);
if (req[embedding::kWeight] == kWriteTo || req[embedding::kWeight] == kAddTo) {
if (req[embedding::kWeight] == kWriteTo) {
grad_in = scalar<DType>(0.0f);
}
// shape_out_prod ~= the number of elements loaded in AddTakeGrad
// shape_in_prod ~= the number of elements stored in AddTakeGrad
// When the number of elements processed is low, use AddTakeGrad.
// The approximate cut-off value 16384 was found experimentally on Titan X Pascal
uint64_t shape_in_prod =
static_cast<uint64_t>(grad_in.shape_[0])*
static_cast<uint64_t>(grad_in.shape_[1]);
uint64_t shape_out_prod =
static_cast<uint64_t>(grad_out.shape_[0])*
static_cast<uint64_t>(grad_out.shape_[1]);
if (shape_out_prod < (uint64_t)16384 && shape_in_prod < (uint64_t)16384) {
AddTakeGrad(grad_in, data, grad_out);
} else {
AddTakeGradLargeBatchCaller(ctx, grad_in, data, grad_out);
}
} else {
LOG(FATAL) << "wrong req";
}
});
}
namespace take_ { // to avoid name conflict
enum TakeOpInputs {kArr, kIdx};
enum TakeOpOutputs {kOut};
enum TakeOpResource {kTempSpace};
enum TakeOpMode {kRaise, kWrap, kClip};
} // namespace take_
// TODO(somebody): behaviors specified by params
struct TakeParam: public dmlc::Parameter<TakeParam> {
int axis;
int mode;
DMLC_DECLARE_PARAMETER(TakeParam) {
DMLC_DECLARE_FIELD(axis)
.set_lower_bound(0)
.set_default(0)
.describe("the axis of data tensor to be taken.");
DMLC_DECLARE_FIELD(mode)
.add_enum("raise", take_::kRaise)
.add_enum("wrap", take_::kWrap)
.add_enum("clip", take_::kClip)
.set_default(take_::kRaise)
.describe("specify how out-of-bound indices bahave.");
}
};
template<typename PType>
inline void TakeParamParser(nnvm::NodeAttrs *attrs) {
PType param;
param.Init(attrs->dict);
if (param.axis != 0) {
LOG(FATAL) << "Axis other than 0 currently not supported.";
}
if (param.mode != take_::kRaise) {
LOG(FATAL) << "Mode other than raise currently not supported.";
}
}
inline bool TakeOpShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
using namespace mshadow;
const TShape &arrshape = (*in_attrs)[take_::kArr];
const TShape &idxshape = (*in_attrs)[take_::kIdx];
if (idxshape.ndim() == 0) return false;
out_attrs->clear();
TShape oshape(idxshape.ndim() + arrshape.ndim() - 1);
for (size_t i = 0; i < idxshape.ndim(); ++i) {
oshape[i] = idxshape[i];
}
for (size_t i = 0; i < arrshape.ndim() - 1; i++) {
oshape[i + idxshape.ndim()] = arrshape[i + 1];
}
out_attrs->push_back(oshape);
return true;
}
inline bool TakeOpType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type) {
// using single dtype ("float32") for safety reason
CHECK_GE(in_type->size(), 2U);
int dtype = (*in_type)[1];
CHECK_NE(dtype, -1) << "idx must have specified type";
for (index_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
CHECK_EQ((*in_type)[i], dtype) << "This layer requires uniform type. "
<< "Expected " << dtype << " v.s. given "
<< (*in_type)[i];
}
}
out_type->clear();
out_type->push_back(dtype);
return true;
}
template<typename xpu>
void TakeOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(req[take_::kOut], kWriteTo);
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK_GE(inputs[take_::kArr].ndim(), 2U)
<< "take layer expects its array's size to be at least 2. "
<< inputs[take_::kArr].ndim()
<< " dimensional input is given instead";
const TShape& idxshape = inputs[take_::kIdx].shape_;
const TShape& arrshape = inputs[take_::kArr].shape_;
const TShape& oshape = outputs[take_::kOut].shape_;
int idxndim = idxshape.ndim();
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> idx = inputs[take_::kIdx].get_with_shape<xpu, 1, DType>(
Shape1(idxshape.ProdShape(0, idxndim)), s);
Tensor<xpu, 2, DType> data = inputs[take_::kArr].get_with_shape<xpu, 2, DType>(
Shape2(arrshape[0], arrshape.ProdShape(1, arrshape.ndim())), s);
Tensor<xpu, 2, DType> out = outputs[take_::kOut].get_with_shape<xpu, 2, DType>(
Shape2(oshape.ProdShape(0, idxndim), oshape.ProdShape(idxndim, oshape.ndim())), s);
out = take(idx, data);
});
}
template<typename xpu>
void TakeOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 2U);
CHECK_EQ(req[take_::kIdx], kNullOp)
<< "take layer doesn't support gradient into index";
// inputs are specified in the .cc file, which are the gradients from
// the upper layer and the input index
// outputs are the gradients of inputs in the feed-forward pass
const TShape& idxshape = inputs[1].shape_;
const TShape& arrshape = outputs[0].shape_;
const TShape& oshape = inputs[0].shape_;
int idxndim = idxshape.ndim();
// grad_out is the gradient of the outputs in the feed-forward
// grad_in is the gradient of the inputs in the feed-forward
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> idx = inputs[1].get_with_shape<xpu, 1, DType>(
Shape1(idxshape.ProdShape(0, idxndim)), s);
Tensor<xpu, 2, DType> grad_out = inputs[0].get_with_shape<xpu, 2, DType>(
Shape2(oshape.ProdShape(0, idxndim), oshape.ProdShape(idxndim, oshape.ndim())), s);
Tensor<xpu, 2, DType> grad_in = outputs[0].get_with_shape<xpu, 2, DType>(
Shape2(arrshape[0], arrshape.ProdShape(1, arrshape.ndim())), s);
if (req[take_::kArr] == kWriteTo || req[take_::kArr] == kAddTo) {
if (req[take_::kArr] == kWriteTo) {
grad_in = scalar<DType>(0.0f);
}
// shape_out_prod ~= the number of elements loaded in AddTakeGrad
// shape_in_prod ~= the number of elements stored in AddTakeGrad
// When the number of elements processed is low, use AddTakeGrad.
// The approximate cut-off value 16384 was found experimentally on Titan X Pascal
uint64_t shape_in_prod =
static_cast<uint64_t>(grad_in.shape_[0])*
static_cast<uint64_t>(grad_in.shape_[1]);
uint64_t shape_out_prod =
static_cast<uint64_t>(grad_out.shape_[0])*
static_cast<uint64_t>(grad_out.shape_[1]);
if (shape_out_prod < (uint64_t)16384 && shape_in_prod < (uint64_t)16384) {
AddTakeGrad(grad_in, idx, grad_out);
} else {
AddTakeGradLargeBatchCaller(ctx, grad_in, idx, grad_out);
}
} else {
LOG(FATAL) << "wrong req";
}
});
}
inline bool BatchTakeOpShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U) << "BatchTake op requires two inputs";
if ((*in_attrs)[1].ndim() != 0) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[1]);
} else if ((*out_attrs)[0].ndim() != 0) {
SHAPE_ASSIGN_CHECK(*in_attrs, 1, (*out_attrs)[0]);
}
if ((*in_attrs)[0].ndim() == 0) return false;
CHECK_GE((*in_attrs)[0].ndim(), 2U) << "Data array must have at least 2 dimensional";
if ((*out_attrs)[0].ndim() == 0) return false;
CHECK_EQ((*in_attrs)[0].Size()/(*in_attrs)[0][(*in_attrs)[0].ndim()-1],
(*out_attrs)[0].Size())
<< "Index array's size must be the same as data array's size excluding the first dimension";
return true;
}
inline bool BatchTakeOpType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
if ((*in_attrs)[0] != -1) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
} else if ((*out_attrs)[0] != -1) {
TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]);
}
TYPE_ASSIGN_CHECK(*in_attrs, 1, mshadow::kInt32);
return true;
}
/*! \brief take scalar value from 2d data array */
template<int req>
struct batch_take {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
const int *idx, int M) {
int j = idx[i];
if (j < 0) j = 0;
else if (j >= M) j = M-1;
KERNEL_ASSIGN(out[i], req, a[i*M+j]);
}
};
template<typename xpu>
void BatchTakeOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
using namespace mxnet_op;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<batch_take<req_type>, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<int>(),
inputs[0].Size()/inputs[0].shape_[0]);
});
});
}
/*!
* \brief The parameters of the one_hot operator.
*/
struct OneHotParam : public dmlc::Parameter<OneHotParam> {
int depth;
double on_value;
double off_value;
int axis;
int dtype;
DMLC_DECLARE_PARAMETER(OneHotParam) {
DMLC_DECLARE_FIELD(depth)
.describe("The dimension size at dim = axis.");
DMLC_DECLARE_FIELD(on_value)
.set_default(1.0f)
.describe("The value assigned to the locations represented by indices.");
DMLC_DECLARE_FIELD(off_value)
.set_default(0.0f)
.describe("The value assigned to the locations not represented by indices.");
DMLC_DECLARE_FIELD(dtype)
.set_default(mshadow::kFloat32)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.add_enum("float16", mshadow::kFloat16)
.add_enum("uint8", mshadow::kUint8)
.add_enum("int32", mshadow::kInt32)
.describe("DType of the output");
}
};
inline void GetOneHotParams(const OneHotParam& param, int* depth, double* on_value,
double* off_value, int* dtype) {
*depth = param.depth;
CHECK_GE(*depth, 0) << "Dimension size, depth, must be a non-negative integer";
*on_value = param.on_value;
*off_value = param.off_value;
*dtype = param.dtype;
}
inline bool OneHotOpShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
const OneHotParam& param = nnvm::get<OneHotParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
// The shape of indices
const TShape& ishape = (*in_attrs)[0];
int depth = 0;
double on_value = 1.0;
double off_value = 0.0;
int dtype = mshadow::kFloat32;
GetOneHotParams(param, &depth, &on_value, &off_value, &dtype);
TShape oshape(ishape.ndim() + 1);
for (index_t i = 0; i < ishape.ndim(); ++i) {
oshape[i] = ishape[i];
}
oshape[oshape.ndim()-1] = depth;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
return true;
}
inline bool OneHotOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
int depth = 0;
double on_value = 1.0;
double off_value = 0.0;
int dtype = mshadow::kFloat32;
const OneHotParam& param = nnvm::get<OneHotParam>(attrs.parsed);
GetOneHotParams(param, &depth, &on_value, &off_value, &dtype);
TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kInt32);
TYPE_ASSIGN_CHECK(*out_attrs, 0, dtype);
return true;
}
template<int req>
struct one_hot {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out, const int* indices,
int depth, DType on_value) {
int offset = i * depth;
int j = indices[i];
if (j >= 0 && j < depth) {
KERNEL_ASSIGN(out[offset+j], req, on_value);
}
}
};
template<typename xpu>
void OneHotOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
// The following line is needed to guard the situation when
// an output array is empty on GPU. In that case, out.dptr() = 0x0
if (outputs[0].Size() == 0) return;
int depth = 0;
double on_value = 1.0;
double off_value = 0.0;
int dtype = mshadow::kFloat32;
const OneHotParam& param = nnvm::get<OneHotParam>(attrs.parsed);
GetOneHotParams(param, &depth, &on_value, &off_value, &dtype);
using namespace mxnet_op;
using namespace mshadow::expr;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
mshadow::Tensor<xpu, 1, DType> out = outputs[0].FlatTo1D<xpu, DType>(s);
ASSIGN_DISPATCH(out, req[0], static_cast<DType>(off_value));
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<one_hot<req_type>, xpu>::Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<int>(), depth,
static_cast<DType>(on_value));
});
});
}
} // namespace op
} // namespace mxnet
#ifdef __CUDACC__
#include "./indexing_op-inl.cuh"
#endif
#endif // MXNET_OPERATOR_TENSOR_INDEXING_OP_H_