| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file Use standard C library call. |
| */ |
| |
| #include <dlpack/dlpack.h> |
| #include <tvm/ffi/error.h> |
| #include <tvm/ffi/function.h> |
| #include <tvm/ffi/reflection/registry.h> |
| #include <tvm/runtime/data_type.h> |
| |
| #include <algorithm> |
| #include <vector> |
| |
| #include "../../../../../3rdparty/compiler-rt/builtin_fp16.h" |
| |
| namespace tvm { |
| namespace contrib { |
| |
| using namespace runtime; |
| |
| template <typename DType, bool stable_comparison = false> |
| bool CompareAscend(const std::pair<int64_t, DType>& lhs, const std::pair<int64_t, DType>& rhs) { |
| if constexpr (stable_comparison) { |
| if (lhs.second == rhs.second) { |
| return lhs.first < rhs.first; |
| } |
| } |
| |
| return lhs.second < rhs.second; |
| } |
| |
| template <typename DType, bool stable_comparison = false> |
| bool CompareDescend(const std::pair<int64_t, DType>& lhs, const std::pair<int64_t, DType>& rhs) { |
| if constexpr (stable_comparison) { |
| if (lhs.second == rhs.second) { |
| return lhs.first < rhs.first; |
| } |
| } |
| |
| return lhs.second > rhs.second; |
| } |
| |
| struct float16 { |
| uint16_t bits; |
| float to_float() const { |
| return __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(bits); |
| } |
| |
| inline bool operator==(const float16& rhs) const { return to_float() == rhs.to_float(); } |
| inline bool operator!=(const float16& rhs) const { return to_float() != rhs.to_float(); } |
| inline bool operator<(const float16& rhs) const { return to_float() < rhs.to_float(); } |
| inline bool operator>(const float16& rhs) const { return to_float() > rhs.to_float(); } |
| inline bool operator<=(const float16& rhs) const { return to_float() <= rhs.to_float(); } |
| inline bool operator>=(const float16& rhs) const { return to_float() >= rhs.to_float(); } |
| }; |
| |
| // Argsort implemented C library sort for nms. |
| // Return indices of sorted tensor. |
| // By default, the last axis will be used to sort. |
| // sort_num specify the number of elements to be sorted. |
| // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) |
| // and sort axis is dk. sort_num should have dimension of |
| // (d1, d2, ..., d(k-1), d(k+1), ..., dn). |
| void RegisterArgsortNMS() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def_packed( |
| "tvm.contrib.sort.argsort_nms", [](ffi::PackedArgs args, ffi::Any* ret) { |
| auto input = args[0].cast<DLTensor*>(); |
| auto sort_num = args[1].cast<DLTensor*>(); |
| auto output = args[2].cast<DLTensor*>(); |
| int32_t axis = args[3].cast<int>(); |
| bool is_ascend = args[4].cast<bool>(); |
| |
| auto dtype = input->dtype; |
| auto data_ptr = static_cast<float*>(input->data); |
| auto sort_num_ptr = static_cast<int32_t*>(sort_num->data); |
| std::vector<std::pair<int32_t, float>> sorter; |
| int64_t axis_mul_before = 1; |
| int64_t axis_mul_after = 1; |
| |
| if (axis < 0) { |
| axis = input->ndim + axis; |
| } |
| |
| // Currently only supports input dtype to be float32. |
| TVM_FFI_ICHECK_EQ(dtype.code, 2) << "Currently only supports input dtype " |
| "to be float."; |
| #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC != 1) |
| TVM_FFI_ICHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype " |
| "to be float32."; |
| #endif |
| TVM_FFI_ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " |
| "input ndim " |
| << input->ndim; |
| |
| for (int i = 0; i < input->ndim; ++i) { |
| if (i < axis) { |
| axis_mul_before *= input->shape[i]; |
| } else if (i > axis) { |
| axis_mul_after *= input->shape[i]; |
| } |
| } |
| |
| for (int64_t i = 0; i < axis_mul_before; ++i) { |
| for (int64_t j = 0; j < axis_mul_after; ++j) { |
| sorter.clear(); |
| int32_t current_sort_num = *(sort_num_ptr + i * axis_mul_after + j); |
| int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; |
| for (int64_t k = 0; k < current_sort_num; ++k) { |
| int64_t full_idx = base_idx + k * axis_mul_after; |
| sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx))); |
| } |
| if (is_ascend) { |
| #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) |
| if (dtype.bits == 16) { |
| std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<__fp16>); |
| } else { |
| #endif |
| std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<float>); |
| #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) |
| } |
| #endif |
| } else { |
| #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) |
| if (dtype.bits == 16) { |
| std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<__fp16>); |
| } else { |
| #endif |
| std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>); |
| #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) |
| } |
| #endif |
| } |
| for (int32_t k = 0; k < input->shape[axis]; ++k) { |
| *(static_cast<int32_t*>(output->data) + base_idx + k * axis_mul_after) = |
| k < static_cast<int32_t>(sorter.size()) ? sorter[k].first : k; |
| } |
| } |
| } |
| }); |
| } |
| |
| template <typename DataType, typename OutType> |
| void sort_impl( |
| DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend, |
| std::function<void(OutType*, size_t, const std::pair<int64_t, DataType>&)> epilogue) { |
| auto data_ptr = static_cast<DataType*>(input->data); |
| auto out_ptr = static_cast<OutType*>(output->data); |
| std::vector<std::pair<int64_t, DataType>> sorter; |
| |
| int axis_mul_before = 1; |
| int axis_mul_after = 1; |
| for (int i = 0; i < input->ndim; ++i) { |
| if (i < axis) { |
| axis_mul_before *= input->shape[i]; |
| } else if (i > axis) { |
| axis_mul_after *= input->shape[i]; |
| } |
| } |
| |
| for (int i = 0; i < axis_mul_before; ++i) { |
| for (int j = 0; j < axis_mul_after; ++j) { |
| sorter.clear(); |
| int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; |
| for (int64_t k = 0; k < input->shape[axis]; ++k) { |
| int64_t full_idx = base_idx + k * axis_mul_after; |
| sorter.emplace_back(std::make_pair(k, data_ptr[full_idx])); |
| } |
| if (is_ascend) { |
| std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<DataType>); |
| } else { |
| std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<DataType>); |
| } |
| for (int64_t k = 0; k < input->shape[axis]; ++k) { |
| epilogue(out_ptr, base_idx + k * axis_mul_after, sorter[k]); |
| } |
| } |
| } |
| } |
| |
| template <typename DataType, typename OutType> |
| void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { |
| return sort_impl<DataType, OutType>( |
| input, output, axis, is_ascend, |
| [](OutType* out_ptr, size_t index, const std::pair<int64_t, DataType>& sort_pair) { |
| out_ptr[index] = static_cast<OutType>(sort_pair.first); |
| }); |
| } |
| |
| template <typename DataType> |
| void sort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { |
| return sort_impl<DataType, DataType>( |
| input, output, axis, is_ascend, |
| [](DataType* out_ptr, size_t index, const std::pair<int64_t, DataType>& sort_pair) { |
| out_ptr[index] = sort_pair.second; |
| }); |
| } |
| |
| // Argsort implemented C library sort. |
| // Return indices of sorted tensor. |
| // By default, the last axis will be used to sort. |
| // sort_num specify the number of elements to be sorted. |
| // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) |
| // and sort axis is dk. sort_num should have dimension of |
| // (d1, d2, ..., d(k-1), d(k+1), ..., dn). |
| void RegisterArgsort() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def_packed("tvm.contrib.sort.argsort", [](ffi::PackedArgs args, ffi::Any* ret) { |
| auto input = args[0].cast<DLTensor*>(); |
| auto output = args[1].cast<DLTensor*>(); |
| int32_t axis = args[2].cast<int32_t>(); |
| bool is_ascend = args[3].cast<bool>(); |
| if (axis < 0) { |
| axis = input->ndim + axis; |
| } |
| TVM_FFI_ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " |
| "input ndim " |
| << input->ndim; |
| |
| auto data_dtype = ffi::DLDataTypeToString(input->dtype); |
| auto out_dtype = ffi::DLDataTypeToString(output->dtype); |
| |
| if (data_dtype == "float32") { |
| if (out_dtype == "int32") { |
| argsort<float, int32_t>(input, output, axis, is_ascend); |
| } else if (out_dtype == "int64") { |
| argsort<float, int64_t>(input, output, axis, is_ascend); |
| } else if (out_dtype == "float32") { |
| argsort<float, float>(input, output, axis, is_ascend); |
| } else if (out_dtype == "float64") { |
| argsort<float, double>(input, output, axis, is_ascend); |
| } else { |
| TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; |
| } |
| } else if (data_dtype == "float64") { |
| if (out_dtype == "int32") { |
| argsort<double, int32_t>(input, output, axis, is_ascend); |
| } else if (out_dtype == "int64") { |
| argsort<double, int64_t>(input, output, axis, is_ascend); |
| } else if (out_dtype == "float32") { |
| argsort<double, float>(input, output, axis, is_ascend); |
| } else if (out_dtype == "float64") { |
| argsort<double, double>(input, output, axis, is_ascend); |
| } else { |
| TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; |
| } |
| #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) |
| } else if (data_dtype == "float16") { |
| if (out_dtype == "float16") { |
| argsort<__fp16, __fp16>(input, output, axis, is_ascend); |
| } else { |
| TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; |
| } |
| #endif |
| } else if (data_dtype == "int32") { |
| if (out_dtype == "int32") { |
| argsort<int32_t, int32_t>(input, output, axis, is_ascend); |
| } else if (out_dtype == "int64") { |
| argsort<int32_t, int64_t>(input, output, axis, is_ascend); |
| } else if (out_dtype == "float32") { |
| argsort<int32_t, float>(input, output, axis, is_ascend); |
| } else if (out_dtype == "float64") { |
| argsort<int32_t, double>(input, output, axis, is_ascend); |
| } else { |
| TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; |
| } |
| } else if (data_dtype == "int64") { |
| if (out_dtype == "int32") { |
| argsort<int64_t, int32_t>(input, output, axis, is_ascend); |
| } else if (out_dtype == "int64") { |
| argsort<int64_t, int64_t>(input, output, axis, is_ascend); |
| } else if (out_dtype == "float32") { |
| argsort<int64_t, float>(input, output, axis, is_ascend); |
| } else if (out_dtype == "float64") { |
| argsort<int64_t, double>(input, output, axis, is_ascend); |
| } else { |
| TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; |
| } |
| } else if (data_dtype == "float16") { |
| if (out_dtype == "int32") { |
| argsort<float16, int32_t>(input, output, axis, is_ascend); |
| } else if (out_dtype == "int64") { |
| argsort<float16, int64_t>(input, output, axis, is_ascend); |
| } else if (out_dtype == "float32") { |
| argsort<float16, float>(input, output, axis, is_ascend); |
| } else if (out_dtype == "float64") { |
| argsort<float16, double>(input, output, axis, is_ascend); |
| } else { |
| TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; |
| } |
| } else { |
| TVM_FFI_THROW(InternalError) << "Unsupported input dtype: " << data_dtype; |
| } |
| }); |
| } |
| |
| // Sort implemented C library sort. |
| // Return sorted tensor. |
| // By default, the last axis will be used to sort. |
| // sort_num specify the number of elements to be sorted. |
| // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) |
| // and sort axis is dk. sort_num should have dimension of |
| // (d1, d2, ..., d(k-1), d(k+1), ..., dn). |
| void RegisterSort() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def_packed("tvm.contrib.sort.sort", [](ffi::PackedArgs args, ffi::Any* ret) { |
| auto input = args[0].cast<DLTensor*>(); |
| auto output = args[1].cast<DLTensor*>(); |
| int32_t axis = args[2].cast<int32_t>(); |
| bool is_ascend = args[3].cast<bool>(); |
| if (axis < 0) { |
| axis = input->ndim + axis; |
| } |
| TVM_FFI_ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " |
| "input ndim " |
| << input->ndim; |
| |
| auto data_dtype = ffi::DLDataTypeToString(input->dtype); |
| auto out_dtype = ffi::DLDataTypeToString(output->dtype); |
| |
| TVM_FFI_ICHECK_EQ(data_dtype, out_dtype); |
| |
| if (data_dtype == "float32") { |
| sort<float>(input, output, axis, is_ascend); |
| } else if (data_dtype == "float64") { |
| sort<double>(input, output, axis, is_ascend); |
| #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) |
| } else if (data_dtype == "float16") { |
| sort<__fp16>(input, output, axis, is_ascend); |
| #endif |
| } else if (data_dtype == "int32") { |
| sort<int32_t>(input, output, axis, is_ascend); |
| } else if (data_dtype == "int64") { |
| sort<int64_t>(input, output, axis, is_ascend); |
| } else if (data_dtype == "float16") { |
| sort<float16>(input, output, axis, is_ascend); |
| } else { |
| TVM_FFI_THROW(InternalError) << "Unsupported input dtype: " << data_dtype; |
| } |
| }); |
| } |
| |
| template <typename DataType, typename IndicesType> |
| void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, int axis, |
| bool is_ascend) { |
| DataType* data_ptr = static_cast<DataType*>(input->data); |
| DataType* values_ptr = |
| (out_values == nullptr) ? nullptr : static_cast<DataType*>(out_values->data); |
| IndicesType* indices_ptr = |
| (out_indices == nullptr) ? nullptr : static_cast<IndicesType*>(out_indices->data); |
| |
| // Maintain a min/max containing the top-k elements |
| std::vector<std::pair<int64_t, DataType>> running_heap; |
| |
| // Need +1 when inserting new element before maintaining heap invariant |
| running_heap.reserve(k + 1); |
| |
| int axis_mul_before = 1; |
| int axis_mul_after = 1; |
| for (int i = 0; i < input->ndim; ++i) { |
| if (i < axis) { |
| axis_mul_before *= input->shape[i]; |
| } else if (i > axis) { |
| axis_mul_after *= input->shape[i]; |
| } |
| } |
| if (k < 1) { |
| k = input->shape[axis]; |
| } |
| |
| for (int i = 0; i < axis_mul_before; ++i) { |
| for (int j = 0; j < axis_mul_after; ++j) { |
| running_heap.clear(); |
| int64_t src_base_idx = i * input->shape[axis] * axis_mul_after + j; |
| int64_t dst_base_idx = i * k * axis_mul_after + j; |
| |
| // Start by creating min/max heap with fixed-k elements |
| int cur_axis_index = 0; |
| for (; cur_axis_index < k && cur_axis_index < input->shape[axis]; cur_axis_index++) { |
| int64_t full_idx = src_base_idx + cur_axis_index * axis_mul_after; |
| running_heap.emplace_back(std::make_pair(cur_axis_index, data_ptr[full_idx])); |
| } |
| if (!is_ascend) { |
| std::make_heap(running_heap.begin(), running_heap.end(), CompareDescend<DataType, true>); |
| } else { |
| std::make_heap(running_heap.begin(), running_heap.end(), CompareAscend<DataType, true>); |
| } |
| |
| // Iterate through all elements, adding to heap along the way |
| for (; cur_axis_index < input->shape[axis]; cur_axis_index++) { |
| int64_t full_idx = src_base_idx + cur_axis_index * axis_mul_after; |
| std::pair<int64_t, DataType> cur_val = {cur_axis_index, data_ptr[full_idx]}; |
| |
| // Eq. to cur_val.second > running_heap.second |
| if (!is_ascend && CompareDescend<DataType, true>(cur_val, running_heap[0])) { |
| running_heap.push_back(cur_val); |
| std::push_heap(running_heap.begin(), running_heap.end(), CompareDescend<DataType, true>); |
| std::pop_heap(running_heap.begin(), running_heap.end(), CompareDescend<DataType, true>); |
| running_heap.pop_back(); |
| } else if (is_ascend && CompareAscend<DataType, true>(cur_val, running_heap[0])) { |
| running_heap.push_back(cur_val); |
| std::push_heap(running_heap.begin(), running_heap.end(), CompareAscend<DataType, true>); |
| std::pop_heap(running_heap.begin(), running_heap.end(), CompareAscend<DataType, true>); |
| running_heap.pop_back(); |
| } |
| } |
| |
| // finally sort heap and deliver results |
| if (is_ascend) { |
| std::stable_sort(running_heap.begin(), running_heap.end(), CompareAscend<DataType, true>); |
| } else { |
| std::stable_sort(running_heap.begin(), running_heap.end(), CompareDescend<DataType, true>); |
| } |
| |
| for (uint32_t kk = 0; kk < running_heap.size(); ++kk) { |
| if (indices_ptr != nullptr) { |
| indices_ptr[dst_base_idx + kk * axis_mul_after] = |
| static_cast<IndicesType>(running_heap[kk].first); |
| } |
| if (values_ptr != nullptr) { |
| values_ptr[dst_base_idx + kk * axis_mul_after] = |
| static_cast<DataType>(running_heap[kk].second); |
| } |
| } |
| } |
| } |
| } |
| |
| // Argsort implemented C library sort. |
| // Return indices of sorted tensor. |
| // By default, the last axis will be used to sort. |
| // sort_num specify the number of elements to be sorted. |
| // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) |
| // and sort axis is dk. sort_num should have dimension of |
| // (d1, d2, ..., d(k-1), d(k+1), ..., dn). |
| void RegisterTopk() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def_packed("tvm.contrib.sort.topk", [](ffi::PackedArgs args, ffi::Any* ret) { |
| auto input = args[0].cast<DLTensor*>(); |
| DLTensor* values_out = nullptr; |
| DLTensor* indices_out = nullptr; |
| int k = args[args.size() - 4].cast<int>(); |
| int axis = args[args.size() - 3].cast<int>(); |
| std::string ret_type = args[args.size() - 2].cast<std::string>(); |
| bool is_ascend = args[args.size() - 1].cast<bool>(); |
| if (ret_type == "both") { |
| values_out = args[1].cast<DLTensor*>(); |
| indices_out = args[2].cast<DLTensor*>(); |
| } else if (ret_type == "values") { |
| values_out = args[1].cast<DLTensor*>(); |
| } else if (ret_type == "indices") { |
| indices_out = args[1].cast<DLTensor*>(); |
| } else { |
| TVM_FFI_THROW(InternalError) << "Unsupported ret type: " << ret_type; |
| } |
| if (axis < 0) { |
| axis = input->ndim + axis; |
| } |
| TVM_FFI_ICHECK(axis >= 0 && axis < input->ndim) |
| << "Axis out of boundary for input ndim " << input->ndim; |
| |
| auto data_dtype = ffi::DLDataTypeToString(input->dtype); |
| auto out_dtype = |
| (indices_out == nullptr) ? "int64" : ffi::DLDataTypeToString(indices_out->dtype); |
| |
| if (data_dtype == "float32") { |
| if (out_dtype == "int32") { |
| topk<float, int32_t>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "int64") { |
| topk<float, int64_t>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "float32") { |
| topk<float, float>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "float64") { |
| topk<float, double>(input, values_out, indices_out, k, axis, is_ascend); |
| } else { |
| TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; |
| } |
| } else if (data_dtype == "float64") { |
| if (out_dtype == "int32") { |
| topk<double, int32_t>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "int64") { |
| topk<double, int64_t>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "float32") { |
| topk<double, float>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "float64") { |
| topk<double, double>(input, values_out, indices_out, k, axis, is_ascend); |
| } else { |
| TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; |
| } |
| } else if (data_dtype == "uint8") { |
| if (out_dtype == "uint8") { |
| topk<uint8_t, uint8_t>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "int32") { |
| topk<uint8_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "int64") { |
| topk<uint8_t, int64_t>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "float32") { |
| topk<uint8_t, float>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "float64") { |
| topk<uint8_t, double>(input, values_out, indices_out, k, axis, is_ascend); |
| } else { |
| TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; |
| } |
| } else if (data_dtype == "int8") { |
| if (out_dtype == "int8") { |
| topk<int8_t, int8_t>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "int32") { |
| topk<int8_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "int64") { |
| topk<int8_t, int64_t>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "float32") { |
| topk<int8_t, float>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "float64") { |
| topk<int8_t, double>(input, values_out, indices_out, k, axis, is_ascend); |
| } else { |
| TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; |
| } |
| } else if (data_dtype == "int32") { |
| if (out_dtype == "int32") { |
| topk<int32_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "int64") { |
| topk<int32_t, int64_t>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "float32") { |
| topk<int32_t, float>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "float64") { |
| topk<int32_t, double>(input, values_out, indices_out, k, axis, is_ascend); |
| } else { |
| TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; |
| } |
| } else if (data_dtype == "int64") { |
| if (out_dtype == "int32") { |
| topk<int64_t, int32_t>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "int64") { |
| topk<int64_t, int64_t>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "float32") { |
| topk<int64_t, float>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "float64") { |
| topk<int64_t, double>(input, values_out, indices_out, k, axis, is_ascend); |
| } else { |
| TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; |
| } |
| } else if (data_dtype == "float16") { |
| if (out_dtype == "int32") { |
| topk<float16, int32_t>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "int64") { |
| topk<float16, int64_t>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "float32") { |
| topk<float16, float>(input, values_out, indices_out, k, axis, is_ascend); |
| } else if (out_dtype == "float64") { |
| topk<float16, double>(input, values_out, indices_out, k, axis, is_ascend); |
| } else { |
| TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; |
| } |
| } else { |
| TVM_FFI_THROW(InternalError) << "Unsupported input dtype: " << data_dtype; |
| } |
| }); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| RegisterArgsortNMS(); |
| RegisterArgsort(); |
| RegisterSort(); |
| RegisterTopk(); |
| } |
| |
| } // namespace contrib |
| } // namespace tvm |