blob: 42ae43cc7584ba182859096beb0b9689986fee12 [file] [log] [blame]
/*!
* Copyright (c) 2017 by Contributors
* \file sort_op.h
* \brief SortByKey function
*/
#ifndef MXNET_OPERATOR_TENSOR_SORT_OP_H_
#define MXNET_OPERATOR_TENSOR_SORT_OP_H_
#include <dmlc/logging.h>
#include <mshadow/tensor.h>
#include <vector>
#include <type_traits>
namespace mxnet {
namespace op {
/*!
* \brief CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!)
* \param keys the keys to sort
* \param values the values that sorts w.r.t the key
* \param is_ascend whether to sort key in ascending order
*/
template<typename KDType, typename VDType>
inline void SortByKey(mshadow::Tensor<cpu, 1, KDType> keys, mshadow::Tensor<cpu, 1, VDType> values,
bool is_ascend = true, mshadow::Tensor<cpu, 1, char>* workspace = NULL,
const int begin_bit = 0, const int end_bit = sizeof(KDType)*8) {
CHECK_EQ(keys.CheckContiguous(), true);
CHECK_EQ(values.CheckContiguous(), true);
CHECK_EQ(keys.size(0), values.size(0))
<< "The sizes of key/value are not equal! keys_size: " << keys.size(0)
<< "values_size: " << values.size(0);
std::vector<size_t> idx(keys.size(0));
std::vector<KDType> keys_vec(keys.size(0));
std::vector<VDType> values_vec(values.size(0));
for (int i = 0; i < keys.size(0); i++) {
idx[i] = i;
keys_vec[i] = keys[i];
values_vec[i] = values[i];
}
if (is_ascend) {
std::stable_sort(idx.begin(), idx.end(),
[&keys_vec](size_t i1, size_t i2)
{return keys_vec[i1] < keys_vec[i2]; });
} else {
std::stable_sort(idx.begin(), idx.end(),
[&keys_vec](size_t i1, size_t i2)
{return keys_vec[i1] > keys_vec[i2]; });
}
for (index_t i = 0; i < values.size(0); i++) {
keys[i] = keys_vec[idx[i]];
values[i] = values_vec[idx[i]];
}
}
/*!
* \brief CPU/GPU: Return the amount of temporary storage in bytes required for SortByKey
* \param num_keys number of keys to sort
*/
template <typename KDType, typename VDType, typename xpu>
inline typename std::enable_if<std::is_same<xpu, cpu>::value, size_t>::type
SortByKeyWorkspaceSize(const size_t num_keys) {
return 0;
}
/*!
* \brief CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!)
* \param keys the keys to sort
* \param values the values that sorts w.r.t the key
* \param is_ascend whether to sort key in ascending order
*/
template<typename KDType, typename VDType>
inline void SortByKey(mshadow::Tensor<gpu, 1, KDType> keys, mshadow::Tensor<gpu, 1, VDType> values,
bool is_ascend = true, mshadow::Tensor<gpu, 1, char>* workspace = NULL,
const int begin_bit = 0, const int end_bit = sizeof(KDType)*8);
/*!
* \brief CPU/GPU: Return the amount of temporary storage in bytes required for SortByKey
* \param num_keys number of keys to sort
*/
template <typename KDType, typename VDType, typename xpu>
inline typename std::enable_if<std::is_same<xpu, gpu>::value, size_t>::type
SortByKeyWorkspaceSize(const size_t num_keys);
} // namespace op
} // namespace mxnet
#ifdef __CUDACC__
#include "./sort_op-inl.cuh"
#endif
#endif // MXNET_OPERATOR_TENSOR_SORT_OP_H_