blob: 16217432dc98588b26ae6d08eabdd7b70c5f5dc2 [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file Use external Thrust library call
*/
#include <dlpack/dlpack.h>
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/gather.h>
#include <thrust/mr/device_memory_resource.h>
#include <thrust/mr/disjoint_tls_pool.h>
#include <thrust/mr/memory_resource.h>
#include <thrust/scan.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include <tvm/ffi/dtype.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>
#include <algorithm>
#include <functional>
#include <memory>
#include <vector>
#include "../../cuda/cuda_common.h"
namespace tvm {
namespace contrib {
using namespace runtime;
/*! \brief Memory resource backed by pre-allocated workspace. */
class WorkspaceMemoryResource : public thrust::mr::memory_resource<void*> {
public:
explicit WorkspaceMemoryResource(DLTensor* workspace) {
if (workspace != nullptr) {
this->workspace = workspace->data;
TVM_FFI_ICHECK(workspace->ndim == 1 && workspace->dtype.code == kDLUInt &&
workspace->dtype.bits == 8);
this->workspace_size = workspace->shape[0];
} else {
// Fallback to thrust TLS caching allocator if workspace is not provided.
thrust_pool_ = &thrust::mr::tls_disjoint_pool(
thrust::mr::get_global_resource<thrust::device_memory_resource>(),
thrust::mr::get_global_resource<thrust::mr::new_delete_resource>());
}
}
void* do_allocate(size_t bytes, size_t alignment) override {
if (workspace != nullptr) {
void* result = std::align(alignment, bytes, workspace, workspace_size);
TVM_FFI_ICHECK(result) << "Failed to allocate " << bytes << " bytes with alignment "
<< alignment << " bytes.";
workspace = static_cast<char*>(workspace) + bytes;
workspace_size -= bytes;
return result;
}
return thrust_pool_->do_allocate(bytes, alignment).get();
}
void do_deallocate(void* p, size_t bytes, size_t alignment) override {
if (workspace != nullptr) {
// No-op
} else {
thrust_pool_->do_deallocate(thrust::device_memory_resource::pointer(p), bytes, alignment);
}
}
thrust::mr::disjoint_unsynchronized_pool_resource<thrust::device_memory_resource,
thrust::mr::new_delete_resource>* thrust_pool_ =
nullptr;
void* workspace = nullptr;
size_t workspace_size = 0;
};
auto get_thrust_exec_policy(WorkspaceMemoryResource* memory_resouce) {
int device_id;
CUDA_CALL(cudaGetDevice(&device_id));
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(kDLCUDA, device_id));
return thrust::cuda::par_nosync(memory_resouce).on(stream);
}
// Performs sorting along axis -1 and returns both sorted values and indices.
template <typename DataType, typename IndicesType>
void thrust_sort(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, bool is_ascend,
int n_values, DLTensor* workspace) {
thrust::device_ptr<DataType> data_ptr(static_cast<DataType*>(input->data));
thrust::device_ptr<DataType> values_ptr(static_cast<DataType*>(out_values->data));
thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType*>(out_indices->data));
WorkspaceMemoryResource mr(workspace);
auto policy = get_thrust_exec_policy(&mr);
size_t size = 1;
for (int i = 0; i < input->ndim; ++i) {
size *= input->shape[i];
}
thrust::copy(policy, data_ptr, data_ptr + size, values_ptr);
if (size == static_cast<size_t>(input->shape[input->ndim - 1])) {
// A fast path for single segment case
thrust::sequence(indices_ptr, indices_ptr + n_values);
if (is_ascend) {
thrust::sort_by_key(policy, values_ptr, values_ptr + n_values, indices_ptr);
} else {
thrust::sort_by_key(policy, values_ptr, values_ptr + n_values, indices_ptr,
thrust::greater<DataType>());
}
} else {
// segmented sort by key
// Follow the back-to-back stable_sort_by_key strategy explained below
// https://groups.google.com/g/thrust-users/c/BoLsxO6b4FY
thrust::device_ptr<int64_t> argsort_order(
static_cast<int64_t*>(mr.do_allocate(sizeof(int64_t) * size, sizeof(int64_t))));
thrust::sequence(argsort_order, argsort_order + size);
// First, sort values and store the sorted order in argsort_order.
if (is_ascend) {
thrust::stable_sort_by_key(policy, values_ptr, values_ptr + size, argsort_order);
} else {
thrust::stable_sort_by_key(policy, values_ptr, values_ptr + size, argsort_order,
thrust::greater<DataType>());
}
// The following is to create the indices array 0, 1, 2, 0, 1, 2 ... 0, 1, 2
// without materializing it
auto counting_iter = thrust::counting_iterator<int64_t>(0);
auto linear_index_to_sort_axis_index = [n_values] __host__ __device__(int64_t i) {
return i % n_values;
}; // NOLINT(*)
auto init_indices_iter =
thrust::make_transform_iterator(counting_iter, linear_index_to_sort_axis_index);
// This will reorder indices 0, 1, 2 ... in the sorted order of values_ptr
thrust::gather(policy, argsort_order, argsort_order + size, init_indices_iter, indices_ptr);
thrust::device_ptr<int> segment_ids(
static_cast<int*>(mr.do_allocate(sizeof(int) * size, sizeof(int))));
auto linear_index_to_segment_id = [n_values] __host__ __device__(int64_t i) {
return i / n_values;
}; // NOLINT(*)
// We also reorder segment indices 0, 0, 0, 1, 1, 1 ... in the order of values_ptr
thrust::transform(policy, argsort_order, argsort_order + size, segment_ids,
linear_index_to_segment_id);
// The second sort key-ed by segment_ids would bring segment_ids back to 0, 0, 0, 1, 1, 1 ...
// values_ptr and indices_ptr will also be sorted in the order of segmend_ids above
// Since sorting has been done in a stable way, relative orderings of values and indices
// in the segment do not change and hence they remain sorted.
auto key_val_zip = thrust::make_zip_iterator(thrust::make_tuple(values_ptr, indices_ptr));
thrust::stable_sort_by_key(policy, segment_ids, segment_ids + size, key_val_zip);
}
}
void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices_out,
bool is_ascend, int sort_len, std::string data_dtype, std::string out_dtype,
DLTensor* workspace) {
if (data_dtype == "float16") {
if (out_dtype == "int32") {
thrust_sort<half, int32_t>(input, values_out, indices_out, is_ascend, sort_len, workspace);
} else if (out_dtype == "int64") {
thrust_sort<half, int64_t>(input, values_out, indices_out, is_ascend, sort_len, workspace);
} else if (out_dtype == "float32") {
thrust_sort<half, float>(input, values_out, indices_out, is_ascend, sort_len, workspace);
} else if (out_dtype == "float64") {
thrust_sort<half, double>(input, values_out, indices_out, is_ascend, sort_len, workspace);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float32") {
if (out_dtype == "int32") {
thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend, sort_len, workspace);
} else if (out_dtype == "int64") {
thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend, sort_len, workspace);
} else if (out_dtype == "float32") {
thrust_sort<float, float>(input, values_out, indices_out, is_ascend, sort_len, workspace);
} else if (out_dtype == "float64") {
thrust_sort<float, double>(input, values_out, indices_out, is_ascend, sort_len, workspace);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float64") {
if (out_dtype == "int32") {
thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend, sort_len, workspace);
} else if (out_dtype == "int64") {
thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend, sort_len, workspace);
} else if (out_dtype == "float32") {
thrust_sort<double, float>(input, values_out, indices_out, is_ascend, sort_len, workspace);
} else if (out_dtype == "float64") {
thrust_sort<double, double>(input, values_out, indices_out, is_ascend, sort_len, workspace);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int32") {
if (out_dtype == "int32") {
thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend, sort_len, workspace);
} else if (out_dtype == "int64") {
thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend, sort_len, workspace);
} else if (out_dtype == "float32") {
thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend, sort_len, workspace);
} else if (out_dtype == "float64") {
thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend, sort_len, workspace);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int64") {
if (out_dtype == "int32") {
thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend, sort_len, workspace);
} else if (out_dtype == "int64") {
thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend, sort_len, workspace);
} else if (out_dtype == "float32") {
thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend, sort_len, workspace);
} else if (out_dtype == "float64") {
thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend, sort_len, workspace);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
}
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed("tvm.contrib.thrust.sort", [](ffi::PackedArgs args, ffi::Any* ret) {
TVM_FFI_ICHECK_GE(args.size(), 4);
auto input = args[0].cast<DLTensor*>();
auto values_out = args[1].cast<DLTensor*>();
auto indices_out = args[2].cast<DLTensor*>();
bool is_ascend = args[3].cast<bool>();
DLTensor* workspace = nullptr;
if (args.size() == 5) {
workspace = args[4].cast<DLTensor*>();
}
auto data_dtype = ffi::DLDataTypeToString(input->dtype);
auto out_dtype = ffi::DLDataTypeToString(indices_out->dtype);
int n_values = input->shape[input->ndim - 1];
thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype,
workspace);
});
}
template <typename KeyType, typename ValueType>
void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* keys_out,
DLTensor* values_out, bool for_scatter,
DLTensor* workspace = nullptr) {
const auto size = keys_in->shape[0];
thrust::device_ptr<KeyType> keys_in_ptr(static_cast<KeyType*>(keys_in->data));
thrust::device_ptr<ValueType> values_in_ptr(static_cast<ValueType*>(values_in->data));
thrust::device_ptr<KeyType> keys_out_ptr(static_cast<KeyType*>(keys_out->data));
thrust::device_ptr<ValueType> values_out_ptr(static_cast<ValueType*>(values_out->data));
WorkspaceMemoryResource mr(workspace);
auto policy = get_thrust_exec_policy(&mr);
if (for_scatter) {
thrust::transform(policy, keys_in_ptr, keys_in_ptr + size, keys_out_ptr,
[size] __device__(KeyType k) {
if (k < 0) return k + static_cast<KeyType>(size);
return k;
});
} else {
thrust::copy(policy, keys_in_ptr, keys_in_ptr + size, keys_out_ptr);
}
thrust::copy(policy, values_in_ptr, values_in_ptr + size, values_out_ptr);
thrust::stable_sort_by_key(policy, keys_out_ptr, keys_out_ptr + size, values_out_ptr);
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed(
"tvm.contrib.thrust.stable_sort_by_key", [](ffi::PackedArgs args, ffi::Any* ret) {
TVM_FFI_ICHECK_GE(args.size(), 5);
auto keys_in = args[0].cast<DLTensor*>();
auto values_in = args[1].cast<DLTensor*>();
auto keys_out = args[2].cast<DLTensor*>();
auto values_out = args[3].cast<DLTensor*>();
bool for_scatter = args[4].cast<bool>();
DLTensor* workspace = nullptr;
if (args.size() == 6) {
workspace = args[5].cast<DLTensor*>();
}
auto key_dtype = ffi::DLDataTypeToString(keys_in->dtype);
auto value_dtype = ffi::DLDataTypeToString(values_in->dtype);
if (key_dtype == "int32") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<int, int>(keys_in, values_in, keys_out, values_out,
for_scatter, workspace);
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<int, int64_t>(keys_in, values_in, keys_out, values_out,
for_scatter, workspace);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<int, float>(keys_in, values_in, keys_out, values_out,
for_scatter, workspace);
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else if (key_dtype == "int64") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<int64_t, int>(keys_in, values_in, keys_out, values_out,
for_scatter, workspace);
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<int64_t, int64_t>(keys_in, values_in, keys_out, values_out,
for_scatter, workspace);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<int64_t, float>(keys_in, values_in, keys_out, values_out,
for_scatter, workspace);
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else if (key_dtype == "float32") {
if (value_dtype == "int32") {
thrust_stable_sort_by_key<float, int>(keys_in, values_in, keys_out, values_out,
for_scatter, workspace);
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<float, int64_t>(keys_in, values_in, keys_out, values_out,
for_scatter, workspace);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<float, float>(keys_in, values_in, keys_out, values_out,
for_scatter, workspace);
} else {
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
}
} else {
LOG(FATAL) << "Unsupported key dtype: " << key_dtype;
}
});
}
template <typename InType, typename OutType>
void thrust_scan(DLTensor* data, DLTensor* output, bool exclusive, DLTensor* workspace) {
WorkspaceMemoryResource mr(workspace);
auto policy = get_thrust_exec_policy(&mr);
thrust::device_ptr<InType> data_ptr(static_cast<InType*>(data->data));
thrust::device_ptr<OutType> output_ptr(static_cast<OutType*>(output->data));
const auto scan_size = data->shape[data->ndim - 1];
if (scan_size == 0) return;
size_t size = 1;
for (int i = 0; i < data->ndim; ++i) size *= data->shape[i];
const bool need_cast = std::is_same<InType, OutType>::value == false;
auto data_cast_ptr = thrust::make_transform_iterator(
data_ptr, [] __host__ __device__(InType v) { return static_cast<OutType>(v); }); // NOLINT(*)
if (size == static_cast<size_t>(data->shape[data->ndim - 1])) {
if (exclusive && need_cast) {
thrust::exclusive_scan(policy, data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
} else if (exclusive && !need_cast) {
thrust::exclusive_scan(policy, data_ptr, data_ptr + scan_size, output_ptr);
} else if (!exclusive && need_cast) {
thrust::inclusive_scan(policy, data_cast_ptr, data_cast_ptr + scan_size, output_ptr);
} else {
thrust::inclusive_scan(policy, data_ptr, data_ptr + scan_size, output_ptr);
}
} else {
// Use thrust segmented scan to compute scan on the inner most axis
// data->shape[0] * data->shape[1] * ... * data->shape[ndim - 2] scans are
// computed in parallel
// This is for constructing a sequence 0, 0, 0,...,1, 1, 1,...,2, 2, 2,...,
// without materializing the sequence vector
auto counting_iter = thrust::counting_iterator<size_t>(0);
// Without __host__ annotation, cub crashes
auto linear_index_to_scan_key = [scan_size] __host__ __device__(size_t i) {
return i / scan_size;
}; // NOLINT(*)
auto key_iter = thrust::make_transform_iterator(counting_iter, linear_index_to_scan_key);
if (exclusive && need_cast) {
thrust::exclusive_scan_by_key(policy, key_iter, key_iter + size, data_cast_ptr, output_ptr);
} else if (exclusive && !need_cast) {
thrust::exclusive_scan_by_key(policy, key_iter, key_iter + size, data_ptr, output_ptr);
} else if (!exclusive && need_cast) {
thrust::inclusive_scan_by_key(policy, key_iter, key_iter + size, data_cast_ptr, output_ptr);
} else {
thrust::inclusive_scan_by_key(policy, key_iter, key_iter + size, data_ptr, output_ptr);
}
}
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed(
"tvm.contrib.thrust.sum_scan", [](ffi::PackedArgs args, ffi::Any* ret) {
TVM_FFI_ICHECK(args.size() == 2 || args.size() == 3 || args.size() == 4);
auto data = args[0].cast<DLTensor*>();
auto output = args[1].cast<DLTensor*>();
bool exclusive = false;
DLTensor* workspace = nullptr;
if (args.size() >= 3) {
exclusive = args[2].cast<bool>();
}
if (args.size() == 4) {
workspace = args[3].cast<DLTensor*>();
}
auto in_dtype = ffi::DLDataTypeToString(data->dtype);
auto out_dtype = ffi::DLDataTypeToString(output->dtype);
if (in_dtype == "bool") {
if (out_dtype == "int32") {
thrust_scan<bool, int>(data, output, exclusive, workspace);
} else if (out_dtype == "int64") {
thrust_scan<bool, int64_t>(data, output, exclusive, workspace);
} else if (out_dtype == "float32") {
thrust_scan<bool, float>(data, output, exclusive, workspace);
} else if (out_dtype == "float64") {
thrust_scan<bool, double>(data, output, exclusive, workspace);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
<< ". Supported output dtypes are int32, int64, float32, and float64";
}
} else if (in_dtype == "int32") {
if (out_dtype == "int32") {
thrust_scan<int, int>(data, output, exclusive, workspace);
} else if (out_dtype == "int64") {
thrust_scan<int, int64_t>(data, output, exclusive, workspace);
} else if (out_dtype == "float32") {
thrust_scan<int, float>(data, output, exclusive, workspace);
} else if (out_dtype == "float64") {
thrust_scan<int, double>(data, output, exclusive, workspace);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
<< ". Supported output dtypes are int32, int64, float32, and float64";
}
} else if (in_dtype == "int64") {
if (out_dtype == "int64") {
thrust_scan<int64_t, int64_t>(data, output, exclusive, workspace);
} else if (out_dtype == "float32") {
thrust_scan<int64_t, float>(data, output, exclusive, workspace);
} else if (out_dtype == "float64") {
thrust_scan<int64_t, double>(data, output, exclusive, workspace);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
<< ". Supported output dtypes are int64, float32, and float64";
}
} else if (in_dtype == "float32") {
if (out_dtype == "float32") {
thrust_scan<float, float>(data, output, exclusive, workspace);
} else if (out_dtype == "float64") {
thrust_scan<float, double>(data, output, exclusive, workspace);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
<< ". Supported output dtypes are float32, and float64";
}
} else if (in_dtype == "float64") {
if (out_dtype == "float64") {
thrust_scan<double, double>(data, output, exclusive, workspace);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
<< ". Supported output dtype is float64";
}
} else {
LOG(FATAL) << "Unsupported input dtype: " << in_dtype
<< ". Supported input dtypes are bool, int32, int64, float32, and float64";
}
});
}
} // namespace contrib
} // namespace tvm