| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| /*! |
| * \file tvm/ffi/function.h |
| * \brief A managed function in the TVM FFI. |
| */ |
| #ifndef TVM_FFI_FUNCTION_H_ |
| #define TVM_FFI_FUNCTION_H_ |
| |
| /*! |
| * \brief Controls whether DLL exports should include metadata. |
| * |
| * When set to 1, exported functions will include additional metadata. |
| * When set to 0 (default), exports are minimal without metadata. |
| */ |
| #ifndef TVM_FFI_DLL_EXPORT_INCLUDE_METADATA |
| #define TVM_FFI_DLL_EXPORT_INCLUDE_METADATA 0 |
| #endif |
| |
| #if TVM_FFI_DLL_EXPORT_INCLUDE_METADATA |
| #include <sstream> |
| #endif // TVM_FFI_DLL_EXPORT_INCLUDE_METADATA |
| |
| #include <tvm/ffi/any.h> |
| #include <tvm/ffi/base_details.h> |
| #include <tvm/ffi/c_api.h> |
| #include <tvm/ffi/error.h> |
| #include <tvm/ffi/expected.h> |
| #include <tvm/ffi/function_details.h> |
| |
| #include <functional> |
| #include <optional> |
| #include <string> |
| #include <tuple> |
| #include <type_traits> |
| #include <utility> |
| #include <vector> |
| |
| namespace tvm { |
| namespace ffi { |
| |
| /** |
| * Helper macro to construct a safe call |
| * |
| * \brief Marks the beginning of the safe call that catches exception explicitly |
| * \sa TVM_FFI_SAFE_CALL_END |
| * |
| * \code{.cpp} |
| * int TVMFFICStyleFunction() { |
| * TVM_FFI_SAFE_CALL_BEGIN(); |
| * // c++ code region here |
| * TVM_FFI_SAFE_CALL_END(); |
| * } |
| * \endcode |
| */ |
| #define TVM_FFI_SAFE_CALL_BEGIN() \ |
| try { \ |
| (void)0 |
| |
| /*! |
| * \brief Marks the end of safe call. |
| */ |
| #define TVM_FFI_SAFE_CALL_END() \ |
| return 0; \ |
| } \ |
| catch (const ::tvm::ffi::Error& err) { \ |
| ::tvm::ffi::details::SetSafeCallRaised(err); \ |
| return -1; \ |
| } \ |
| catch (const std::exception& ex) { \ |
| ::tvm::ffi::details::SetSafeCallRaised(::tvm::ffi::Error("InternalError", ex.what(), "")); \ |
| return -1; \ |
| } \ |
| TVM_FFI_UNREACHABLE() |
| |
| /*! |
| * \brief Macro to check a call to TVMFFISafeCallType and raise exception if error happens. |
| * \param func The function to check. |
| * |
| * \code{.cpp} |
| * // calls TVMFFIFunctionCall and raises exception if error happens |
| * TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_arr, &type_index)); |
| * \endcode |
| */ |
| #define TVM_FFI_CHECK_SAFE_CALL(func) \ |
| { \ |
| int ret_code = (func); \ |
| if (TVM_FFI_PREDICT_FALSE(ret_code != 0)) { \ |
| throw ::tvm::ffi::details::MoveFromSafeCallRaised(); \ |
| } \ |
| } |
| |
| /*! |
| * \brief Object container class that backs ffi::Function |
| * \note Do not use this class directly, use ffi::Function |
| */ |
| class FunctionObj : public Object, public TVMFFIFunctionCell { |
| public: |
| /*! \brief Typedef for C++ style calling signature that comes with exception propagation */ |
| using FCall = void (*)(const FunctionObj*, const AnyView*, int32_t, Any*); |
| using TVMFFIFunctionCell::cpp_call; |
| using TVMFFIFunctionCell::safe_call; |
| /*! |
| * \brief Call the function in packed format. |
| * \param args The arguments |
| * \param num_args The number of arguments |
| * \param result The return value. |
| */ |
| TVM_FFI_INLINE void CallPacked(const AnyView* args, int32_t num_args, Any* result) const { |
| // if cpp_call is set, use it to call the function, otherwise, redirect to safe_call |
| // use conditional expression here so the select is branchless |
| FCall call_ptr = |
| this->cpp_call ? reinterpret_cast<FCall>(this->cpp_call) : CppCallDedirectToSafeCall; |
| (*call_ptr)(this, args, num_args, result); |
| } |
| /// \cond Doxygen_Suppress |
| static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIFunction; |
| TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIFunction, FunctionObj, Object); |
| /// \endcond |
| |
| protected: |
| /*! \brief Make default constructor protected. */ |
| FunctionObj() {} |
| friend class Function; |
| |
| private: |
| static void CppCallDedirectToSafeCall(const FunctionObj* func, const AnyView* args, |
| int32_t num_args, Any* rv) { |
| FunctionObj* self = static_cast<FunctionObj*>(const_cast<FunctionObj*>(func)); |
| TVM_FFI_CHECK_SAFE_CALL(self->safe_call(self, reinterpret_cast<const TVMFFIAny*>(args), |
| num_args, reinterpret_cast<TVMFFIAny*>(rv))); |
| } |
| }; |
| |
| namespace details { |
| /*! |
| * \brief Derived object class for constructing FunctionObj backed by a TCallable |
| * |
| * This is a helper class that implements the function call interface. |
| * Invariance: TCallable cannot be const or reference type. |
| */ |
| template <typename TCallable> |
| class FunctionObjImpl : public FunctionObj { |
| public: |
| static_assert(std::is_same_v<TCallable, std::remove_cv_t<std::remove_reference_t<TCallable>>>, |
| "TCallable of FunctionObjImpl cannot be const or reference type"); |
| |
| /*! \brief The type of derived object class */ |
| using TSelf = FunctionObjImpl<TCallable>; |
| |
| /*! |
| * \brief Derived object class for constructing ffi::FunctionObj. |
| * \param args The arguments to construct TCallable |
| */ |
| template <typename... Args> |
| explicit FunctionObjImpl(Args&&... args) : callable_(std::forward<Args>(args)...) { |
| this->safe_call = SafeCall; |
| this->cpp_call = reinterpret_cast<void*>(CppCall); |
| } |
| |
| FunctionObjImpl(const FunctionObjImpl&) = delete; |
| FunctionObjImpl& operator=(const FunctionObjImpl&) = delete; |
| |
| TCallable* GetCallable() { return &callable_; } |
| |
| private: |
| // implementation of call |
| static void CppCall(const FunctionObj* func, const AnyView* args, int32_t num_args, Any* result) { |
| (static_cast<const TSelf*>(func))->callable_(args, num_args, result); |
| } |
| /// \cond Doxygen_Suppress |
| // Implementing safe call style |
| static int SafeCall(void* func, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* result) { |
| TVM_FFI_SAFE_CALL_BEGIN(); |
| TVM_FFI_ICHECK_LT(result->type_index, TypeIndex::kTVMFFIStaticObjectBegin); |
| FunctionObj* self = static_cast<FunctionObj*>(func); |
| reinterpret_cast<FCall>(self->cpp_call)(self, reinterpret_cast<const AnyView*>(args), num_args, |
| reinterpret_cast<Any*>(result)); |
| TVM_FFI_SAFE_CALL_END(); |
| } |
| /// \endcond |
| /*! \brief Type-erased filed for storing callable object*/ |
| mutable TCallable callable_; |
| }; |
| |
| /*! |
| * \brief FunctionObj specialization for raw C style callback where handle and deleter are null. |
| */ |
| class ExternCFunctionObjNullHandleImpl : public FunctionObj { |
| public: |
| explicit ExternCFunctionObjNullHandleImpl(TVMFFISafeCallType safe_call) { |
| this->safe_call = safe_call; |
| this->cpp_call = nullptr; |
| } |
| }; |
| |
| /*! |
| * \brief FunctionObj specialization that leverages C-style callback definitions. |
| */ |
| class ExternCFunctionObjImpl : public FunctionObj { |
| public: |
| ExternCFunctionObjImpl(void* self, TVMFFISafeCallType safe_call, void (*deleter)(void* self)) |
| : self_(self), safe_call_(safe_call), deleter_(deleter) { |
| this->safe_call = SafeCall; |
| this->cpp_call = nullptr; |
| } |
| |
| ~ExternCFunctionObjImpl() { |
| if (deleter_) deleter_(self_); |
| } |
| |
| private: |
| static int32_t SafeCall(void* func, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* rv) { |
| ExternCFunctionObjImpl* self = reinterpret_cast<ExternCFunctionObjImpl*>(func); |
| return self->safe_call_(self->self_, args, num_args, rv); |
| } |
| |
| void* self_; |
| TVMFFISafeCallType safe_call_; |
| void (*deleter_)(void* self); |
| }; |
| |
| // Helper class to set packed arguments |
| class PackedArgsSetter { |
| public: |
| explicit PackedArgsSetter(AnyView* args) : args_(args) {} |
| |
| // NOTE: setter needs to be very carefully designed |
| // such that we do not have temp variable conversion(eg. convert from lvalue to rvalue) |
| // that is why we need T&& and std::forward here |
| template <typename T> |
| TVM_FFI_INLINE void operator()(size_t i, T&& value) const { |
| args_[i].operator=(std::forward<T>(value)); |
| } |
| |
| private: |
| AnyView* args_; |
| }; |
| } // namespace details |
| |
| /*! |
| * \brief Represents arguments packed in AnyView array |
| * \note This class represent packed arguments to ffi::Function |
| */ |
| class PackedArgs { |
| public: |
| /*! |
| * \brief Constructor |
| * \param data The arguments |
| * \param size The number of arguments |
| */ |
| PackedArgs(const AnyView* data, int32_t size) : data_(data), size_(size) {} |
| |
| /*! \return size of the arguments */ |
| int size() const { return size_; } |
| |
| /*! \return The arguments */ |
| const AnyView* data() const { return data_; } |
| |
| /*! |
| * \brief Slice the arguments |
| * \param begin The begin index |
| * \param end The end index |
| * \return The sliced arguments |
| */ |
| PackedArgs Slice(int begin, int end = -1) const { |
| if (end == -1) { |
| end = size_; |
| } |
| return PackedArgs(data_ + begin, end - begin); |
| } |
| |
| /*! |
| * \brief Get i-th argument |
| * \param i the index. |
| * \return the ith argument. |
| */ |
| AnyView operator[](int i) const { return data_[i]; } |
| |
| /*! |
| * \brief Fill the arguments into the AnyView array |
| * \param data The AnyView array to store the packed arguments |
| * \param args The arguments to be packed |
| * \note Caller must ensure all args are alive during lifetime of data. |
| * A common pitfall is to pass in local variables that are immediately |
| * destroyed after calling Fill. |
| */ |
| template <typename... Args> |
| TVM_FFI_INLINE static void Fill(AnyView* data, Args&&... args) { |
| details::for_each(details::PackedArgsSetter(data), std::forward<Args>(args)...); |
| } |
| |
| private: |
| /*! \brief The arguments */ |
| const AnyView* data_; |
| /*! \brief The number of arguments */ |
| int32_t size_; |
| }; |
| |
| /*! |
| * \brief ffi::Function is a type-erased function. |
| * The arguments are passed by "packed format" via AnyView |
| */ |
| class Function : public ObjectRef { |
| public: |
| /*! \brief Constructor from null */ |
| Function(std::nullptr_t) : ObjectRef(nullptr) {} // NOLINT(*) |
| /*! |
| * \brief Constructing a packed function from a callable type |
| * whose signature is consistent with `ffi::Function` |
| * \param packed_call The packed function signature |
| * \note legacy purpose, should change to Function::FromPacked for mostfuture use. |
| */ |
| template <typename TCallable, |
| typename = std::enable_if_t<!std::is_same_v<std::decay_t<TCallable>, Function>>> |
| explicit Function(TCallable&& packed_call) { |
| *this = FromPacked(std::forward<TCallable>(packed_call)); |
| } |
| /*! |
| * \brief Constructing a packed function from a callable type |
| * whose signature is consistent with `ffi::Function` |
| * \param packed_call The packed function signature |
| */ |
| template <typename TCallable> |
| static Function FromPacked(TCallable&& packed_call) { |
| static_assert( |
| std::is_convertible_v<TCallable, std::function<void(const AnyView*, int32_t, Any*)>> || |
| std::is_convertible_v<TCallable, std::function<void(PackedArgs args, Any*)>>, |
| "tvm::ffi::Function::FromPacked requires input function signature to match packed func " |
| "format"); |
| if constexpr (std::is_convertible_v<TCallable, std::function<void(PackedArgs args, Any*)>>) { |
| return FromPackedInternal( |
| [packed_call = std::forward<TCallable>(packed_call)]( |
| const AnyView* args, int32_t num_args, Any* rv) mutable -> void { |
| packed_call(PackedArgs{args, num_args}, rv); |
| }); |
| } else { |
| return FromPackedInternal(std::forward<TCallable>(packed_call)); |
| } |
| } |
| |
| /*! |
| * \brief Constructing a packed function from a callable type |
| * whose signature is consistent with `ffi::Function`. |
| * It will create the Callable object with the given arguments, |
| * and return the inplace constructed Function along with |
| * the pointer to the callable object. The lifetime of the callable |
| * object is managed by the returned Function. |
| * \param args The arguments to construct TCallable |
| * \return A tuple of (Function, TCallable*) |
| */ |
| template <typename TCallable, typename... Args> |
| static auto FromPackedInplace(Args&&... args) { |
| // We must ensure TCallable is a value type (decay_t) that can hold the callable object |
| static_assert(std::is_same_v<TCallable, std::decay_t<TCallable>>); |
| static_assert(std::is_invocable_v<TCallable, const AnyView*, int32_t, Any*>); |
| using ObjType = details::FunctionObjImpl<TCallable>; |
| Function func; |
| auto obj_ptr = make_object<ObjType>(std::forward<Args>(args)...); |
| auto* call_ptr = obj_ptr->GetCallable(); |
| func.data_ = std::move(obj_ptr); |
| return std::make_tuple(std::move(func), call_ptr); |
| } |
| |
| /*! |
| * \brief Create ffi::Function from a C style callbacks. |
| * |
| * self and deleter can be nullptr if the function do not need closure support |
| * and corresponds to a raw function pointer. |
| * |
| * \param self Resource handle to the function |
| * \param safe_call The safe_call definition in C. |
| * \param deleter The deleter to release the resource of self. |
| * \return The created function. |
| */ |
| static Function FromExternC(void* self, TVMFFISafeCallType safe_call, |
| void (*deleter)(void* self)) { |
| // the other function coems from a different library |
| Function func; |
| if (self == nullptr && deleter == nullptr) { |
| func.data_ = make_object<details::ExternCFunctionObjNullHandleImpl>(safe_call); |
| } else { |
| func.data_ = make_object<details::ExternCFunctionObjImpl>(self, safe_call, deleter); |
| } |
| return func; |
| } |
| /*! |
| * \brief Get global function by name |
| * \param name The function name |
| * \return The global function. |
| * \note This function will return std::nullopt if the function is not found. |
| */ |
| static std::optional<Function> GetGlobal(std::string_view name) { |
| TVMFFIObjectHandle handle; |
| TVMFFIByteArray name_arr{name.data(), name.size()}; |
| TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionGetGlobal(&name_arr, &handle)); |
| if (handle != nullptr) { |
| return Function( |
| details::ObjectUnsafe::ObjectPtrFromOwned<FunctionObj>(static_cast<Object*>(handle))); |
| } else { |
| return std::nullopt; |
| } |
| } |
| |
| /*! |
| * \brief Get global function by name |
| * \param name The name of the function |
| * \return The global function |
| * \note This function will return std::nullopt if the function is not found. |
| */ |
| static std::optional<Function> GetGlobal(const std::string& name) { |
| return GetGlobal(std::string_view(name.data(), name.length())); |
| } |
| |
| /*! |
| * \brief Get global function by name |
| * \param name The name of the function |
| * \return The global function |
| * \note This function will return std::nullopt if the function is not found. |
| */ |
| static std::optional<Function> GetGlobal(const String& name) { |
| return GetGlobal(std::string_view(name.data(), name.length())); |
| } |
| |
| /*! |
| * \brief Get global function by name |
| * \param name The name of the function |
| * \return The global function |
| * \note This function will return std::nullopt if the function is not found. |
| */ |
| static std::optional<Function> GetGlobal(const char* name) { |
| return GetGlobal(std::string_view(name)); |
| } |
| /*! |
| * \brief Get global function by name and throw an error if it is not found. |
| * \param name The name of the function |
| * \return The global function |
| * \note This function will throw an error if the function is not found. |
| */ |
| static Function GetGlobalRequired(std::string_view name) { |
| std::optional<Function> res = GetGlobal(name); |
| if (!res.has_value()) { |
| TVM_FFI_THROW(ValueError) << "Function " << name << " not found"; |
| } |
| return *res; |
| } |
| |
| /*! |
| * \brief Get global function by name |
| * \param name The name of the function |
| * \return The global function |
| * \note This function will throw an error if the function is not found. |
| */ |
| static Function GetGlobalRequired(const std::string& name) { |
| return GetGlobalRequired(std::string_view(name.data(), name.length())); |
| } |
| |
| /*! |
| * \brief Get global function by name |
| * \param name The name of the function |
| * \return The global function |
| * \note This function will throw an error if the function is not found. |
| */ |
| static Function GetGlobalRequired(const String& name) { |
| return GetGlobalRequired(std::string_view(name.data(), name.length())); |
| } |
| |
| /*! |
| * \brief Get global function by name |
| * \param name The name of the function |
| * \return The global function |
| * \note This function will throw an error if the function is not found. |
| */ |
| static Function GetGlobalRequired(const char* name) { |
| return GetGlobalRequired(std::string_view(name)); |
| } |
| /*! |
| * \brief Set global function by name |
| * \param name The name of the function |
| * \param func The function |
| * \param override Whether to override when there is duplication. |
| */ |
| static void SetGlobal(std::string_view name, |
| Function func, // NOLINT(performance-unnecessary-value-param) |
| bool override = false) { |
| TVMFFIByteArray name_arr{name.data(), name.size()}; |
| TVM_FFI_CHECK_SAFE_CALL( |
| TVMFFIFunctionSetGlobal(&name_arr, details::ObjectUnsafe::GetHeader(func.get()), override)); |
| } |
| /*! |
| * \brief List all global names |
| * \return A vector of all global names |
| * \note This function do not depend on Array so core do not have container dep. |
| */ |
| static std::vector<String> ListGlobalNames() { |
| Function fname_functor = |
| GetGlobalRequired("ffi.FunctionListGlobalNamesFunctor")().cast<Function>(); |
| std::vector<String> names; |
| int len = fname_functor(-1).cast<int>(); |
| names.reserve(len); |
| for (int i = 0; i < len; ++i) { |
| names.push_back(fname_functor(i).cast<String>()); |
| } |
| return names; |
| } |
| /** |
| * \brief Remove a global function by name |
| * \param name The name of the function |
| */ |
| static void RemoveGlobal(const String& name) { |
| static Function fremove = GetGlobalRequired("ffi.FunctionRemoveGlobal"); |
| fremove(name); |
| } |
| /*! |
| * \brief Constructing a packed function from a normal function. |
| * |
| * \param callable the internal container of packed function. |
| */ |
| template <typename TCallable> |
| static Function FromTyped(TCallable&& callable) { |
| using FuncInfo = details::FunctionInfo<std::decay_t<TCallable>>; |
| // Callable is always captured by value here to avoid possible dangling reference |
| auto call_packed = [callable = std::forward<TCallable>(callable)]( |
| const AnyView* args, int32_t num_args, Any* rv) mutable -> void { |
| details::unpack_call<typename FuncInfo::RetType>( |
| std::make_index_sequence<FuncInfo::num_args>{}, nullptr, callable, args, num_args, rv); |
| }; |
| return FromPackedInternal(std::move(call_packed)); |
| } |
| /*! |
| * \brief Constructing a packed function from a normal function. |
| * |
| * \param callable the internal container of packed function. |
| * \param name optional name attacked to the function. |
| */ |
| template <typename TCallable> |
| static Function FromTyped(TCallable&& callable, std::string name) { |
| using FuncInfo = details::FunctionInfo<std::decay_t<TCallable>>; |
| // Callable is always captured by value here to avoid possible dangling reference |
| auto call_packed = [callable = std::forward<TCallable>(callable), name = std::move(name)]( |
| const AnyView* args, int32_t num_args, Any* rv) mutable -> void { |
| details::unpack_call<typename FuncInfo::RetType>( |
| std::make_index_sequence<FuncInfo::num_args>{}, &name, callable, args, num_args, rv); |
| }; |
| return FromPackedInternal(std::move(call_packed)); |
| } |
| |
| /*! |
| * \brief Directly invoke an extern "C" function that follows the TVM FFI SafeCall convention. |
| * |
| * This function can be useful to turn an existing exported symbol into a typed function. |
| * |
| * \code{.cpp} |
| * // An extern "C" function, matching TVMFFISafeCallType |
| * extern "C" int __tvm_ffi_add( |
| * void* handle, const TVMFFIAny* args, int32_t num_args, TVMFFIAny*result |
| * ); |
| * // redirect an existing symbol into a typed function |
| * inline int add(int a, int b) { |
| * return tvm::ffi::Function::InvokeExternC(nullptr, __tvm_ffi_add, a, b).cast<int>(); |
| * } |
| * \endcode |
| * |
| * \tparam Args The types of the arguments to the extern function. |
| * \param handle The handle argument, for exported symbols this is usually nullptr. |
| * \param safe_call The function pointer to the extern "C" function. |
| * \param args The arguments to pass to the function. |
| * \return The return value, wrapped in a tvm::ffi::Any. |
| */ |
| template <typename... Args> |
| TVM_FFI_INLINE static Any InvokeExternC(void* handle, TVMFFISafeCallType safe_call, |
| Args&&... args) { |
| const int kNumArgs = sizeof...(Args); |
| const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; |
| AnyView args_pack[kArraySize]; |
| PackedArgs::Fill(args_pack, std::forward<Args>(args)...); |
| Any result; |
| TVM_FFI_CHECK_SAFE_CALL(safe_call(handle, reinterpret_cast<const TVMFFIAny*>(args_pack), |
| kNumArgs, reinterpret_cast<TVMFFIAny*>(&result))); |
| return result; |
| } |
| /*! |
| * \brief Call function by directly passing in unpacked arguments. |
| * |
| * \param args Arguments to be passed. |
| * \tparam Args arguments to be passed. |
| * |
| * \code{.cpp} |
| * // Example code on how to call packed function |
| * void CallFFIFunction(tvm::ffi::Function f) { |
| * // call like normal functions by pass in arguments |
| * // return value is automatically converted back |
| * int rvalue = f(1, 2.0); |
| * } |
| * \endcode |
| */ |
| template <typename... Args> |
| TVM_FFI_INLINE Any operator()(Args&&... args) const { |
| const int kNumArgs = sizeof...(Args); |
| const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; |
| AnyView args_pack[kArraySize]; |
| PackedArgs::Fill(args_pack, std::forward<Args>(args)...); |
| Any result; |
| static_cast<FunctionObj*>(data_.get())->CallPacked(args_pack, kNumArgs, &result); |
| return result; |
| } |
| /*! |
| * \brief Call the function in packed format. |
| * \param args The arguments |
| * \param num_args The number of arguments |
| * \param result The return value. |
| */ |
| TVM_FFI_INLINE void CallPacked(const AnyView* args, int32_t num_args, Any* result) const { |
| static_cast<FunctionObj*>(data_.get())->CallPacked(args, num_args, result); |
| } |
| /*! |
| * \brief Call the function in packed format. |
| * \param args The arguments |
| * \param result The return value. |
| */ |
| TVM_FFI_INLINE void CallPacked(PackedArgs args, Any* result) const { |
| static_cast<FunctionObj*>(data_.get())->CallPacked(args.data(), args.size(), result); |
| } |
| |
| /*! |
| * \brief Call the function and return Expected<T> for exception-free error handling. |
| * \tparam T The expected return type (default: Any). |
| * \param args The arguments to pass to the function. |
| * \return Expected<T> containing either the result or an error. |
| * |
| * This method provides exception-free calling by catching all exceptions |
| * and returning them as Error values in the Expected type. |
| * |
| * \code |
| * Function func = Function::GetGlobal("risky_function"); |
| * Expected<int> result = func.CallExpected<int>(arg1, arg2); |
| * if (result.is_ok()) { |
| * int value = result.value(); |
| * } else { |
| * Error err = result.error(); |
| * } |
| * \endcode |
| */ |
| template <typename T = Any, typename... Args> |
| TVM_FFI_INLINE Expected<T> CallExpected(Args&&... args) const { |
| constexpr size_t kNumArgs = sizeof...(Args); |
| AnyView args_pack[kNumArgs > 0 ? kNumArgs : 1]; |
| PackedArgs::Fill(args_pack, std::forward<Args>(args)...); |
| |
| Any result; |
| FunctionObj* func_obj = static_cast<FunctionObj*>(data_.get()); |
| |
| // Use safe_call path to catch exceptions |
| int ret_code = func_obj->safe_call(func_obj, reinterpret_cast<const TVMFFIAny*>(args_pack), |
| kNumArgs, reinterpret_cast<TVMFFIAny*>(&result)); |
| |
| if (ret_code == 0) { |
| if constexpr (std::is_same_v<T, Any>) { |
| return std::move(result); |
| } else { |
| // Try T first (fast path), then Error |
| if (auto val = result.template try_cast<T>()) { |
| return *std::move(val); |
| } |
| if (auto err = result.template try_cast<Error>()) { |
| return Unexpected(std::move(*err)); |
| } |
| return Unexpected(Error("TypeError", |
| "CallExpected: result type mismatch, expected " + |
| TypeTraits<T>::TypeStr() + ", but got " + result.GetTypeKey(), |
| "")); |
| } |
| } else { |
| return Unexpected(details::MoveFromSafeCallRaised()); |
| } |
| } |
| |
| /*! \return Whether the packed function is nullptr */ |
| TVM_FFI_INLINE bool operator==(std::nullptr_t) const { return data_ == nullptr; } |
| /*! \return Whether the packed function is not nullptr */ |
| TVM_FFI_INLINE bool operator!=(std::nullptr_t) const { return data_ != nullptr; } |
| |
| /// \cond Doxygen_Suppress |
| TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Function, ObjectRef, FunctionObj); |
| /// \endcond |
| |
| class Registry; |
| |
| private: |
| /*! |
| * \brief Constructing a packed function from a callable type |
| * whose signature is consistent with `ffi::Function` |
| * \param packed_call The packed function signature |
| */ |
| template <typename TCallable> |
| static Function FromPackedInternal(TCallable&& packed_call) { |
| // We must make TCallable a value type (decay_t) that can hold the callable object |
| using ObjType = typename details::FunctionObjImpl<std::decay_t<TCallable>>; |
| Function func; |
| func.data_ = make_object<ObjType>(std::forward<TCallable>(packed_call)); |
| return func; |
| } |
| }; |
| |
| /*! |
| * \brief Please refer to \ref TypedFunctionAnchor "TypedFunction<R(Args..)>" |
| */ |
| template <typename FType> |
| class TypedFunction; |
| |
| /*! |
| * \anchor TypedFunctionAnchor |
| * \brief A ffi::Function wrapper to provide typed function signature. |
| * It is backed by a ffi::Function internally. |
| * |
| * TypedFunction enables compile time type checking. |
| * TypedFunction works with the runtime system: |
| * - It can be passed as an argument of ffi::Function. |
| * - It can be assigned to ffi::Any. |
| * - It can be directly converted to a type-erased ffi::Function. |
| * |
| * Developers should prefer TypedFunction over ffi::Function in C++ code |
| * as it enables compile time checking. |
| * We can construct a TypedFunction from a lambda function |
| * with the same signature. |
| * |
| * \code{.cpp} |
| * // user defined lambda function. |
| * auto addone = [](int x)->int { return x + 1; }; |
| * // We can directly convert |
| * // lambda function to TypedFunction |
| * TypedFunction<int(int)> ftyped(addone); |
| * // invoke the function. |
| * int y = ftyped(1); |
| * // Can be directly converted to ffi::Function |
| * ffi::Function packed = ftype; |
| * \endcode |
| * \tparam R The return value of the function. |
| * \tparam Args The argument signature of the function. |
| */ |
| template <typename R, typename... Args> |
| class TypedFunction<R(Args...)> { |
| public: |
| /*! \brief short hand for this function type */ |
| using TSelf = TypedFunction<R(Args...)>; |
| /*! \brief default constructor */ |
| TypedFunction() = default; |
| /*! \brief constructor from null */ |
| TypedFunction(std::nullptr_t null) {} // NOLINT(*) |
| /*! |
| * \brief constructor from a function |
| * \param packed The function |
| */ |
| TypedFunction(Function packed) : packed_(std::move(packed)) {} // NOLINT(*) |
| /*! |
| * \brief construct from a lambda function with the same signature. |
| * |
| * Example usage: |
| * \code{.cpp} |
| * auto typed_lambda = [](int x)->int { return x + 1; } |
| * // construct from packed function |
| * TypedFunction<int(int)> ftyped(typed_lambda, "add_one"); |
| * // call the typed version. |
| * CHECK_EQ(ftyped(1), 2); |
| * \endcode |
| * |
| * \param typed_lambda typed lambda function. |
| * \param name the name of the lambda function. |
| * \tparam FLambda the type of the lambda function. |
| */ |
| template <typename FLambda, |
| typename = std::enable_if_t<std::is_convertible_v<FLambda, std::function<R(Args...)>>>> |
| TypedFunction(FLambda&& typed_lambda, std::string name) { |
| packed_ = Function::FromTyped(std::forward<FLambda>(typed_lambda), std::move(name)); |
| } |
| /*! |
| * \brief construct from a lambda function with the same signature. |
| * |
| * This version does not take a name. It is highly recommend you use the |
| * version that takes a name for the lambda. |
| * |
| * Example usage: |
| * \code{.cpp} |
| * auto typed_lambda = [](int x)->int { return x + 1; } |
| * // construct from packed function |
| * TypedFunction<int(int)> ftyped(typed_lambda); |
| * // call the typed version. |
| * CHECK_EQ(ftyped(1), 2); |
| * \endcode |
| * |
| * \param typed_lambda typed lambda function. |
| * \tparam FLambda the type of the lambda function. |
| */ |
| template <typename FLambda, |
| typename = std::enable_if_t<std::is_convertible_v<FLambda, std::function<R(Args...)>> && |
| !std::is_same_v<std::decay_t<FLambda>, TSelf>>> |
| TypedFunction(FLambda&& typed_lambda) { // NOLINT(google-explicit-constructor) |
| packed_ = Function::FromTyped(std::forward<FLambda>(typed_lambda)); |
| } |
| /*! |
| * \brief copy assignment operator from typed lambda |
| * |
| * Example usage: |
| * \code{.cpp} |
| * // construct from packed function |
| * TypedFunction<int(int)> ftyped; |
| * ftyped = [](int x) { return x + 1; } |
| * // call the typed version. |
| * CHECK_EQ(ftyped(1), 2); |
| * \endcode |
| * |
| * \param typed_lambda typed lambda function. |
| * \tparam FLambda the type of the lambda function. |
| * \returns reference to self. |
| */ |
| template <typename FLambda, |
| typename = std::enable_if_t<std::is_convertible_v<FLambda, std::function<R(Args...)>> && |
| !std::is_same_v<std::decay_t<FLambda>, TSelf>>> |
| TSelf& operator=(FLambda&& typed_lambda) { |
| packed_ = Function::FromTyped(std::forward<FLambda>(typed_lambda)); |
| return *this; |
| } |
| /*! |
| * \brief copy assignment operator from ffi::Function. |
| * \param packed The packed function. |
| * \returns reference to self. |
| */ |
| TSelf& operator=(Function packed) { |
| packed_ = std::move(packed); |
| return *this; |
| } |
| /*! |
| * \brief Invoke the operator. |
| * \param args The arguments |
| * \returns The return value. |
| */ |
| TVM_FFI_INLINE R operator()(Args... args) const { // NOLINT(performance-unnecessary-value-param) |
| if constexpr (std::is_same_v<R, void>) { |
| packed_(std::forward<Args>(args)...); |
| } else { |
| Any res = packed_(std::forward<Args>(args)...); |
| if constexpr (std::is_same_v<R, Any>) { |
| return res; |
| } else { |
| return std::move(res).cast<R>(); |
| } |
| } |
| } |
| /*! |
| * \brief convert to ffi::Function |
| * \return the internal ffi::Function |
| */ |
| operator Function() const { return packed(); } // NOLINT(google-explicit-constructor) |
| /*! |
| * \return reference the internal ffi::Function |
| */ |
| const Function& packed() const& { return packed_; } |
| /*! |
| * \return r-value reference the internal ffi::Function |
| */ |
| constexpr Function&& packed() && { return std::move(packed_); } |
| /*! \return Whether the packed function is nullptr */ |
| bool operator==(std::nullptr_t null) const { return packed_ == nullptr; } |
| /*! \return Whether the packed function is not nullptr */ |
| bool operator!=(std::nullptr_t null) const { return packed_ != nullptr; } |
| /*! |
| * \brief Get the type schema of `TypedFunction<R(Args...)>` in json format. |
| * \return The type schema of the function in json format. |
| */ |
| static std::string TypeSchema() { return details::FuncFunctorImpl<R, Args...>::TypeSchema(); } |
| |
| private: |
| /*! \brief The internal packed function */ |
| Function packed_; |
| }; |
| |
| template <typename FType> |
| inline constexpr bool use_default_type_traits_v<TypedFunction<FType>> = false; |
| |
| template <typename FType> |
| struct TypeTraits<TypedFunction<FType>> : public TypeTraitsBase { |
| static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIFunction; |
| |
| TVM_FFI_INLINE static void CopyToAnyView(const TypedFunction<FType>& src, TVMFFIAny* result) { |
| TypeTraits<Function>::CopyToAnyView(src.packed(), result); |
| } |
| |
| TVM_FFI_INLINE static void MoveToAny(TypedFunction<FType> src, TVMFFIAny* result) { |
| // Move from rvalue to trigger TypedFunction rvalue packed() overload |
| TypeTraits<Function>::MoveToAny(std::move(src).packed(), result); |
| } |
| |
| TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { |
| return src->type_index == TypeIndex::kTVMFFIFunction; |
| } |
| |
| TVM_FFI_INLINE static TypedFunction<FType> CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { |
| return TypedFunction<FType>(TypeTraits<Function>::CopyFromAnyViewAfterCheck(src)); |
| } |
| |
| TVM_FFI_INLINE static std::optional<TypedFunction<FType>> TryCastFromAnyView( |
| const TVMFFIAny* src) { |
| std::optional<Function> opt = TypeTraits<Function>::TryCastFromAnyView(src); |
| if (opt.has_value()) { |
| return TypedFunction<FType>(*std::move(opt)); |
| } else { |
| return std::nullopt; |
| } |
| } |
| |
| TVM_FFI_INLINE static std::string TypeStr() { return details::FunctionInfo<FType>::Sig(); } |
| TVM_FFI_INLINE static std::string TypeSchema() { return TypedFunction<FType>::TypeSchema(); } |
| }; |
| |
| /*! |
| * \brief helper function to get type index from key |
| */ |
| inline int32_t TypeKeyToIndex(std::string_view type_key) { |
| int32_t type_index; |
| TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; |
| TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); |
| return type_index; |
| } |
| |
| /// \cond Doxygen_Suppress |
| // Internal implementation macros used by TVM_FFI_DLL_EXPORT_TYPED_FUNC and related macros. |
| // These should not be used directly; use the public macros instead. |
| |
| // Internal implementation macro that generates the C ABI wrapper function |
| #define TVM_FFI_DLL_EXPORT_TYPED_FUNC_IMPL_(ExportName, Function) \ |
| extern "C" { \ |
| TVM_FFI_DLL_EXPORT int __tvm_ffi_##ExportName(void* self, const TVMFFIAny* args, \ |
| int32_t num_args, TVMFFIAny* result) { \ |
| TVM_FFI_SAFE_CALL_BEGIN(); \ |
| using FuncInfo = ::tvm::ffi::details::FunctionInfo<decltype(Function)>; \ |
| static std::string name = #ExportName; \ |
| ::tvm::ffi::details::unpack_call<typename FuncInfo::RetType>( \ |
| std::make_index_sequence<FuncInfo::num_args>{}, &name, Function, \ |
| reinterpret_cast<const ::tvm::ffi::AnyView*>(args), num_args, \ |
| reinterpret_cast<::tvm::ffi::Any*>(result)); \ |
| TVM_FFI_SAFE_CALL_END(); \ |
| } \ |
| } |
| /// \endcond |
| |
| /*! |
| * \brief Export typed function as a SafeCallType symbol that follows the FFI ABI. |
| * |
| * This macro exports the function and automatically exports metadata when |
| * TVM_FFI_DLL_EXPORT_INCLUDE_METADATA is defined. |
| * |
| * \param ExportName The symbol name to be exported. |
| * \param Function The typed function. |
| * |
| * \sa ffi::TypedFunction, TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC |
| * |
| * \code{.cpp} |
| * int AddOne_(int x) { |
| * return x + 1; |
| * } |
| * // Expose the function as "AddOne" |
| * TVM_FFI_DLL_EXPORT_TYPED_FUNC(AddOne, AddOne_); |
| * // Expose the function as "SubOne" |
| * TVM_FFI_DLL_EXPORT_TYPED_FUNC(SubOne, [](int x) { |
| * return x - 1; |
| * }); |
| * \endcode |
| * |
| * \note The final symbol names are: |
| * - `__tvm_ffi_<ExportName>` (function) |
| * - `__tvm_ffi__metadata_<ExportName>` (metadata - only when |
| * TVM_FFI_DLL_EXPORT_INCLUDE_METADATA is defined) |
| */ |
| #if TVM_FFI_DLL_EXPORT_INCLUDE_METADATA |
| // Implementation note: we specifically use TVMFFIStringFromByteArray |
| // so the returned string metadata is allocated in the libtvm_ffi and long lived. |
| #define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ |
| TVM_FFI_DLL_EXPORT_TYPED_FUNC_IMPL_(ExportName, Function) \ |
| extern "C" { \ |
| TVM_FFI_DLL_EXPORT int __tvm_ffi__metadata_##ExportName(void* self, const TVMFFIAny* args, \ |
| int32_t num_args, TVMFFIAny* result) { \ |
| TVM_FFI_SAFE_CALL_BEGIN(); \ |
| using FuncInfo = ::tvm::ffi::details::FunctionInfo<decltype(Function)>; \ |
| std::ostringstream os; \ |
| os << R"({"type_schema":)" \ |
| << ::tvm::ffi::EscapeStringJSON(::tvm::ffi::String(FuncInfo::TypeSchema())) << R"(})"; \ |
| std::string data = os.str(); \ |
| TVMFFIByteArray data_array{data.data(), data.size()}; \ |
| return TVMFFIStringFromByteArray(&data_array, result); \ |
| TVM_FFI_SAFE_CALL_END(); \ |
| } \ |
| } |
| #else |
| #define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ |
| TVM_FFI_DLL_EXPORT_TYPED_FUNC_IMPL_(ExportName, Function) |
| #endif |
| |
| /*! |
| * \brief Export documentation string for a typed function. |
| * |
| * This macro exports a documentation string associated with a function export name. |
| * The docstring can be used by stub generators and documentation tools. |
| * This macro only exports the docstring; it does not export the function itself. |
| * |
| * \param ExportName The symbol name that the docstring is associated with. |
| * \param DocString The documentation string (C string literal). |
| * |
| * \sa ffi::TypedFunction, TVM_FFI_DLL_EXPORT_TYPED_FUNC |
| * |
| * \code{.cpp} |
| * int Add(int a, int b) { |
| * return a + b; |
| * } |
| * |
| * TVM_FFI_DLL_EXPORT_TYPED_FUNC(add, Add); |
| * TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC( |
| * add, |
| * R"(Add two integers and return the sum. |
| * |
| * Parameters |
| * ---------- |
| * a : int |
| * First integer |
| * b : int |
| * Second integer |
| * |
| * Returns |
| * ------- |
| * result : int |
| * Sum of a and b)"); |
| * \endcode |
| * |
| * \note The exported symbol name is `__tvm_ffi__doc_<ExportName>` (docstring getter function). |
| * This symbol is only exported when TVM_FFI_DLL_EXPORT_INCLUDE_METADATA is defined. |
| */ |
| #if TVM_FFI_DLL_EXPORT_INCLUDE_METADATA |
| // Implementation note: we specifically use TVMFFIStringFromByteArray |
| // so the returned string metadata is allocated in the libtvm_ffi and long lived. |
| #define TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC(ExportName, DocString) \ |
| extern "C" { \ |
| TVM_FFI_DLL_EXPORT int __tvm_ffi__doc_##ExportName(void* self, const TVMFFIAny* args, \ |
| int32_t num_args, TVMFFIAny* result) { \ |
| TVM_FFI_SAFE_CALL_BEGIN(); \ |
| std::string_view data(DocString); \ |
| TVMFFIByteArray data_array{data.data(), data.size()}; \ |
| return TVMFFIStringFromByteArray(&data_array, result); \ |
| TVM_FFI_SAFE_CALL_END(); \ |
| } \ |
| } |
| #else |
| #define TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC(ExportName, DocString) |
| #endif |
| } // namespace ffi |
| } // namespace tvm |
| #endif // TVM_FFI_FUNCTION_H_ |