blob: 0ccbe410a1e1a6076dd0ab48b88dce33b9f7a422 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2016 by Contributors
* \file ordering_op-inl.h
* \brief Function definition of ordering operators
*/
#ifndef MXNET_OPERATOR_TENSOR_ORDERING_OP_INL_H_
#define MXNET_OPERATOR_TENSOR_ORDERING_OP_INL_H_
#include <mxnet/operator_util.h>
#include <dmlc/optional.h>
#include <mshadow/tensor.h>
#include <algorithm>
#include <vector>
#include <type_traits>
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
#include "./sort_op.h"
#include "./indexing_op.h"
namespace mshadow {
template<typename xpu, int src_dim, typename DType, int dst_dim>
inline Tensor<xpu, dst_dim, DType> inplace_reshape(Tensor<xpu, src_dim, DType> src,
Shape<dst_dim> target_shape) {
CHECK_EQ(src.CheckContiguous(), true);
return Tensor<xpu, dst_dim, DType>(src.dptr_, target_shape, src.stream_);
}
};
namespace mxnet {
namespace op {
// These enums are only visible within this header
namespace topk_enum {
enum TopKReturnType {kReturnValue, kReturnIndices, kReturnMask, kReturnBoth};
} // topk_enum
struct TopKParam : public dmlc::Parameter<TopKParam> {
dmlc::optional<int> axis;
int k;
int ret_typ;
bool is_ascend;
int dtype;
DMLC_DECLARE_PARAMETER(TopKParam) {
DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<int>(-1))
.describe("Axis along which to choose the top k indices."
" If not given, the flattened array is used. Default is -1.");
DMLC_DECLARE_FIELD(k).set_default(1)
.describe("Number of top elements to select,"
" should be always smaller than or equal to the element number in the given axis."
" A global sort is performed if set k < 1.");
DMLC_DECLARE_FIELD(ret_typ).set_default(topk_enum::kReturnIndices)
.add_enum("value", topk_enum::kReturnValue)
.add_enum("indices", topk_enum::kReturnIndices)
.add_enum("mask", topk_enum::kReturnMask)
.add_enum("both", topk_enum::kReturnBoth)
.describe("The return type.\n"
" \"value\" means to return the top k values,"
" \"indices\" means to return the indices of the top k values,"
" \"mask\" means to return a mask array containing 0 and 1. 1 means the top k values."
" \"both\" means to return a list of both values and indices of top k elements.");
DMLC_DECLARE_FIELD(is_ascend).set_default(false)
.describe("Whether to choose k largest or k smallest elements."
" Top K largest elements will be chosen if set to false.");
DMLC_DECLARE_FIELD(dtype)
// TODO(srivrohi): remove support for real data type in mxnet-2.0
.add_enum("uint8", mshadow::kUint8)
.add_enum("int32", mshadow::kInt32)
.add_enum("int64", mshadow::kInt64)
.add_enum("float16", mshadow::kFloat16)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.set_default(mshadow::kFloat32)
.describe("DType of the output indices when ret_typ is \"indices\" or \"both\". "
"An error will be raised if the selected data type cannot precisely represent the "
"indices.");
}
};
struct SortParam : public dmlc::Parameter<SortParam> {
dmlc::optional<int> axis;
bool is_ascend;
DMLC_DECLARE_PARAMETER(SortParam) {
DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<int>(-1))
.describe("Axis along which to choose sort the input tensor."
" If not given, the flattened array is used. Default is -1.");
DMLC_DECLARE_FIELD(is_ascend).set_default(true)
.describe("Whether to sort in ascending or descending order.");
}
};
struct ArgSortParam : public dmlc::Parameter<ArgSortParam> {
dmlc::optional<int> axis;
bool is_ascend;
int dtype;
DMLC_DECLARE_PARAMETER(ArgSortParam) {
DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<int>(-1))
.describe("Axis along which to sort the input tensor."
" If not given, the flattened array is used. Default is -1.");
DMLC_DECLARE_FIELD(is_ascend).set_default(true)
.describe("Whether to sort in ascending or descending order.");
DMLC_DECLARE_FIELD(dtype)
// TODO(srivrohi): remove support for real data type in mxnet-2.0
.add_enum("uint8", mshadow::kUint8)
.add_enum("int32", mshadow::kInt32)
.add_enum("int64", mshadow::kInt64)
.add_enum("float16", mshadow::kFloat16)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.set_default(mshadow::kFloat32)
.describe("DType of the output indices. It is only valid when ret_typ is \"indices\" or"
" \"both\". An error will be raised if the selected data type cannot precisely "
"represent the indices.");
}
};
inline void ParseTopKParam(const TShape& src_shape,
const TopKParam& param,
TShape *target_shape,
size_t *batch_size,
index_t *element_num,
int *axis,
index_t *k,
bool *do_transpose,
bool *is_ascend) {
*do_transpose = false;
*k = param.k;
*is_ascend = param.is_ascend;
// get batch_size, axis and element_num
if (!static_cast<bool>(param.axis)) { // No axis given
*axis = 0;
*batch_size = 1;
*element_num = src_shape.Size();
} else {
*axis = param.axis.value();
if (*axis < 0) {
*axis += src_shape.ndim();
}
CHECK(*axis >= 0 && *axis < static_cast<int>(src_shape.ndim()))
<< "Invalid axis! axis should be between 0 and "
<< src_shape.ndim() << ", found axis=" << *axis;
*batch_size = src_shape.Size() / src_shape[*axis];
*element_num = src_shape[*axis];
if (*axis != src_shape.ndim() - 1) {
*do_transpose = true;
}
}
// get k
if (param.k <= 0) {
*k = *element_num;
}
// get target_shape
if (!static_cast<bool>(param.axis)) {
if (param.ret_typ != topk_enum::kReturnMask) {
*target_shape = mshadow::Shape1(*k);
} else {
*target_shape = src_shape;
}
} else {
*target_shape = src_shape;
if (param.ret_typ != topk_enum::kReturnMask) {
(*target_shape)[*axis] = *k;
}
}
CHECK(*k >= 1 && *k <= *element_num) << "k must be smaller than "
<< *element_num << ", get k = " << *k;
}
using namespace mshadow;
struct fill_ind_to_one {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, const index_t* indices, DType* out) {
out[indices[i]] = static_cast<DType>(1);
}
};
struct fill_ind {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, const index_t* indices, const DType* val,
int req, DType* out) {
KERNEL_ASSIGN(out[indices[i]], req, val[i]);
}
};
template<typename DType>
MSHADOW_FORCE_INLINE void TopKSort(const Tensor<cpu, 1, DType>& dat,
const Tensor<cpu, 1, index_t>& ind,
const Tensor<cpu, 1, char>& work,
index_t K, index_t N, bool is_ascend,
Stream<cpu> *s) {
// Use full sort when K is relatively large.
const bool full_sort(K*8 > N);
// Batch size.
const index_t M(work.size(0)/(sizeof(DType)*N));
const int omp_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount());
#pragma omp parallel for num_threads(omp_threads)
for (index_t i = 0; i < M; ++i) {
// Tensor `work` stores the flattened source data, while `dat` stores the sorted result.
DType *vals = reinterpret_cast<DType*>(work.dptr_);
DType *sorted_vals = dat.dptr_+i*N;
index_t *indices = ind.dptr_+i*N;
if (is_ascend) {
if (full_sort) {
std::sort(indices, indices+N,
[&](const index_t& i1, const index_t& i2){
return vals[i1] < vals[i2]; });
} else {
std::partial_sort(indices, indices+K, indices+N,
[&](const index_t& i1, const index_t& i2){
return vals[i1] < vals[i2]; });
}
} else {
if (full_sort) {
std::sort(indices, indices+N,
[&](const index_t& i1, const index_t& i2){
return vals[i1] > vals[i2]; });
} else {
std::partial_sort(indices, indices+K, indices+N,
[&](const index_t& i1, const index_t& i2){
return vals[i1] > vals[i2]; });
}
}
for (index_t j = 0; j < K; ++j) {
sorted_vals[j] = vals[indices[j]];
}
}
}
#ifdef __CUDACC__
template<typename DType>
MSHADOW_XINLINE bool TopKCompare(DType val1, index_t ind1, DType val2, index_t ind2,
bool is_ascend) {
// Negative indices denote undefined values which are considered arbitrary small resp. large.
return (ind2 < 0) || (ind1 >= 0 && ((is_ascend && val1 < val2) || (!is_ascend && val1 > val2)));
}
template<typename DType>
MSHADOW_XINLINE void MergeTopK(index_t K, DType *val1, index_t *ind1, DType *val2, index_t *ind2,
bool is_ascend) {
// In-place merge of two sorted top-K lists into val1/ind1. First determine the intervals
// [0,..,i1], [0,..i2] of the two lists that will be part of the merged list.
index_t i1(K-1), i2(K-1);
for (index_t i = 0; i < K; ++i) {
if (TopKCompare(val1[i1], ind1[i1], val2[i2], ind2[i2], is_ascend)) {
--i2;
} else {
--i1;
}
}
// Now merge the lists from back to front.
for (index_t i = K; i--;) {
if (i2 < 0 || i1 >= 0 && TopKCompare(val2[i2], ind2[i2], val1[i1], ind1[i1], is_ascend)) {
val1[i] = val1[i1];
ind1[i] = ind1[i1];
--i1;
} else {
val1[i] = val2[i2];
ind1[i] = ind2[i2];
--i2;
}
}
}
template<typename DType>
__global__ void PartialSortSmallK(index_t K, index_t N, DType *val, index_t *ind, bool is_ascend) {
// Buffer for pairwise reduction.
extern __shared__ index_t buff[];
// Start of buffer sections associated with this thread.
const index_t offset(threadIdx.x*K);
index_t *ind_buff = &buff[offset];
DType *val_buff = reinterpret_cast<DType*>(&buff[blockDim.x*K])+offset;
// Initialize top-K values for this thread.
for (index_t i = 0; i < K; ++i) {
ind_buff[i] = -1;
}
// Range of values this thread cares about. Each thread block processes
// a different batch item (i.e. a different set of ind/val where we
// have to select the top-K elements). All threads within the same
// block work on the same batch item.
const index_t first(blockIdx.x*N+threadIdx.x), last((blockIdx.x+1)*N);
// Select top-K from this range and store it sorted in the buffer.
// We assume a small K, so linear insertion is o.k.
for (index_t i = first; i < last; i += blockDim.x) {
DType cur_val(val[i]);
index_t cur_ind(ind[i]);
for (index_t j = K; j-- && TopKCompare(cur_val, cur_ind, val_buff[j],
ind_buff[j], is_ascend); ) {
if (j+1 < K) {
val_buff[j+1] = val_buff[j];
ind_buff[j+1] = ind_buff[j];
}
val_buff[j] = cur_val;
ind_buff[j] = cur_ind;
}
}
// Recursive merge of sorted lists for this thread block. Note that blockDim.x is not
// necessary a power of two, therefore the additional checks for last_s.
for (index_t s = (blockDim.x+1)/2, last_s = blockDim.x;
last_s > 1; last_s = s, s = (s+1)/2) {
__syncthreads();
if (threadIdx.x < s && threadIdx.x+s < last_s) {
MergeTopK(K, val_buff, ind_buff, val_buff+s*K, ind_buff+s*K, is_ascend);
}
}
// Final updates on master thread.
if (threadIdx.x == 0) {
for (index_t i = 0; i < K; ++i) {
ind[blockIdx.x*N+i] = ind_buff[i];
val[blockIdx.x*N+i] = val_buff[i];
}
}
}
template<typename DType>
MSHADOW_FORCE_INLINE void TopKSort(const Tensor<gpu, 1, DType>& dat,
const Tensor<gpu, 1, index_t>& ind,
const Tensor<gpu, 1, char>& work,
index_t K, index_t N, bool is_ascend,
Stream<gpu> *s) {
// Use full sort for all but very small K for which we
// can do a partial sort entirely within shared memory.
const bool full_sort(K > 5);
// Batch size.
const index_t M(dat.size(0)/N);
if (full_sort) {
// Divide workspace into two parts. The first one is needed to store batch ids.
size_t alignment = std::max(sizeof(DType), sizeof(index_t));
size_t id_size = PadBytes(sizeof(index_t) * ind.size(0), alignment);
Tensor<gpu, 1, index_t> batch_id(reinterpret_cast<index_t*>(work.dptr_),
Shape1(ind.size(0)), s);
Tensor<gpu, 1, char> sort_work(work.dptr_+id_size, Shape1(work.size(0)-id_size), s);
mxnet::op::SortByKey(dat, ind, is_ascend, &sort_work);
if (M > 1) {
// Back to back sorting. Note that mxnet::op::SortByKey is a stable sort.
batch_id = ind / N;
mxnet::op::SortByKey(batch_id, dat, true, &sort_work);
batch_id = ind / N;
mxnet::op::SortByKey(batch_id, ind, true, &sort_work);
}
} else {
const int nthreads(mshadow::cuda::kBaseThreadNum);
PartialSortSmallK<<<M, nthreads, nthreads*K*(sizeof(int)+sizeof(DType)),
mshadow::Stream<gpu>::GetStream(s)>>>
(K, N, dat.dptr_, ind.dptr_, is_ascend);
}
}
#endif
/*!
* \brief Implementation of the TopK operation
*
*
* \param ctx the running context
* \param resource temporary resource handler
* \param src the Source blob
* \param ret the destination blobs
* \param param the topk parameters
* \tparam xpu the device type.
* \tparam DType type of the output value/mask.
* \tparam IDType type of the output indices.
*/
template<typename xpu, typename DType, typename IDType>
void TopKImpl(const RunContext &ctx,
const Resource &resource,
const std::vector<OpReqType>& req,
const TBlob& src,
const std::vector<TBlob>& ret,
const TopKParam& param) {
using namespace mshadow;
using namespace mshadow::expr;
// 1. Parse and initialize information
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 1, char> workspace;
Tensor<xpu, 1, char> temp_workspace;
Tensor<xpu, 1, DType> sorted_dat;
Tensor<xpu, 1, index_t> indices, sel_indices;
size_t batch_size = 0;
index_t element_num = 0; // number of batches + the size of each batch
int axis = 0;
bool do_transpose = false;
bool is_ascend = false;
index_t k = 0;
size_t alignment = std::max(sizeof(DType), sizeof(index_t));
mxnet::TShape target_shape;
ParseTopKParam(src.shape_, param,
&target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
CHECK_LE(element_num, mxnet::common::MaxIntegerValue<index_t>())
<< "'index_t' does not have a sufficient precision to represent "
<< "the indices of the input array. The total element_num is "
<< element_num << ", but the selected index_t can only represent "
<< mxnet::common::MaxIntegerValue<index_t>() << " elements";
Tensor<xpu, 3, DType> dat = src.FlatTo3D<xpu, DType>(axis, axis, s);
// Temp space needed by the full sorts.
size_t temp_size = std::max(
mxnet::op::SortByKeyWorkspaceSize<index_t, DType, xpu>(src.Size()),
mxnet::op::SortByKeyWorkspaceSize<DType, index_t, xpu>(src.Size()));
temp_size = std::max(temp_size,
mxnet::op::SortByKeyWorkspaceSize<index_t, index_t, xpu>(src.Size()));
// Additional temp space for gpu full sorts for batch ids.
temp_size += PadBytes(sizeof(index_t) * src.Size(), alignment);
// Temp space for cpu sorts.
temp_size = std::max(temp_size, sizeof(DType) * src.Size());
size_t workspace_size = temp_size + PadBytes(sizeof(DType) * src.Size(), alignment)
+ PadBytes(sizeof(index_t) * src.Size(), alignment);
if (param.ret_typ == topk_enum::kReturnMask) {
workspace_size += PadBytes(sizeof(index_t) * batch_size * k, alignment);
}
workspace = resource.get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
char* workspace_curr_ptr = workspace.dptr_;
sorted_dat = Tensor<xpu, 1, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
Shape1(src.Size()), s); // contain sorted dat
workspace_curr_ptr += PadBytes(sizeof(DType) * src.Size(), alignment);
indices = Tensor<xpu, 1, index_t>(reinterpret_cast<index_t*>(workspace_curr_ptr),
Shape1(src.Size()), s); // indices in the original matrix
workspace_curr_ptr += PadBytes(sizeof(index_t) * src.Size(), alignment);
if (param.ret_typ == topk_enum::kReturnMask) {
sel_indices = Tensor<xpu, 1, index_t>(reinterpret_cast<index_t*>(workspace_curr_ptr),
Shape1(batch_size * k), s);
workspace_curr_ptr += PadBytes(sizeof(index_t) * batch_size * k, alignment);
CHECK_EQ(sel_indices.CheckContiguous(), true);
}
if (std::is_same<xpu, cpu>::value) {
Tensor<xpu, 1, DType> flattened_data;
if (do_transpose) {
flattened_data = Tensor<xpu, 1, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
Shape1(src.Size()), s);
workspace_curr_ptr += sizeof(DType) * src.Size();
flattened_data = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
CHECK_EQ(flattened_data.CheckContiguous(), true);
} else {
flattened_data = src.FlatTo1D<xpu, DType>(s);
}
// `temp_workspace` stores the flattened data
temp_workspace = Tensor<xpu, 1, char>(reinterpret_cast<char*>(flattened_data.dptr_),
Shape1(sizeof(DType)*src.Size()), s);
CHECK_EQ(temp_workspace.CheckContiguous(), true);
} else {
if (do_transpose) {
sorted_dat = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
} else {
sorted_dat = reshape(dat, Shape1(src.Size()));
}
CHECK_EQ(sorted_dat.CheckContiguous(), true);
temp_workspace = Tensor<xpu, 1, char>(workspace_curr_ptr, Shape1(temp_size), s); // temp space
workspace_curr_ptr += temp_size;
}
mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size * element_num, 1, index_t{0}, index_t{1},
kWriteTo, indices.dptr_);
CHECK_EQ(indices.CheckContiguous(), true);
// 2. Perform inplace batch sort.
// After sorting, each batch in `sorted_dat` will be sorted in the corresponding order
// up to the k-th element and the `indices` will contain the corresponding index in `sorted_dat`
// `temp_workspace` is used to store the flattend source data for CPU device, and it's used as
// a temporal buffer for GPU device.
TopKSort(sorted_dat, indices, temp_workspace, k, element_num, is_ascend, s);
// 3. Assign results to the ret blob
// When returning indices, only update(modulo) required elements instead of full elements
// to avoid redundant calculation.
// Cast `ret_indices` from int to real_t could introduce conversion error when the element_num
// is large enough.
if (param.ret_typ == topk_enum::kReturnMask) {
Tensor<xpu, 1, DType> ret_mask = ret[0].FlatTo1D<xpu, DType>(s);
ret_mask = scalar<DType>(0);
sel_indices = reshape(slice<1>(
inplace_reshape(indices,
Shape2(batch_size,
element_num)), 0, k),
Shape1(batch_size * k));
if (do_transpose) {
mxnet::TShape src_shape = src.shape_.FlatTo3D(axis);
CHECK_EQ(sel_indices.CheckContiguous(), true);
sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]),
Shape3(0, 2, 1));
}
if (req[0] == kNullOp) {
return;
} else if (req[0] == kWriteTo) {
mxnet_op::Kernel<fill_ind_to_one, xpu>::Launch(s, batch_size * k,
sel_indices.dptr_, ret_mask.dptr_);
} else {
LOG(FATAL) << "req=" << req[0] << " is not supported yet.";
}
} else if (param.ret_typ == topk_enum::kReturnIndices) {
if (do_transpose) {
Tensor<xpu, 3, IDType> ret_indices = ret[0].FlatTo3D<xpu, IDType>(axis, axis, s);
ASSIGN_DISPATCH(ret_indices, req[0], tcast<IDType>(F<mshadow_op::mod>(transpose(
slice<2>(inplace_reshape(indices,
Shape3(ret_indices.shape_[0],
ret_indices.shape_[2],
element_num)),
0, k),
Shape3(0, 2, 1)), element_num)));
} else {
Tensor<xpu, 2, IDType> ret_indices =
ret[0].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
ASSIGN_DISPATCH(ret_indices, req[0], tcast<IDType>(F<mshadow_op::mod>(slice<1>(
inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k),
element_num)));
}
} else {
if (do_transpose) {
Tensor<xpu, 3, DType> ret_value = ret[0].FlatTo3D<xpu, DType>(axis, axis, s);
Tensor<xpu, 3, IDType> ret_indices = ret[1].FlatTo3D<xpu, IDType>(axis, axis, s);
ASSIGN_DISPATCH(ret_value, req[0], transpose(
slice<2>(inplace_reshape(sorted_dat,
Shape3(ret_value.shape_[0], ret_value.shape_[2], element_num)),
0, k), Shape3(0, 2, 1)));
ASSIGN_DISPATCH(ret_indices, req[1], tcast<IDType>(F<mshadow_op::mod>(transpose(
slice<2>(inplace_reshape(indices,
Shape3(ret_indices.shape_[0],
ret_indices.shape_[2],
element_num)),
0, k), Shape3(0, 2, 1)), element_num)));
} else {
Tensor<xpu, 2, DType> ret_value =
ret[0].get_with_shape<xpu, 2, DType>(Shape2(batch_size, k), s);
Tensor<xpu, 2, IDType> ret_indices =
ret[1].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
ASSIGN_DISPATCH(ret_value, req[0],
slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k));
ASSIGN_DISPATCH(ret_indices, req[1], tcast<IDType>(F<mshadow_op::mod>(slice<1>(
inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k), element_num)));
}
}
}
template<typename xpu>
void TopK(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
if (param.ret_typ == topk_enum::kReturnIndices || param.ret_typ == topk_enum::kReturnBoth) {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
TopKImpl<xpu, DType, IDType>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
})
});
} else {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
TopKImpl<xpu, DType, index_t>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
});
}
}
template<typename xpu>
void Sort(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const SortParam& param = nnvm::get<SortParam>(attrs.parsed);
TopKParam topk_param;
topk_param.axis = param.axis;
topk_param.is_ascend = param.is_ascend;
topk_param.k = 0;
topk_param.ret_typ = topk_enum::kReturnValue;
MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
TopKImpl<xpu, DType, index_t>(ctx.run_ctx, ctx.requested[0], req, inputs[0],
outputs, topk_param);
});
}
template<typename xpu>
void ArgSort(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const ArgSortParam& param = nnvm::get<ArgSortParam>(attrs.parsed);
TopKParam topk_param;
topk_param.axis = param.axis;
topk_param.is_ascend = param.is_ascend;
topk_param.k = 0;
topk_param.dtype = param.dtype;
topk_param.ret_typ = topk_enum::kReturnIndices;
MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
TopKImpl<xpu, DType, IDType>(ctx.run_ctx,
ctx.requested[0], req, inputs[0], outputs, topk_param);
});
});
}
template<typename xpu, typename DType, typename IDType>
void TopKBackwardImpl(const OpContext &ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
const TopKParam& param) {
CHECK_NE(req[0], kWriteInplace);
using namespace mshadow;
using namespace mshadow::expr;
Stream<xpu> *s = ctx.run_ctx.get_stream<xpu>();
CHECK(param.ret_typ == topk_enum::kReturnValue || param.ret_typ == topk_enum::kReturnBoth);
size_t batch_size = 0;
index_t element_num = 0; // number of batches + the size of each batch
int axis = 0;
bool do_transpose = false;
bool is_ascend = false;
index_t k = 0;
mxnet::TShape target_shape;
ParseTopKParam(outputs[0].shape_, param,
&target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
CHECK_LE(element_num, mxnet::common::MaxIntegerValue<IDType>())
<< "'IDType' does not have a sufficient precision to represent "
<< "the indices of the input array. The total element_num is " << element_num
<< ", but the selected index_t can only represent "
<< mxnet::common::MaxIntegerValue<IDType>() << " elements";
Tensor<xpu, 1, index_t> workspace =
ctx.requested[0].get_space_typed<xpu, 1, index_t>(Shape1(batch_size * k + batch_size), s);
Tensor<xpu, 1, index_t> sel_indices =
Tensor<xpu, 1, index_t>(workspace.dptr_, Shape1(batch_size * k), s);
Tensor<xpu, 1, index_t> batch_shift =
Tensor<xpu, 1, index_t>(workspace.dptr_ + batch_size * k, Shape1(batch_size), s);
Tensor<xpu, 2, DType> out_grad =
inputs[0].get_with_shape<xpu, 2, DType>(Shape2(inputs[0].shape_.Size(), 1), s);
Tensor<xpu, 2, DType> in_grad =
outputs[0].get_with_shape<xpu, 2, DType>(Shape2(outputs[0].shape_.Size(), 1), s);
mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size, 1, index_t{0}, element_num, kWriteTo,
batch_shift.dptr_);
if (do_transpose) {
Tensor<xpu, 1, IDType> indices = inputs[2].FlatTo1D<xpu, IDType>(s);
mxnet::TShape src_shape = outputs[0].shape_.FlatTo3D(axis);
sel_indices = reshape(transpose(
broadcast_to(inplace_reshape(batch_shift,
Shape3(src_shape[0], src_shape[2], 1)),
mxnet::TShape(Shape3(src_shape[0], src_shape[2], k))),
Shape3(0, 2, 1)),
Shape1(batch_size * k));
sel_indices += tcast<index_t>(indices);
sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]),
Shape3(0, 2, 1));
} else {
Tensor<xpu, 2, IDType> indices =
inputs[2].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
sel_indices = reshape(tcast<index_t>(indices) +
broadcast_to(inplace_reshape(batch_shift, Shape2(batch_size, 1)),
mxnet::TShape(Shape2(batch_size, k))),
Shape1(batch_size * k));
}
CHECK_EQ(sel_indices.CheckContiguous(), true);
if (kWriteTo == req[0] || kAddTo == req[0]) {
if (kWriteTo == req[0]) {
in_grad = scalar<DType>(0);
}
mxnet_op::Kernel<fill_ind, xpu>::Launch(s, batch_size * k,
sel_indices.dptr_,
out_grad.dptr_,
req[0],
in_grad.dptr_);
} else {
LOG(FATAL) << "Not Implemented!";
}
}
template<typename xpu>
void TopKBackward_(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
if (param.ret_typ == topk_enum::kReturnBoth) {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
TopKBackwardImpl<xpu, DType, IDType>(ctx, inputs, req, outputs, param);
});
});
} else if (param.ret_typ == topk_enum::kReturnValue) {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
TopKBackwardImpl<xpu, DType, index_t>(ctx, inputs, req, outputs, param);
});
} else {
LOG(FATAL) << "Not Implemented";
}
}
inline uint32_t TopKNumOutputs(const NodeAttrs& attrs) {
const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
if (param.ret_typ == topk_enum::kReturnIndices ||
param.ret_typ == topk_enum::kReturnMask) {
return static_cast<uint32_t>(1);
} else {
return static_cast<uint32_t>(2);
}
}
inline uint32_t TopKNumVisibleOutputs(const NodeAttrs& attrs) {
const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
if (param.ret_typ == topk_enum::kReturnBoth) {
return static_cast<uint32_t>(2);
} else {
return static_cast<uint32_t>(1);
}
}
inline bool TopKType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
size_t in_size = in_attrs->size();
size_t out_size = out_attrs->size();
CHECK_EQ(in_size, 1);
CHECK(out_size == 1 || out_size == 2);
// out_attr[0] -> stores value
// out_attr[1] -> stores indices
if (out_size > 1) {
if (param.ret_typ == topk_enum::kReturnValue) {
#if MXNET_USE_INT64_TENSOR_SIZE == 1
CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt64))
#else
CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32))
#endif
<< "Failed to set the type of ret_indices.";
} else {
CHECK(type_assign(&(*out_attrs)[1], param.dtype))
<< "Failed to set the type of ret_indices.";
}
}
if (param.ret_typ == topk_enum::kReturnIndices) {
CHECK(type_assign(&(*out_attrs)[0], param.dtype))
<< "Failed to set the type of ret_indices.";
} else {
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
return out_attrs->at(0) != -1;
}
return true;
}
inline bool TopKShapeImpl(const TopKParam& param,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
if (param.ret_typ == topk_enum::kReturnIndices ||
param.ret_typ == topk_enum::kReturnMask) {
CHECK_EQ(out_attrs->size(), 1U);
} else {
CHECK_EQ(out_attrs->size(), 2U);
}
mxnet::TShape& in_shape = (*in_attrs)[0];
size_t batch_size = 0;
index_t element_num = 0; // number of batches + the size of each batch
int axis = 0;
bool do_transpose = false;
bool is_ascend = false;
index_t k = 0;
mxnet::TShape target_shape;
ParseTopKParam(in_shape, param,
&target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
if (param.ret_typ == topk_enum::kReturnIndices ||
param.ret_typ == topk_enum::kReturnMask) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, target_shape);
} else {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, target_shape);
SHAPE_ASSIGN_CHECK(*out_attrs, 1, target_shape);
}
return true;
}
inline bool TopKShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
return TopKShapeImpl(param, in_attrs, out_attrs);
}
inline bool SortType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
int data_type = -1;
size_t in_size = in_attrs->size();
size_t out_size = out_attrs->size();
CHECK_EQ(in_size, 1);
CHECK_EQ(out_size, 2);
#if MXNET_USE_INT64_TENSOR_SIZE == 1
CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt64))
#else
CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32))
#endif
<< "Failed to set the type of ret_indices";
CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]="
<< (*in_attrs)[0];
CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]="
<< (*out_attrs)[0];
CHECK(type_assign(&(*in_attrs)[0], data_type)) << "Incompatible dtype of input, in_attrs[0]="
<< (*in_attrs)[0];
CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]="
<< (*out_attrs)[0];
if (data_type == -1) return false;
return true;
}
inline bool SortShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const SortParam& param = nnvm::get<SortParam>(attrs.parsed);
TopKParam topk_param;
topk_param.axis = param.axis;
topk_param.is_ascend = param.is_ascend;
topk_param.k = 0;
topk_param.ret_typ = topk_enum::kReturnValue;
return TopKShapeImpl(topk_param, in_attrs, out_attrs);
}
inline bool ArgSortType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const ArgSortParam& param = nnvm::get<ArgSortParam>(attrs.parsed);
CHECK(type_assign(&(*out_attrs)[0], param.dtype))
<< "Failed to set the type of ret_indices.";
return true;
}
inline bool ArgSortShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const ArgSortParam& param = nnvm::get<ArgSortParam>(attrs.parsed);
TopKParam topk_param;
topk_param.axis = param.axis;
topk_param.is_ascend = param.is_ascend;
topk_param.k = 0;
topk_param.ret_typ = topk_enum::kReturnIndices;
return TopKShapeImpl(topk_param, in_attrs, out_attrs);
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_TENSOR_ORDERING_OP_INL_H_