blob: 110060356d16c2c24b0f588fa27a9f879646068d [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ffi/function.h
* \brief A managed function in the TVM FFI.
*/
#ifndef TVM_FFI_FUNCTION_H_
#define TVM_FFI_FUNCTION_H_
/*!
* \brief Controls whether DLL exports should include metadata.
*
* When set to 1, exported functions will include additional metadata.
* When set to 0 (default), exports are minimal without metadata.
*/
#ifndef TVM_FFI_DLL_EXPORT_INCLUDE_METADATA
#define TVM_FFI_DLL_EXPORT_INCLUDE_METADATA 0
#endif
#if TVM_FFI_DLL_EXPORT_INCLUDE_METADATA
#include <sstream>
#endif // TVM_FFI_DLL_EXPORT_INCLUDE_METADATA
#include <tvm/ffi/any.h>
#include <tvm/ffi/base_details.h>
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/expected.h>
#include <tvm/ffi/function_details.h>
#include <functional>
#include <optional>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
namespace tvm {
namespace ffi {
/**
* Helper macro to construct a safe call
*
* \brief Marks the beginning of the safe call that catches exception explicitly
* \sa TVM_FFI_SAFE_CALL_END
*
* \code{.cpp}
* int TVMFFICStyleFunction() {
* TVM_FFI_SAFE_CALL_BEGIN();
* // c++ code region here
* TVM_FFI_SAFE_CALL_END();
* }
* \endcode
*/
#define TVM_FFI_SAFE_CALL_BEGIN() \
try { \
(void)0
/*!
* \brief Marks the end of safe call.
*/
#define TVM_FFI_SAFE_CALL_END() \
return 0; \
} \
catch (const ::tvm::ffi::Error& err) { \
::tvm::ffi::details::SetSafeCallRaised(err); \
return -1; \
} \
catch (const std::exception& ex) { \
::tvm::ffi::details::SetSafeCallRaised(::tvm::ffi::Error("InternalError", ex.what(), "")); \
return -1; \
} \
TVM_FFI_UNREACHABLE()
/*!
* \brief Macro to check a call to TVMFFISafeCallType and raise exception if error happens.
* \param func The function to check.
*
* \code{.cpp}
* // calls TVMFFIFunctionCall and raises exception if error happens
* TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_arr, &type_index));
* \endcode
*/
#define TVM_FFI_CHECK_SAFE_CALL(func) \
{ \
int ret_code = (func); \
if (TVM_FFI_PREDICT_FALSE(ret_code != 0)) { \
throw ::tvm::ffi::details::MoveFromSafeCallRaised(); \
} \
}
/*!
* \brief Object container class that backs ffi::Function
* \note Do not use this class directly, use ffi::Function
*/
class FunctionObj : public Object, public TVMFFIFunctionCell {
public:
/*! \brief Typedef for C++ style calling signature that comes with exception propagation */
using FCall = void (*)(const FunctionObj*, const AnyView*, int32_t, Any*);
using TVMFFIFunctionCell::cpp_call;
using TVMFFIFunctionCell::safe_call;
/*!
* \brief Call the function in packed format.
* \param args The arguments
* \param num_args The number of arguments
* \param result The return value.
*/
TVM_FFI_INLINE void CallPacked(const AnyView* args, int32_t num_args, Any* result) const {
// if cpp_call is set, use it to call the function, otherwise, redirect to safe_call
// use conditional expression here so the select is branchless
FCall call_ptr =
this->cpp_call ? reinterpret_cast<FCall>(this->cpp_call) : CppCallDedirectToSafeCall;
(*call_ptr)(this, args, num_args, result);
}
/// \cond Doxygen_Suppress
static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIFunction;
TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIFunction, FunctionObj, Object);
/// \endcond
protected:
/*! \brief Make default constructor protected. */
FunctionObj() {}
friend class Function;
private:
static void CppCallDedirectToSafeCall(const FunctionObj* func, const AnyView* args,
int32_t num_args, Any* rv) {
FunctionObj* self = static_cast<FunctionObj*>(const_cast<FunctionObj*>(func));
TVM_FFI_CHECK_SAFE_CALL(self->safe_call(self, reinterpret_cast<const TVMFFIAny*>(args),
num_args, reinterpret_cast<TVMFFIAny*>(rv)));
}
};
namespace details {
/*!
* \brief Derived object class for constructing FunctionObj backed by a TCallable
*
* This is a helper class that implements the function call interface.
* Invariance: TCallable cannot be const or reference type.
*/
template <typename TCallable>
class FunctionObjImpl : public FunctionObj {
public:
static_assert(std::is_same_v<TCallable, std::remove_cv_t<std::remove_reference_t<TCallable>>>,
"TCallable of FunctionObjImpl cannot be const or reference type");
/*! \brief The type of derived object class */
using TSelf = FunctionObjImpl<TCallable>;
/*!
* \brief Derived object class for constructing ffi::FunctionObj.
* \param args The arguments to construct TCallable
*/
template <typename... Args>
explicit FunctionObjImpl(Args&&... args) : callable_(std::forward<Args>(args)...) {
this->safe_call = SafeCall;
this->cpp_call = reinterpret_cast<void*>(CppCall);
}
FunctionObjImpl(const FunctionObjImpl&) = delete;
FunctionObjImpl& operator=(const FunctionObjImpl&) = delete;
TCallable* GetCallable() { return &callable_; }
private:
// implementation of call
static void CppCall(const FunctionObj* func, const AnyView* args, int32_t num_args, Any* result) {
(static_cast<const TSelf*>(func))->callable_(args, num_args, result);
}
/// \cond Doxygen_Suppress
// Implementing safe call style
static int SafeCall(void* func, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* result) {
TVM_FFI_SAFE_CALL_BEGIN();
TVM_FFI_ICHECK_LT(result->type_index, TypeIndex::kTVMFFIStaticObjectBegin);
FunctionObj* self = static_cast<FunctionObj*>(func);
reinterpret_cast<FCall>(self->cpp_call)(self, reinterpret_cast<const AnyView*>(args), num_args,
reinterpret_cast<Any*>(result));
TVM_FFI_SAFE_CALL_END();
}
/// \endcond
/*! \brief Type-erased filed for storing callable object*/
mutable TCallable callable_;
};
/*!
* \brief FunctionObj specialization for raw C style callback where handle and deleter are null.
*/
class ExternCFunctionObjNullHandleImpl : public FunctionObj {
public:
explicit ExternCFunctionObjNullHandleImpl(TVMFFISafeCallType safe_call) {
this->safe_call = safe_call;
this->cpp_call = nullptr;
}
};
/*!
* \brief FunctionObj specialization that leverages C-style callback definitions.
*/
class ExternCFunctionObjImpl : public FunctionObj {
public:
ExternCFunctionObjImpl(void* self, TVMFFISafeCallType safe_call, void (*deleter)(void* self))
: self_(self), safe_call_(safe_call), deleter_(deleter) {
this->safe_call = SafeCall;
this->cpp_call = nullptr;
}
~ExternCFunctionObjImpl() {
if (deleter_) deleter_(self_);
}
private:
static int32_t SafeCall(void* func, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* rv) {
ExternCFunctionObjImpl* self = reinterpret_cast<ExternCFunctionObjImpl*>(func);
return self->safe_call_(self->self_, args, num_args, rv);
}
void* self_;
TVMFFISafeCallType safe_call_;
void (*deleter_)(void* self);
};
// Helper class to set packed arguments
class PackedArgsSetter {
public:
explicit PackedArgsSetter(AnyView* args) : args_(args) {}
// NOTE: setter needs to be very carefully designed
// such that we do not have temp variable conversion(eg. convert from lvalue to rvalue)
// that is why we need T&& and std::forward here
template <typename T>
TVM_FFI_INLINE void operator()(size_t i, T&& value) const {
args_[i].operator=(std::forward<T>(value));
}
private:
AnyView* args_;
};
} // namespace details
/*!
* \brief Represents arguments packed in AnyView array
* \note This class represent packed arguments to ffi::Function
*/
class PackedArgs {
public:
/*!
* \brief Constructor
* \param data The arguments
* \param size The number of arguments
*/
PackedArgs(const AnyView* data, int32_t size) : data_(data), size_(size) {}
/*! \return size of the arguments */
int size() const { return size_; }
/*! \return The arguments */
const AnyView* data() const { return data_; }
/*!
* \brief Slice the arguments
* \param begin The begin index
* \param end The end index
* \return The sliced arguments
*/
PackedArgs Slice(int begin, int end = -1) const {
if (end == -1) {
end = size_;
}
return PackedArgs(data_ + begin, end - begin);
}
/*!
* \brief Get i-th argument
* \param i the index.
* \return the ith argument.
*/
AnyView operator[](int i) const { return data_[i]; }
/*!
* \brief Fill the arguments into the AnyView array
* \param data The AnyView array to store the packed arguments
* \param args The arguments to be packed
* \note Caller must ensure all args are alive during lifetime of data.
* A common pitfall is to pass in local variables that are immediately
* destroyed after calling Fill.
*/
template <typename... Args>
TVM_FFI_INLINE static void Fill(AnyView* data, Args&&... args) {
details::for_each(details::PackedArgsSetter(data), std::forward<Args>(args)...);
}
private:
/*! \brief The arguments */
const AnyView* data_;
/*! \brief The number of arguments */
int32_t size_;
};
/*!
* \brief ffi::Function is a type-erased function.
* The arguments are passed by "packed format" via AnyView
*/
class Function : public ObjectRef {
public:
/*! \brief Constructor from null */
Function(std::nullptr_t) : ObjectRef(nullptr) {} // NOLINT(*)
/*!
* \brief Constructing a packed function from a callable type
* whose signature is consistent with `ffi::Function`
* \param packed_call The packed function signature
* \note legacy purpose, should change to Function::FromPacked for mostfuture use.
*/
template <typename TCallable,
typename = std::enable_if_t<!std::is_same_v<std::decay_t<TCallable>, Function>>>
explicit Function(TCallable&& packed_call) {
*this = FromPacked(std::forward<TCallable>(packed_call));
}
/*!
* \brief Constructing a packed function from a callable type
* whose signature is consistent with `ffi::Function`
* \param packed_call The packed function signature
*/
template <typename TCallable>
static Function FromPacked(TCallable&& packed_call) {
static_assert(
std::is_convertible_v<TCallable, std::function<void(const AnyView*, int32_t, Any*)>> ||
std::is_convertible_v<TCallable, std::function<void(PackedArgs args, Any*)>>,
"tvm::ffi::Function::FromPacked requires input function signature to match packed func "
"format");
if constexpr (std::is_convertible_v<TCallable, std::function<void(PackedArgs args, Any*)>>) {
return FromPackedInternal(
[packed_call = std::forward<TCallable>(packed_call)](
const AnyView* args, int32_t num_args, Any* rv) mutable -> void {
packed_call(PackedArgs{args, num_args}, rv);
});
} else {
return FromPackedInternal(std::forward<TCallable>(packed_call));
}
}
/*!
* \brief Constructing a packed function from a callable type
* whose signature is consistent with `ffi::Function`.
* It will create the Callable object with the given arguments,
* and return the inplace constructed Function along with
* the pointer to the callable object. The lifetime of the callable
* object is managed by the returned Function.
* \param args The arguments to construct TCallable
* \return A tuple of (Function, TCallable*)
*/
template <typename TCallable, typename... Args>
static auto FromPackedInplace(Args&&... args) {
// We must ensure TCallable is a value type (decay_t) that can hold the callable object
static_assert(std::is_same_v<TCallable, std::decay_t<TCallable>>);
static_assert(std::is_invocable_v<TCallable, const AnyView*, int32_t, Any*>);
using ObjType = details::FunctionObjImpl<TCallable>;
Function func;
auto obj_ptr = make_object<ObjType>(std::forward<Args>(args)...);
auto* call_ptr = obj_ptr->GetCallable();
func.data_ = std::move(obj_ptr);
return std::make_tuple(std::move(func), call_ptr);
}
/*!
* \brief Create ffi::Function from a C style callbacks.
*
* self and deleter can be nullptr if the function do not need closure support
* and corresponds to a raw function pointer.
*
* \param self Resource handle to the function
* \param safe_call The safe_call definition in C.
* \param deleter The deleter to release the resource of self.
* \return The created function.
*/
static Function FromExternC(void* self, TVMFFISafeCallType safe_call,
void (*deleter)(void* self)) {
// the other function coems from a different library
Function func;
if (self == nullptr && deleter == nullptr) {
func.data_ = make_object<details::ExternCFunctionObjNullHandleImpl>(safe_call);
} else {
func.data_ = make_object<details::ExternCFunctionObjImpl>(self, safe_call, deleter);
}
return func;
}
/*!
* \brief Get global function by name
* \param name The function name
* \return The global function.
* \note This function will return std::nullopt if the function is not found.
*/
static std::optional<Function> GetGlobal(std::string_view name) {
TVMFFIObjectHandle handle;
TVMFFIByteArray name_arr{name.data(), name.size()};
TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionGetGlobal(&name_arr, &handle));
if (handle != nullptr) {
return Function(
details::ObjectUnsafe::ObjectPtrFromOwned<FunctionObj>(static_cast<Object*>(handle)));
} else {
return std::nullopt;
}
}
/*!
* \brief Get global function by name
* \param name The name of the function
* \return The global function
* \note This function will return std::nullopt if the function is not found.
*/
static std::optional<Function> GetGlobal(const std::string& name) {
return GetGlobal(std::string_view(name.data(), name.length()));
}
/*!
* \brief Get global function by name
* \param name The name of the function
* \return The global function
* \note This function will return std::nullopt if the function is not found.
*/
static std::optional<Function> GetGlobal(const String& name) {
return GetGlobal(std::string_view(name.data(), name.length()));
}
/*!
* \brief Get global function by name
* \param name The name of the function
* \return The global function
* \note This function will return std::nullopt if the function is not found.
*/
static std::optional<Function> GetGlobal(const char* name) {
return GetGlobal(std::string_view(name));
}
/*!
* \brief Get global function by name and throw an error if it is not found.
* \param name The name of the function
* \return The global function
* \note This function will throw an error if the function is not found.
*/
static Function GetGlobalRequired(std::string_view name) {
std::optional<Function> res = GetGlobal(name);
if (!res.has_value()) {
TVM_FFI_THROW(ValueError) << "Function " << name << " not found";
}
return *res;
}
/*!
* \brief Get global function by name
* \param name The name of the function
* \return The global function
* \note This function will throw an error if the function is not found.
*/
static Function GetGlobalRequired(const std::string& name) {
return GetGlobalRequired(std::string_view(name.data(), name.length()));
}
/*!
* \brief Get global function by name
* \param name The name of the function
* \return The global function
* \note This function will throw an error if the function is not found.
*/
static Function GetGlobalRequired(const String& name) {
return GetGlobalRequired(std::string_view(name.data(), name.length()));
}
/*!
* \brief Get global function by name
* \param name The name of the function
* \return The global function
* \note This function will throw an error if the function is not found.
*/
static Function GetGlobalRequired(const char* name) {
return GetGlobalRequired(std::string_view(name));
}
/*!
* \brief Set global function by name
* \param name The name of the function
* \param func The function
* \param override Whether to override when there is duplication.
*/
static void SetGlobal(std::string_view name,
Function func, // NOLINT(performance-unnecessary-value-param)
bool override = false) {
TVMFFIByteArray name_arr{name.data(), name.size()};
TVM_FFI_CHECK_SAFE_CALL(
TVMFFIFunctionSetGlobal(&name_arr, details::ObjectUnsafe::GetHeader(func.get()), override));
}
/*!
* \brief List all global names
* \return A vector of all global names
* \note This function do not depend on Array so core do not have container dep.
*/
static std::vector<String> ListGlobalNames() {
Function fname_functor =
GetGlobalRequired("ffi.FunctionListGlobalNamesFunctor")().cast<Function>();
std::vector<String> names;
int len = fname_functor(-1).cast<int>();
names.reserve(len);
for (int i = 0; i < len; ++i) {
names.push_back(fname_functor(i).cast<String>());
}
return names;
}
/**
* \brief Remove a global function by name
* \param name The name of the function
*/
static void RemoveGlobal(const String& name) {
static Function fremove = GetGlobalRequired("ffi.FunctionRemoveGlobal");
fremove(name);
}
/*!
* \brief Constructing a packed function from a normal function.
*
* \param callable the internal container of packed function.
*/
template <typename TCallable>
static Function FromTyped(TCallable&& callable) {
using FuncInfo = details::FunctionInfo<std::decay_t<TCallable>>;
// Callable is always captured by value here to avoid possible dangling reference
auto call_packed = [callable = std::forward<TCallable>(callable)](
const AnyView* args, int32_t num_args, Any* rv) mutable -> void {
details::unpack_call<typename FuncInfo::RetType>(
std::make_index_sequence<FuncInfo::num_args>{}, nullptr, callable, args, num_args, rv);
};
return FromPackedInternal(std::move(call_packed));
}
/*!
* \brief Constructing a packed function from a normal function.
*
* \param callable the internal container of packed function.
* \param name optional name attacked to the function.
*/
template <typename TCallable>
static Function FromTyped(TCallable&& callable, std::string name) {
using FuncInfo = details::FunctionInfo<std::decay_t<TCallable>>;
// Callable is always captured by value here to avoid possible dangling reference
auto call_packed = [callable = std::forward<TCallable>(callable), name = std::move(name)](
const AnyView* args, int32_t num_args, Any* rv) mutable -> void {
details::unpack_call<typename FuncInfo::RetType>(
std::make_index_sequence<FuncInfo::num_args>{}, &name, callable, args, num_args, rv);
};
return FromPackedInternal(std::move(call_packed));
}
/*!
* \brief Directly invoke an extern "C" function that follows the TVM FFI SafeCall convention.
*
* This function can be useful to turn an existing exported symbol into a typed function.
*
* \code{.cpp}
* // An extern "C" function, matching TVMFFISafeCallType
* extern "C" int __tvm_ffi_add(
* void* handle, const TVMFFIAny* args, int32_t num_args, TVMFFIAny*result
* );
* // redirect an existing symbol into a typed function
* inline int add(int a, int b) {
* return tvm::ffi::Function::InvokeExternC(nullptr, __tvm_ffi_add, a, b).cast<int>();
* }
* \endcode
*
* \tparam Args The types of the arguments to the extern function.
* \param handle The handle argument, for exported symbols this is usually nullptr.
* \param safe_call The function pointer to the extern "C" function.
* \param args The arguments to pass to the function.
* \return The return value, wrapped in a tvm::ffi::Any.
*/
template <typename... Args>
TVM_FFI_INLINE static Any InvokeExternC(void* handle, TVMFFISafeCallType safe_call,
Args&&... args) {
const int kNumArgs = sizeof...(Args);
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
AnyView args_pack[kArraySize];
PackedArgs::Fill(args_pack, std::forward<Args>(args)...);
Any result;
TVM_FFI_CHECK_SAFE_CALL(safe_call(handle, reinterpret_cast<const TVMFFIAny*>(args_pack),
kNumArgs, reinterpret_cast<TVMFFIAny*>(&result)));
return result;
}
/*!
* \brief Call function by directly passing in unpacked arguments.
*
* \param args Arguments to be passed.
* \tparam Args arguments to be passed.
*
* \code{.cpp}
* // Example code on how to call packed function
* void CallFFIFunction(tvm::ffi::Function f) {
* // call like normal functions by pass in arguments
* // return value is automatically converted back
* int rvalue = f(1, 2.0);
* }
* \endcode
*/
template <typename... Args>
TVM_FFI_INLINE Any operator()(Args&&... args) const {
const int kNumArgs = sizeof...(Args);
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
AnyView args_pack[kArraySize];
PackedArgs::Fill(args_pack, std::forward<Args>(args)...);
Any result;
static_cast<FunctionObj*>(data_.get())->CallPacked(args_pack, kNumArgs, &result);
return result;
}
/*!
* \brief Call the function in packed format.
* \param args The arguments
* \param num_args The number of arguments
* \param result The return value.
*/
TVM_FFI_INLINE void CallPacked(const AnyView* args, int32_t num_args, Any* result) const {
static_cast<FunctionObj*>(data_.get())->CallPacked(args, num_args, result);
}
/*!
* \brief Call the function in packed format.
* \param args The arguments
* \param result The return value.
*/
TVM_FFI_INLINE void CallPacked(PackedArgs args, Any* result) const {
static_cast<FunctionObj*>(data_.get())->CallPacked(args.data(), args.size(), result);
}
/*!
* \brief Call the function and return Expected<T> for exception-free error handling.
* \tparam T The expected return type (default: Any).
* \param args The arguments to pass to the function.
* \return Expected<T> containing either the result or an error.
*
* This method provides exception-free calling by catching all exceptions
* and returning them as Error values in the Expected type.
*
* \code
* Function func = Function::GetGlobal("risky_function");
* Expected<int> result = func.CallExpected<int>(arg1, arg2);
* if (result.is_ok()) {
* int value = result.value();
* } else {
* Error err = result.error();
* }
* \endcode
*/
template <typename T = Any, typename... Args>
TVM_FFI_INLINE Expected<T> CallExpected(Args&&... args) const {
constexpr size_t kNumArgs = sizeof...(Args);
AnyView args_pack[kNumArgs > 0 ? kNumArgs : 1];
PackedArgs::Fill(args_pack, std::forward<Args>(args)...);
Any result;
FunctionObj* func_obj = static_cast<FunctionObj*>(data_.get());
// Use safe_call path to catch exceptions
int ret_code = func_obj->safe_call(func_obj, reinterpret_cast<const TVMFFIAny*>(args_pack),
kNumArgs, reinterpret_cast<TVMFFIAny*>(&result));
if (ret_code == 0) {
if constexpr (std::is_same_v<T, Any>) {
return std::move(result);
} else {
// Try T first (fast path), then Error
if (auto val = result.template try_cast<T>()) {
return *std::move(val);
}
if (auto err = result.template try_cast<Error>()) {
return Unexpected(std::move(*err));
}
return Unexpected(Error("TypeError",
"CallExpected: result type mismatch, expected " +
TypeTraits<T>::TypeStr() + ", but got " + result.GetTypeKey(),
""));
}
} else {
return Unexpected(details::MoveFromSafeCallRaised());
}
}
/*! \return Whether the packed function is nullptr */
TVM_FFI_INLINE bool operator==(std::nullptr_t) const { return data_ == nullptr; }
/*! \return Whether the packed function is not nullptr */
TVM_FFI_INLINE bool operator!=(std::nullptr_t) const { return data_ != nullptr; }
/// \cond Doxygen_Suppress
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Function, ObjectRef, FunctionObj);
/// \endcond
class Registry;
private:
/*!
* \brief Constructing a packed function from a callable type
* whose signature is consistent with `ffi::Function`
* \param packed_call The packed function signature
*/
template <typename TCallable>
static Function FromPackedInternal(TCallable&& packed_call) {
// We must make TCallable a value type (decay_t) that can hold the callable object
using ObjType = typename details::FunctionObjImpl<std::decay_t<TCallable>>;
Function func;
func.data_ = make_object<ObjType>(std::forward<TCallable>(packed_call));
return func;
}
};
/*!
* \brief Please refer to \ref TypedFunctionAnchor "TypedFunction<R(Args..)>"
*/
template <typename FType>
class TypedFunction;
/*!
* \anchor TypedFunctionAnchor
* \brief A ffi::Function wrapper to provide typed function signature.
* It is backed by a ffi::Function internally.
*
* TypedFunction enables compile time type checking.
* TypedFunction works with the runtime system:
* - It can be passed as an argument of ffi::Function.
* - It can be assigned to ffi::Any.
* - It can be directly converted to a type-erased ffi::Function.
*
* Developers should prefer TypedFunction over ffi::Function in C++ code
* as it enables compile time checking.
* We can construct a TypedFunction from a lambda function
* with the same signature.
*
* \code{.cpp}
* // user defined lambda function.
* auto addone = [](int x)->int { return x + 1; };
* // We can directly convert
* // lambda function to TypedFunction
* TypedFunction<int(int)> ftyped(addone);
* // invoke the function.
* int y = ftyped(1);
* // Can be directly converted to ffi::Function
* ffi::Function packed = ftype;
* \endcode
* \tparam R The return value of the function.
* \tparam Args The argument signature of the function.
*/
template <typename R, typename... Args>
class TypedFunction<R(Args...)> {
public:
/*! \brief short hand for this function type */
using TSelf = TypedFunction<R(Args...)>;
/*! \brief default constructor */
TypedFunction() = default;
/*! \brief constructor from null */
TypedFunction(std::nullptr_t null) {} // NOLINT(*)
/*!
* \brief constructor from a function
* \param packed The function
*/
TypedFunction(Function packed) : packed_(std::move(packed)) {} // NOLINT(*)
/*!
* \brief construct from a lambda function with the same signature.
*
* Example usage:
* \code{.cpp}
* auto typed_lambda = [](int x)->int { return x + 1; }
* // construct from packed function
* TypedFunction<int(int)> ftyped(typed_lambda, "add_one");
* // call the typed version.
* CHECK_EQ(ftyped(1), 2);
* \endcode
*
* \param typed_lambda typed lambda function.
* \param name the name of the lambda function.
* \tparam FLambda the type of the lambda function.
*/
template <typename FLambda,
typename = std::enable_if_t<std::is_convertible_v<FLambda, std::function<R(Args...)>>>>
TypedFunction(FLambda&& typed_lambda, std::string name) {
packed_ = Function::FromTyped(std::forward<FLambda>(typed_lambda), std::move(name));
}
/*!
* \brief construct from a lambda function with the same signature.
*
* This version does not take a name. It is highly recommend you use the
* version that takes a name for the lambda.
*
* Example usage:
* \code{.cpp}
* auto typed_lambda = [](int x)->int { return x + 1; }
* // construct from packed function
* TypedFunction<int(int)> ftyped(typed_lambda);
* // call the typed version.
* CHECK_EQ(ftyped(1), 2);
* \endcode
*
* \param typed_lambda typed lambda function.
* \tparam FLambda the type of the lambda function.
*/
template <typename FLambda,
typename = std::enable_if_t<std::is_convertible_v<FLambda, std::function<R(Args...)>> &&
!std::is_same_v<std::decay_t<FLambda>, TSelf>>>
TypedFunction(FLambda&& typed_lambda) { // NOLINT(google-explicit-constructor)
packed_ = Function::FromTyped(std::forward<FLambda>(typed_lambda));
}
/*!
* \brief copy assignment operator from typed lambda
*
* Example usage:
* \code{.cpp}
* // construct from packed function
* TypedFunction<int(int)> ftyped;
* ftyped = [](int x) { return x + 1; }
* // call the typed version.
* CHECK_EQ(ftyped(1), 2);
* \endcode
*
* \param typed_lambda typed lambda function.
* \tparam FLambda the type of the lambda function.
* \returns reference to self.
*/
template <typename FLambda,
typename = std::enable_if_t<std::is_convertible_v<FLambda, std::function<R(Args...)>> &&
!std::is_same_v<std::decay_t<FLambda>, TSelf>>>
TSelf& operator=(FLambda&& typed_lambda) {
packed_ = Function::FromTyped(std::forward<FLambda>(typed_lambda));
return *this;
}
/*!
* \brief copy assignment operator from ffi::Function.
* \param packed The packed function.
* \returns reference to self.
*/
TSelf& operator=(Function packed) {
packed_ = std::move(packed);
return *this;
}
/*!
* \brief Invoke the operator.
* \param args The arguments
* \returns The return value.
*/
TVM_FFI_INLINE R operator()(Args... args) const { // NOLINT(performance-unnecessary-value-param)
if constexpr (std::is_same_v<R, void>) {
packed_(std::forward<Args>(args)...);
} else {
Any res = packed_(std::forward<Args>(args)...);
if constexpr (std::is_same_v<R, Any>) {
return res;
} else {
return std::move(res).cast<R>();
}
}
}
/*!
* \brief convert to ffi::Function
* \return the internal ffi::Function
*/
operator Function() const { return packed(); } // NOLINT(google-explicit-constructor)
/*!
* \return reference the internal ffi::Function
*/
const Function& packed() const& { return packed_; }
/*!
* \return r-value reference the internal ffi::Function
*/
constexpr Function&& packed() && { return std::move(packed_); }
/*! \return Whether the packed function is nullptr */
bool operator==(std::nullptr_t null) const { return packed_ == nullptr; }
/*! \return Whether the packed function is not nullptr */
bool operator!=(std::nullptr_t null) const { return packed_ != nullptr; }
/*!
* \brief Get the type schema of `TypedFunction<R(Args...)>` in json format.
* \return The type schema of the function in json format.
*/
static std::string TypeSchema() { return details::FuncFunctorImpl<R, Args...>::TypeSchema(); }
private:
/*! \brief The internal packed function */
Function packed_;
};
template <typename FType>
inline constexpr bool use_default_type_traits_v<TypedFunction<FType>> = false;
template <typename FType>
struct TypeTraits<TypedFunction<FType>> : public TypeTraitsBase {
static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIFunction;
TVM_FFI_INLINE static void CopyToAnyView(const TypedFunction<FType>& src, TVMFFIAny* result) {
TypeTraits<Function>::CopyToAnyView(src.packed(), result);
}
TVM_FFI_INLINE static void MoveToAny(TypedFunction<FType> src, TVMFFIAny* result) {
// Move from rvalue to trigger TypedFunction rvalue packed() overload
TypeTraits<Function>::MoveToAny(std::move(src).packed(), result);
}
TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
return src->type_index == TypeIndex::kTVMFFIFunction;
}
TVM_FFI_INLINE static TypedFunction<FType> CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
return TypedFunction<FType>(TypeTraits<Function>::CopyFromAnyViewAfterCheck(src));
}
TVM_FFI_INLINE static std::optional<TypedFunction<FType>> TryCastFromAnyView(
const TVMFFIAny* src) {
std::optional<Function> opt = TypeTraits<Function>::TryCastFromAnyView(src);
if (opt.has_value()) {
return TypedFunction<FType>(*std::move(opt));
} else {
return std::nullopt;
}
}
TVM_FFI_INLINE static std::string TypeStr() { return details::FunctionInfo<FType>::Sig(); }
TVM_FFI_INLINE static std::string TypeSchema() { return TypedFunction<FType>::TypeSchema(); }
};
/*!
* \brief helper function to get type index from key
*/
inline int32_t TypeKeyToIndex(std::string_view type_key) {
int32_t type_index;
TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()};
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index));
return type_index;
}
/// \cond Doxygen_Suppress
// Internal implementation macros used by TVM_FFI_DLL_EXPORT_TYPED_FUNC and related macros.
// These should not be used directly; use the public macros instead.
// Internal implementation macro that generates the C ABI wrapper function
#define TVM_FFI_DLL_EXPORT_TYPED_FUNC_IMPL_(ExportName, Function) \
extern "C" { \
TVM_FFI_DLL_EXPORT int __tvm_ffi_##ExportName(void* self, const TVMFFIAny* args, \
int32_t num_args, TVMFFIAny* result) { \
TVM_FFI_SAFE_CALL_BEGIN(); \
using FuncInfo = ::tvm::ffi::details::FunctionInfo<decltype(Function)>; \
static std::string name = #ExportName; \
::tvm::ffi::details::unpack_call<typename FuncInfo::RetType>( \
std::make_index_sequence<FuncInfo::num_args>{}, &name, Function, \
reinterpret_cast<const ::tvm::ffi::AnyView*>(args), num_args, \
reinterpret_cast<::tvm::ffi::Any*>(result)); \
TVM_FFI_SAFE_CALL_END(); \
} \
}
/// \endcond
/*!
* \brief Export typed function as a SafeCallType symbol that follows the FFI ABI.
*
* This macro exports the function and automatically exports metadata when
* TVM_FFI_DLL_EXPORT_INCLUDE_METADATA is defined.
*
* \param ExportName The symbol name to be exported.
* \param Function The typed function.
*
* \sa ffi::TypedFunction, TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC
*
* \code{.cpp}
* int AddOne_(int x) {
* return x + 1;
* }
* // Expose the function as "AddOne"
* TVM_FFI_DLL_EXPORT_TYPED_FUNC(AddOne, AddOne_);
* // Expose the function as "SubOne"
* TVM_FFI_DLL_EXPORT_TYPED_FUNC(SubOne, [](int x) {
* return x - 1;
* });
* \endcode
*
* \note The final symbol names are:
* - `__tvm_ffi_<ExportName>` (function)
* - `__tvm_ffi__metadata_<ExportName>` (metadata - only when
* TVM_FFI_DLL_EXPORT_INCLUDE_METADATA is defined)
*/
#if TVM_FFI_DLL_EXPORT_INCLUDE_METADATA
// Implementation note: we specifically use TVMFFIStringFromByteArray
// so the returned string metadata is allocated in the libtvm_ffi and long lived.
#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \
TVM_FFI_DLL_EXPORT_TYPED_FUNC_IMPL_(ExportName, Function) \
extern "C" { \
TVM_FFI_DLL_EXPORT int __tvm_ffi__metadata_##ExportName(void* self, const TVMFFIAny* args, \
int32_t num_args, TVMFFIAny* result) { \
TVM_FFI_SAFE_CALL_BEGIN(); \
using FuncInfo = ::tvm::ffi::details::FunctionInfo<decltype(Function)>; \
std::ostringstream os; \
os << R"({"type_schema":)" \
<< ::tvm::ffi::EscapeStringJSON(::tvm::ffi::String(FuncInfo::TypeSchema())) << R"(})"; \
std::string data = os.str(); \
TVMFFIByteArray data_array{data.data(), data.size()}; \
return TVMFFIStringFromByteArray(&data_array, result); \
TVM_FFI_SAFE_CALL_END(); \
} \
}
#else
#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \
TVM_FFI_DLL_EXPORT_TYPED_FUNC_IMPL_(ExportName, Function)
#endif
/*!
* \brief Export documentation string for a typed function.
*
* This macro exports a documentation string associated with a function export name.
* The docstring can be used by stub generators and documentation tools.
* This macro only exports the docstring; it does not export the function itself.
*
* \param ExportName The symbol name that the docstring is associated with.
* \param DocString The documentation string (C string literal).
*
* \sa ffi::TypedFunction, TVM_FFI_DLL_EXPORT_TYPED_FUNC
*
* \code{.cpp}
* int Add(int a, int b) {
* return a + b;
* }
*
* TVM_FFI_DLL_EXPORT_TYPED_FUNC(add, Add);
* TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC(
* add,
* R"(Add two integers and return the sum.
*
* Parameters
* ----------
* a : int
* First integer
* b : int
* Second integer
*
* Returns
* -------
* result : int
* Sum of a and b)");
* \endcode
*
* \note The exported symbol name is `__tvm_ffi__doc_<ExportName>` (docstring getter function).
* This symbol is only exported when TVM_FFI_DLL_EXPORT_INCLUDE_METADATA is defined.
*/
#if TVM_FFI_DLL_EXPORT_INCLUDE_METADATA
// Implementation note: we specifically use TVMFFIStringFromByteArray
// so the returned string metadata is allocated in the libtvm_ffi and long lived.
#define TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC(ExportName, DocString) \
extern "C" { \
TVM_FFI_DLL_EXPORT int __tvm_ffi__doc_##ExportName(void* self, const TVMFFIAny* args, \
int32_t num_args, TVMFFIAny* result) { \
TVM_FFI_SAFE_CALL_BEGIN(); \
std::string_view data(DocString); \
TVMFFIByteArray data_array{data.data(), data.size()}; \
return TVMFFIStringFromByteArray(&data_array, result); \
TVM_FFI_SAFE_CALL_END(); \
} \
}
#else
#define TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC(ExportName, DocString)
#endif
} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_FUNCTION_H_