blob: ae9771641b23287e8e0f76c0b7535f38f8bbc05f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file pack_args.h
* \brief Utility to pack TVMArgs to other type-erased fution calling convention.
*
* Two type erased function signatures are supported.
* - cuda_style(void** args, int num_args);
* - Pack everything by address
* - metal_style(void** buffers, int num_buffers,
* union_32bit args[N], int num_args);
* - Pack buffer by address, pack rest parameter into 32bit union buffer.
*/
#ifndef TVM_RUNTIME_PACK_ARGS_H_
#define TVM_RUNTIME_PACK_ARGS_H_
#include <tvm/runtime/c_runtime_api.h>
#include <cstring>
#include <vector>
namespace tvm {
namespace runtime {
/*!
* \brief argument union type of 32bit.
* Choose 32 bit because most GPU API do not work well with 64 bit.
*/
union ArgUnion {
int32_t v_int32;
uint32_t v_uint32;
float v_float32;
};
/*!
* \brief Create a packed function from void addr types.
*
* \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args)
* \param arg_types The arguments type information.
* \tparam F the function type
*
* \return The wrapped packed function.
*/
template <typename F>
inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DLDataType>& arg_types);
/*!
* \brief Create a packed function that from function only packs buffer arguments.
*
* \param f with signiture (TVMArgs args, TVMRetValue* rv, ArgUnion* pack_args)
* \param arg_types The arguments type information.
* \tparam F the function type
*
* \return The wrapped packed function.
*/
template <typename F>
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DLDataType>& arg_types);
/*!
* \brief Create a packed function that from function that takes a packed arguments.
*
* \param f with signature (TVMArgs args, TVMRetValue* rv, void* pack_args, size_t nbytes)
* \param arg_types The arguments that wish to get from
* \tparam F the function type
*
* \return The wrapped packed function.
*/
template <typename F>
inline PackedFunc PackFuncPackedArg(F f, const std::vector<DLDataType>& arg_types);
/*!
* \brief Extract number of buffer argument from the argument types.
* \param arg_types The argument types.
* \return number of buffer arguments
*/
inline size_t NumBufferArgs(const std::vector<DLDataType>& arg_types);
// implementations details
namespace detail {
template <typename T, int kSize>
class TempArray {
public:
explicit TempArray(int size) {}
T* data() { return data_; }
private:
T data_[kSize];
};
template <typename T>
class TempArray<T, 0> {
public:
explicit TempArray(int size) : data_(size) {}
T* data() { return data_.data(); }
private:
std::vector<T> data_;
};
/*! \brief conversion code used in void arg. */
enum ArgConvertCode {
INT64_TO_INT64,
INT64_TO_INT32,
INT64_TO_UINT32,
FLOAT64_TO_FLOAT32,
FLOAT64_TO_FLOAT64,
HANDLE_TO_HANDLE
};
inline ArgConvertCode GetArgConvertCode(DLDataType t) {
CHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to devic function for now";
if (t.code == kDLInt) {
if (t.bits == 64U) return INT64_TO_INT64;
if (t.bits == 32U) return INT64_TO_INT32;
} else if (t.code == kDLUInt) {
if (t.bits == 32U) return INT64_TO_UINT32;
} else if (t.code == kDLFloat) {
if (t.bits == 64U) return FLOAT64_TO_FLOAT64;
if (t.bits == 32U) return FLOAT64_TO_FLOAT32;
} else if (t.code == kTVMOpaqueHandle) {
return HANDLE_TO_HANDLE;
}
LOG(FATAL) << "Cannot handle " << t << " as device function argument";
return HANDLE_TO_HANDLE;
}
template <int N, typename F>
inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& codes) {
int num_args = static_cast<int>(codes.size());
auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
TempArray<void*, N> addr_(num_args);
TempArray<ArgUnion, N> holder_(num_args);
void** addr = addr_.data();
ArgUnion* holder = holder_.data();
for (int i = 0; i < num_args; ++i) {
switch (codes[i]) {
case INT64_TO_INT64:
case FLOAT64_TO_FLOAT64:
case HANDLE_TO_HANDLE: {
addr[i] = (void*)&(args.values[i]); // NOLINT(*)
break;
}
case INT64_TO_INT32: {
holder[i].v_int32 = static_cast<int32_t>(args.values[i].v_int64);
addr[i] = &(holder[i]);
break;
}
case INT64_TO_UINT32: {
holder[i].v_uint32 = static_cast<uint32_t>(args.values[i].v_int64);
addr[i] = &(holder[i]);
break;
}
case FLOAT64_TO_FLOAT32: {
holder[i].v_float32 = static_cast<float>(args.values[i].v_float64);
addr[i] = &(holder[i]);
break;
}
}
}
f(args, ret, addr);
};
return PackedFunc(ret);
}
template <int N, typename F>
inline PackedFunc PackFuncNonBufferArg_(F f, int base, const std::vector<ArgConvertCode>& codes) {
int num_args = static_cast<int>(codes.size());
auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) {
TempArray<ArgUnion, N> holder_(num_args);
ArgUnion* holder = holder_.data();
for (int i = 0; i < num_args; ++i) {
switch (codes[i]) {
case INT64_TO_INT64:
case FLOAT64_TO_FLOAT64: {
LOG(FATAL) << "Do not support 64bit argument to device function";
break;
}
case INT64_TO_INT32: {
holder[i].v_int32 = static_cast<int32_t>(args.values[base + i].v_int64);
break;
}
case INT64_TO_UINT32: {
holder[i].v_uint32 = static_cast<uint32_t>(args.values[base + i].v_int64);
break;
}
case FLOAT64_TO_FLOAT32: {
holder[i].v_float32 = static_cast<float>(args.values[base + i].v_float64);
break;
}
case HANDLE_TO_HANDLE: {
LOG(FATAL) << "not reached";
break;
}
}
}
f(args, ret, holder);
};
return PackedFunc(ret);
}
template <int N, typename F>
inline PackedFunc PackFuncPackedArg_(F f, const std::vector<ArgConvertCode>& codes) {
int num_args = static_cast<int>(codes.size());
auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
TempArray<uint64_t, N> pack_(num_args);
int32_t* pack = reinterpret_cast<int32_t*>(pack_.data());
int32_t* ptr = pack;
static_assert(sizeof(TVMValue) == 8, "invariant");
static_assert(sizeof(void*) % sizeof(int32_t) == 0, "invariant");
for (int i = 0; i < num_args; ++i) {
switch (codes[i]) {
case HANDLE_TO_HANDLE: {
std::memcpy(ptr, &(args.values[i].v_handle), sizeof(void*));
ptr += sizeof(void*) / sizeof(int32_t);
break;
}
case INT64_TO_INT64:
case FLOAT64_TO_FLOAT64: {
std::memcpy(ptr, &args.values[i], sizeof(TVMValue));
ptr += 2;
break;
}
case INT64_TO_INT32: {
*ptr = static_cast<int32_t>(args.values[i].v_int64);
++ptr;
break;
}
case INT64_TO_UINT32: {
*reinterpret_cast<uint32_t*>(ptr) = static_cast<uint32_t>(args.values[i].v_int64);
++ptr;
break;
}
case FLOAT64_TO_FLOAT32: {
*reinterpret_cast<float*>(ptr) = static_cast<float>(args.values[i].v_float64);
++ptr;
break;
}
default: {
LOG(FATAL) << "not reached";
break;
}
}
}
f(args, ret, pack, (ptr - pack) * sizeof(int32_t));
};
return PackedFunc(ret);
}
} // namespace detail
template <typename F>
inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DLDataType>& arg_types) {
std::vector<detail::ArgConvertCode> codes(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
codes[i] = detail::GetArgConvertCode(arg_types[i]);
}
size_t num_void_args = arg_types.size();
// specialization
if (num_void_args <= 4) {
return detail::PackFuncVoidAddr_<4>(f, codes);
} else if (num_void_args <= 8) {
return detail::PackFuncVoidAddr_<8>(f, codes);
} else {
return detail::PackFuncVoidAddr_<0>(f, codes);
}
}
inline size_t NumBufferArgs(const std::vector<DLDataType>& arg_types) {
size_t base = arg_types.size();
for (size_t i = 0; i < arg_types.size(); ++i) {
if (arg_types[i].code != kTVMOpaqueHandle) {
base = i;
break;
}
}
for (size_t i = base; i < arg_types.size(); ++i) {
CHECK(arg_types[i].code != kTVMOpaqueHandle) << "Device function need to be organized";
}
return base;
}
template <typename F>
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DLDataType>& arg_types) {
size_t num_buffer = NumBufferArgs(arg_types);
std::vector<detail::ArgConvertCode> codes;
for (size_t i = num_buffer; i < arg_types.size(); ++i) {
codes.push_back(detail::GetArgConvertCode(arg_types[i]));
}
int base = static_cast<int>(num_buffer);
size_t nargs = codes.size();
// specialization
if (nargs <= 4) {
return detail::PackFuncNonBufferArg_<4>(f, base, codes);
} else {
return detail::PackFuncNonBufferArg_<0>(f, base, codes);
}
}
template <typename F>
inline PackedFunc PackFuncPackedArg(F f, const std::vector<DLDataType>& arg_types) {
std::vector<detail::ArgConvertCode> codes;
for (size_t i = 0; i < arg_types.size(); ++i) {
codes.push_back(detail::GetArgConvertCode(arg_types[i]));
}
size_t nargs = codes.size();
// specialization
if (nargs <= 4) {
return detail::PackFuncPackedArg_<4>(f, codes);
} else {
return detail::PackFuncPackedArg_<0>(f, codes);
}
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_PACK_ARGS_H_