| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file tvm/runtime/packed_func.h |
| * \brief Type-erased function used across TVM API. |
| */ |
| #ifndef TVM_RUNTIME_PACKED_FUNC_H_ |
| #define TVM_RUNTIME_PACKED_FUNC_H_ |
| |
| #include <dmlc/logging.h> |
| #include <tvm/runtime/c_runtime_api.h> |
| #include <tvm/runtime/container.h> |
| #include <tvm/runtime/data_type.h> |
| #include <tvm/runtime/module.h> |
| #include <tvm/runtime/ndarray.h> |
| #include <tvm/runtime/object.h> |
| |
| #include <functional> |
| #include <limits> |
| #include <memory> |
| #include <string> |
| #include <tuple> |
| #include <type_traits> |
| #include <utility> |
| #include <vector> |
| |
| // Whether use TVM runtime in header only mode. |
| #ifndef TVM_RUNTIME_HEADER_ONLY |
| #define TVM_RUNTIME_HEADER_ONLY 0 |
| #endif |
| |
| // Always inline macro only use in template |
| // expansion cases where we know inline is important. |
| #ifdef _MSC_VER |
| #define TVM_ALWAYS_INLINE __forceinline |
| #else |
| #define TVM_ALWAYS_INLINE inline __attribute__((always_inline)) |
| #endif |
| |
| namespace tvm { |
| namespace runtime { |
| |
| // forward declarations |
| class TVMArgs; |
| class TVMArgValue; |
| class TVMMovableArgValue_; |
| class TVMRetValue; |
| class TVMArgsSetter; |
| |
| /*! |
| * \brief Packed function is a type-erased function. |
| * The arguments are passed by packed format. |
| * |
| * This is an useful unified interface to call generated functions, |
| * It is the unified function function type of TVM. |
| * It corresponds to TVMFunctionHandle in C runtime API. |
| */ |
| class PackedFunc { |
| public: |
| /*! |
| * \brief The internal std::function |
| * \param args The arguments to the function. |
| * \param rv The return value. |
| * |
| * \code |
| * // Example code on how to implemented FType |
| * void MyPackedFunc(TVMArgs args, TVMRetValue* rv) { |
| * // automatically convert arguments to desired type. |
| * int a0 = args[0]; |
| * float a1 = args[1]; |
| * ... |
| * // automatically assign values to rv |
| * std::string my_return_value = "x"; |
| * *rv = my_return_value; |
| * } |
| * \endcode |
| */ |
| using FType = std::function<void(TVMArgs args, TVMRetValue* rv)>; |
| /*! \brief default constructor */ |
| PackedFunc() {} |
| /*! \brief constructor from null */ |
| PackedFunc(std::nullptr_t null) {} // NOLINT(*) |
| /*! |
| * \brief constructing a packed function from a std::function. |
| * \param body the internal container of packed function. |
| */ |
| explicit PackedFunc(FType body) : body_(body) {} |
| /*! |
| * \brief Call packed function by directly passing in unpacked format. |
| * \param args Arguments to be passed. |
| * \tparam Args arguments to be passed. |
| * |
| * \code |
| * // Example code on how to call packed function |
| * void CallPacked(PackedFunc f) { |
| * // call like normal functions by pass in arguments |
| * // return value is automatically converted back |
| * int rvalue = f(1, 2.0); |
| * } |
| * \endcode |
| */ |
| template <typename... Args> |
| inline TVMRetValue operator()(Args&&... args) const; |
| /*! |
| * \brief Call the function in packed format. |
| * \param args The arguments |
| * \param rv The return value. |
| */ |
| inline void CallPacked(TVMArgs args, TVMRetValue* rv) const; |
| /*! \return the internal body function */ |
| inline FType body() const; |
| /*! \return Whether the packed function is nullptr */ |
| bool operator==(std::nullptr_t null) const { return body_ == nullptr; } |
| /*! \return Whether the packed function is not nullptr */ |
| bool operator!=(std::nullptr_t null) const { return body_ != nullptr; } |
| |
| private: |
| /*! \brief internal container of packed function */ |
| FType body_; |
| }; |
| |
| /*! |
| * \brief Please refer to \ref TypedPackedFuncAnchor "TypedPackedFunc<R(Args..)>" |
| */ |
| template <typename FType> |
| class TypedPackedFunc; |
| |
| /*! |
| * \anchor TypedPackedFuncAnchor |
| * \brief A PackedFunc wrapper to provide typed function signature. |
| * It is backed by a PackedFunc internally. |
| * |
| * TypedPackedFunc enables compile time type checking. |
| * TypedPackedFunc works with the runtime system: |
| * - It can be passed as an argument of PackedFunc. |
| * - It can be assigned to TVMRetValue. |
| * - It can be directly converted to a type-erased PackedFunc. |
| * |
| * Developers should prefer TypedPackedFunc over PackedFunc in C++ code |
| * as it enables compile time checking. |
| * We can construct a TypedPackedFunc from a lambda function |
| * with the same signature. |
| * |
| * \code |
| * // user defined lambda function. |
| * auto addone = [](int x)->int { |
| * return x + 1; |
| * }; |
| * // We can directly convert |
| * // lambda function to TypedPackedFunc |
| * TypedPackedFunc<int(int)> ftyped(addone); |
| * // invoke the function. |
| * int y = ftyped(1); |
| * // Can be directly converted to PackedFunc |
| * PackedFunc packed = ftype; |
| * \endcode |
| * \tparam R The return value of the function. |
| * \tparam Args The argument signature of the function. |
| */ |
| template <typename R, typename... Args> |
| class TypedPackedFunc<R(Args...)> { |
| public: |
| /*! \brief short hand for this function type */ |
| using TSelf = TypedPackedFunc<R(Args...)>; |
| /*! \brief default constructor */ |
| TypedPackedFunc() {} |
| /*! \brief constructor from null */ |
| TypedPackedFunc(std::nullptr_t null) {} // NOLINT(*) |
| /*! |
| * \brief construct by wrap a PackedFunc |
| * |
| * Example usage: |
| * \code |
| * PackedFunc packed([](TVMArgs args, TVMRetValue *rv) { |
| * int x = args[0]; |
| * *rv = x + 1; |
| * }); |
| * // construct from packed function |
| * TypedPackedFunc<int(int)> ftyped(packed); |
| * // call the typed version. |
| * CHECK_EQ(ftyped(1), 2); |
| * \endcode |
| * |
| * \param packed The packed function |
| */ |
| inline TypedPackedFunc(PackedFunc packed); // NOLINT(*) |
| /*! |
| * \brief constructor from TVMRetValue |
| * \param value The TVMRetValue |
| */ |
| inline TypedPackedFunc(const TVMRetValue& value); // NOLINT(*) |
| /*! |
| * \brief constructor from TVMArgValue |
| * \param value The TVMArgValue |
| */ |
| inline TypedPackedFunc(const TVMArgValue& value); // NOLINT(*) |
| /*! |
| * \brief constructor from TVMMovableArgValue_ |
| * \param value The TVMMovableArgValue_ |
| */ |
| inline TypedPackedFunc(TVMMovableArgValue_&& value); // NOLINT(*) |
| /*! |
| * \brief construct from a lambda function with the same signature. |
| * |
| * Example usage: |
| * \code |
| * auto typed_lambda = [](int x)->int { return x + 1; } |
| * // construct from packed function |
| * TypedPackedFunc<int(int)> ftyped(typed_lambda); |
| * // call the typed version. |
| * CHECK_EQ(ftyped(1), 2); |
| * \endcode |
| * |
| * \param typed_lambda typed lambda function. |
| * \tparam FLambda the type of the lambda function. |
| */ |
| template <typename FLambda, typename = typename std::enable_if< |
| std::is_convertible<FLambda, |
| std::function<R(Args...)>>::value>::type> |
| TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*) |
| this->AssignTypedLambda(typed_lambda); |
| } |
| /*! |
| * \brief copy assignment operator from typed lambda |
| * |
| * Example usage: |
| * \code |
| * // construct from packed function |
| * TypedPackedFunc<int(int)> ftyped; |
| * ftyped = [](int x) { return x + 1; } |
| * // call the typed version. |
| * CHECK_EQ(ftyped(1), 2); |
| * \endcode |
| * |
| * \param typed_lambda typed lambda function. |
| * \tparam FLambda the type of the lambda function. |
| * \returns reference to self. |
| */ |
| template <typename FLambda, typename = typename std::enable_if< |
| std::is_convertible<FLambda, |
| std::function<R(Args...)>>::value>::type> |
| TSelf& operator=(FLambda typed_lambda) { // NOLINT(*) |
| this->AssignTypedLambda(typed_lambda); |
| return *this; |
| } |
| /*! |
| * \brief copy assignment operator from PackedFunc. |
| * \param packed The packed function. |
| * \returns reference to self. |
| */ |
| TSelf& operator=(PackedFunc packed) { |
| packed_ = packed; |
| return *this; |
| } |
| /*! |
| * \brief Invoke the operator. |
| * \param args The arguments |
| * \returns The return value. |
| */ |
| TVM_ALWAYS_INLINE R operator()(Args... args) const; |
| /*! |
| * \brief convert to PackedFunc |
| * \return the internal PackedFunc |
| */ |
| operator PackedFunc() const { return packed(); } |
| /*! |
| * \return reference the internal PackedFunc |
| */ |
| const PackedFunc& packed() const { return packed_; } |
| /*! \return Whether the packed function is nullptr */ |
| bool operator==(std::nullptr_t null) const { return packed_ == nullptr; } |
| /*! \return Whether the packed function is not nullptr */ |
| bool operator!=(std::nullptr_t null) const { return packed_ != nullptr; } |
| |
| private: |
| friend class TVMRetValue; |
| /*! \brief The internal packed function */ |
| PackedFunc packed_; |
| /*! |
| * \brief Assign the packed field using a typed lambda function. |
| * |
| * \param flambda The lambda function. |
| * \tparam FLambda The lambda function type. |
| * \note We capture the lambda when possible for maximum efficiency. |
| */ |
| template <typename FLambda> |
| inline void AssignTypedLambda(FLambda flambda); |
| }; |
| |
| /*! \brief Arguments into TVM functions. */ |
| class TVMArgs { |
| public: |
| const TVMValue* values; |
| const int* type_codes; |
| int num_args; |
| /*! |
| * \brief constructor |
| * \param values The argument values |
| * \param type_codes The argument type codes |
| * \param num_args number of arguments. |
| */ |
| TVMArgs(const TVMValue* values, const int* type_codes, int num_args) |
| : values(values), type_codes(type_codes), num_args(num_args) {} |
| /*! \return size of the arguments */ |
| inline int size() const; |
| /*! |
| * \brief Get i-th argument |
| * \param i the index. |
| * \return the ith argument. |
| */ |
| inline TVMArgValue operator[](int i) const; |
| }; |
| |
| /*! |
| * \brief Convert argument type code to string. |
| * \param type_code The input type code. |
| * \return The corresponding string repr. |
| */ |
| inline const char* ArgTypeCode2Str(int type_code); |
| |
| // macro to check type code. |
| #define TVM_CHECK_TYPE_CODE(CODE, T) \ |
| CHECK_EQ(CODE, T) << " expected " << ArgTypeCode2Str(T) << " but get " << ArgTypeCode2Str(CODE) |
| |
| /*! |
| * \brief Type traits for runtime type check during FFI conversion. |
| * \tparam T the type to be checked. |
| */ |
| template <typename T> |
| struct ObjectTypeChecker { |
| static bool Check(const Object* ptr) { |
| using ContainerType = typename T::ContainerType; |
| if (ptr == nullptr) return T::_type_is_nullable; |
| return ptr->IsInstance<ContainerType>(); |
| } |
| static std::string TypeName() { |
| using ContainerType = typename T::ContainerType; |
| return ContainerType::_type_key; |
| } |
| }; |
| |
| /*! |
| * \brief Internal base class to |
| * handle conversion to POD values. |
| */ |
| class TVMPODValue_ { |
| public: |
| operator double() const { |
| // Allow automatic conversion from int to float |
| // This avoids errors when user pass in int from |
| // the frontend while the API expects a float. |
| if (type_code_ == kDLInt) { |
| return static_cast<double>(value_.v_int64); |
| } |
| TVM_CHECK_TYPE_CODE(type_code_, kDLFloat); |
| return value_.v_float64; |
| } |
| operator int64_t() const { |
| TVM_CHECK_TYPE_CODE(type_code_, kDLInt); |
| return value_.v_int64; |
| } |
| operator uint64_t() const { |
| TVM_CHECK_TYPE_CODE(type_code_, kDLInt); |
| return value_.v_int64; |
| } |
| operator int() const { |
| TVM_CHECK_TYPE_CODE(type_code_, kDLInt); |
| CHECK_LE(value_.v_int64, std::numeric_limits<int>::max()); |
| CHECK_GE(value_.v_int64, std::numeric_limits<int>::min()); |
| return static_cast<int>(value_.v_int64); |
| } |
| operator bool() const { |
| TVM_CHECK_TYPE_CODE(type_code_, kDLInt); |
| return value_.v_int64 != 0; |
| } |
| operator void*() const { |
| if (type_code_ == kTVMNullptr) return nullptr; |
| if (type_code_ == kTVMDLTensorHandle) return value_.v_handle; |
| TVM_CHECK_TYPE_CODE(type_code_, kTVMOpaqueHandle); |
| return value_.v_handle; |
| } |
| operator DLTensor*() const { |
| if (type_code_ == kTVMDLTensorHandle || type_code_ == kTVMNDArrayHandle) { |
| return static_cast<DLTensor*>(value_.v_handle); |
| } else { |
| if (type_code_ == kTVMNullptr) return nullptr; |
| LOG(FATAL) << "Expect " |
| << "DLTensor* or NDArray but get " << ArgTypeCode2Str(type_code_); |
| return nullptr; |
| } |
| } |
| operator NDArray() const { |
| if (type_code_ == kTVMNullptr) return NDArray(ObjectPtr<Object>(nullptr)); |
| TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle); |
| return NDArray(NDArray::FFIDataFromHandle(static_cast<TVMArrayHandle>(value_.v_handle))); |
| } |
| operator Module() const { |
| if (type_code_ == kTVMNullptr) { |
| return Module(ObjectPtr<Object>(nullptr)); |
| } |
| TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); |
| return Module(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle))); |
| } |
| operator TVMContext() const { |
| TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); |
| return value_.v_ctx; |
| } |
| int type_code() const { return type_code_; } |
| /*! |
| * \brief return handle as specific pointer type. |
| * \tparam T the data type. |
| * \return The pointer type. |
| */ |
| template <typename T> |
| T* ptr() const { |
| return static_cast<T*>(value_.v_handle); |
| } |
| // ObjectRef handling |
| template <typename TObjectRef, |
| typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type> |
| inline bool IsObjectRef() const; |
| template <typename TObjectRef> |
| inline TObjectRef AsObjectRef() const; |
| |
| protected: |
| friend class TVMArgsSetter; |
| friend class TVMRetValue; |
| TVMPODValue_() : type_code_(kTVMNullptr) {} |
| TVMPODValue_(TVMValue value, int type_code) : value_(value), type_code_(type_code) {} |
| |
| /*! \brief The value */ |
| TVMValue value_; |
| /*! \brief the type code */ |
| int type_code_; |
| }; |
| |
| /*! |
| * \brief A single argument value to PackedFunc. |
| * Containing both type_code and TVMValue |
| * |
| * Provides utilities to do type cast into other types. |
| */ |
| class TVMArgValue : public TVMPODValue_ { |
| public: |
| /*! \brief default constructor */ |
| TVMArgValue() {} |
| /*! |
| * \brief constructor |
| * \param value of the function |
| * \param type_code The type code. |
| */ |
| TVMArgValue(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {} |
| // reuse converter from parent |
| using TVMPODValue_::operator double; |
| using TVMPODValue_::operator int64_t; |
| using TVMPODValue_::operator uint64_t; |
| using TVMPODValue_::operator int; |
| using TVMPODValue_::operator bool; |
| using TVMPODValue_::operator void*; |
| using TVMPODValue_::operator DLTensor*; |
| using TVMPODValue_::operator NDArray; |
| using TVMPODValue_::operator TVMContext; |
| using TVMPODValue_::operator Module; |
| using TVMPODValue_::AsObjectRef; |
| using TVMPODValue_::IsObjectRef; |
| |
| // conversion operator. |
| operator std::string() const { |
| if (type_code_ == kTVMDataType) { |
| return DLDataType2String(operator DLDataType()); |
| } else if (type_code_ == kTVMBytes) { |
| TVMByteArray* arr = static_cast<TVMByteArray*>(value_.v_handle); |
| return std::string(arr->data, arr->size); |
| } else if (type_code_ == kTVMStr) { |
| return std::string(value_.v_str); |
| } else { |
| CHECK(IsObjectRef<tvm::runtime::String>()); |
| return AsObjectRef<tvm::runtime::String>().operator std::string(); |
| } |
| } |
| operator PackedFunc() const { |
| if (type_code_ == kTVMNullptr) return PackedFunc(); |
| TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); |
| return *ptr<PackedFunc>(); |
| } |
| template <typename FType> |
| operator TypedPackedFunc<FType>() const { |
| return TypedPackedFunc<FType>(operator PackedFunc()); |
| } |
| const TVMValue& value() const { return value_; } |
| |
| template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type> |
| inline operator T() const; |
| inline operator DLDataType() const; |
| inline operator DataType() const; |
| }; |
| |
| /*! |
| * \brief Internal auxiliary struct for TypedPackedFunc to indicate a movable argument. |
| * |
| * We can only construct a movable argument once from a single argument position. |
| * If the argument is passed as RValue reference, the result will be moved. |
| * We should only construct a MovableArg from an argument once, |
| * as the result will can moved. |
| * |
| * \note For internal development purpose only. |
| */ |
| class TVMMovableArgValue_ : public TVMArgValue { |
| public: |
| TVMMovableArgValue_(TVMValue value, int type_code) : TVMArgValue(value, type_code) {} |
| // reuse converter from parent |
| using TVMArgValue::operator double; |
| using TVMArgValue::operator int64_t; |
| using TVMArgValue::operator uint64_t; |
| using TVMArgValue::operator int; |
| using TVMArgValue::operator bool; |
| using TVMArgValue::operator void*; |
| using TVMArgValue::operator DLTensor*; |
| using TVMArgValue::operator TVMContext; |
| using TVMArgValue::operator std::string; |
| using TVMArgValue::operator DLDataType; |
| using TVMArgValue::operator DataType; |
| using TVMArgValue::operator PackedFunc; |
| /*! |
| * \brief Helper converter function. |
| * Try to move out an argument if possible, |
| * fall back to normal argument conversion rule otherwise. |
| */ |
| template <typename T, |
| typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type> |
| inline operator T() const; |
| }; |
| |
| /*! |
| * \brief Return Value container, |
| * Unlike TVMArgValue, which only holds reference and do not delete |
| * the underlying container during destruction. |
| * |
| * TVMRetValue holds value and will manage the underlying containers |
| * when it stores a complicated data type. |
| */ |
| class TVMRetValue : public TVMPODValue_ { |
| public: |
| /*! \brief default constructor */ |
| TVMRetValue() {} |
| /*! |
| * \brief move constructor from anoter return value. |
| * \param other The other return value. |
| */ |
| TVMRetValue(TVMRetValue&& other) : TVMPODValue_(other.value_, other.type_code_) { |
| other.value_.v_handle = nullptr; |
| other.type_code_ = kTVMNullptr; |
| } |
| /*! \brief destructor */ |
| ~TVMRetValue() { this->Clear(); } |
| // reuse converter from parent |
| using TVMPODValue_::operator double; |
| using TVMPODValue_::operator int64_t; |
| using TVMPODValue_::operator uint64_t; |
| using TVMPODValue_::operator int; |
| using TVMPODValue_::operator bool; |
| using TVMPODValue_::operator void*; |
| using TVMPODValue_::operator DLTensor*; |
| using TVMPODValue_::operator TVMContext; |
| using TVMPODValue_::operator NDArray; |
| using TVMPODValue_::operator Module; |
| using TVMPODValue_::AsObjectRef; |
| using TVMPODValue_::IsObjectRef; |
| |
| TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); } |
| // conversion operators |
| operator std::string() const { |
| if (type_code_ == kTVMDataType) { |
| return DLDataType2String(operator DLDataType()); |
| } else if (type_code_ == kTVMBytes) { |
| return *ptr<std::string>(); |
| } |
| TVM_CHECK_TYPE_CODE(type_code_, kTVMStr); |
| return *ptr<std::string>(); |
| } |
| operator DLDataType() const { |
| if (type_code_ == kTVMStr) { |
| return String2DLDataType(operator std::string()); |
| } |
| TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType); |
| return value_.v_type; |
| } |
| operator DataType() const { return DataType(operator DLDataType()); } |
| operator PackedFunc() const { |
| if (type_code_ == kTVMNullptr) return PackedFunc(); |
| TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); |
| return *ptr<PackedFunc>(); |
| } |
| template <typename FType> |
| operator TypedPackedFunc<FType>() const { |
| return TypedPackedFunc<FType>(operator PackedFunc()); |
| } |
| // Assign operators |
| TVMRetValue& operator=(TVMRetValue&& other) { |
| this->Clear(); |
| value_ = other.value_; |
| type_code_ = other.type_code_; |
| other.type_code_ = kTVMNullptr; |
| return *this; |
| } |
| TVMRetValue& operator=(double value) { |
| this->SwitchToPOD(kDLFloat); |
| value_.v_float64 = value; |
| return *this; |
| } |
| TVMRetValue& operator=(std::nullptr_t value) { |
| this->SwitchToPOD(kTVMNullptr); |
| value_.v_handle = value; |
| return *this; |
| } |
| TVMRetValue& operator=(void* value) { |
| this->SwitchToPOD(kTVMOpaqueHandle); |
| value_.v_handle = value; |
| return *this; |
| } |
| TVMRetValue& operator=(int64_t value) { |
| this->SwitchToPOD(kDLInt); |
| value_.v_int64 = value; |
| return *this; |
| } |
| TVMRetValue& operator=(int value) { |
| this->SwitchToPOD(kDLInt); |
| value_.v_int64 = value; |
| return *this; |
| } |
| TVMRetValue& operator=(TVMContext value) { |
| this->SwitchToPOD(kTVMContext); |
| value_.v_ctx = value; |
| return *this; |
| } |
| TVMRetValue& operator=(DLDataType t) { |
| this->SwitchToPOD(kTVMDataType); |
| value_.v_type = t; |
| return *this; |
| } |
| TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); } |
| TVMRetValue& operator=(bool value) { |
| this->SwitchToPOD(kDLInt); |
| value_.v_int64 = value; |
| return *this; |
| } |
| TVMRetValue& operator=(std::string value) { |
| this->SwitchToClass(kTVMStr, value); |
| return *this; |
| } |
| TVMRetValue& operator=(TVMByteArray value) { |
| this->SwitchToClass(kTVMBytes, std::string(value.data, value.size)); |
| return *this; |
| } |
| TVMRetValue& operator=(NDArray other) { |
| if (other.data_ != nullptr) { |
| this->Clear(); |
| type_code_ = kTVMNDArrayHandle; |
| value_.v_handle = NDArray::FFIGetHandle(other); |
| ObjectRef::FFIClearAfterMove(&other); |
| } else { |
| SwitchToPOD(kTVMNullptr); |
| } |
| return *this; |
| } |
| TVMRetValue& operator=(Module m) { |
| SwitchToObject(kTVMModuleHandle, std::move(m.data_)); |
| return *this; |
| } |
| TVMRetValue& operator=(PackedFunc f) { |
| if (f == nullptr) { |
| this->SwitchToPOD(kTVMNullptr); |
| } else { |
| this->SwitchToClass(kTVMPackedFuncHandle, f); |
| } |
| return *this; |
| } |
| template <typename FType> |
| TVMRetValue& operator=(const TypedPackedFunc<FType>& f) { |
| return operator=(f.packed()); |
| } |
| TVMRetValue& operator=(const TVMRetValue& other) { // NOLINT(*0 |
| this->Assign(other); |
| return *this; |
| } |
| TVMRetValue& operator=(const TVMArgValue& other) { |
| this->Assign(other); |
| return *this; |
| } |
| TVMRetValue& operator=(TVMMovableArgValue_&& other) { |
| this->Assign(other); |
| return *this; |
| } |
| /*! |
| * \brief Move the value back to front-end via C API. |
| * This marks the current container as null. |
| * The managed resources are moved to the front-end. |
| * The front end should take charge in managing them. |
| * |
| * \param ret_value The return value. |
| * \param ret_type_code The return type code. |
| */ |
| void MoveToCHost(TVMValue* ret_value, int* ret_type_code) { |
| // cannot move str; need specially handle. |
| CHECK(type_code_ != kTVMStr && type_code_ != kTVMBytes); |
| *ret_value = value_; |
| *ret_type_code = type_code_; |
| type_code_ = kTVMNullptr; |
| } |
| /*! |
| * \brief Construct a new TVMRetValue by |
| * moving from return value stored via C API. |
| * \param value the value. |
| * \param type_code The type code. |
| * \return The created TVMRetValue. |
| */ |
| static TVMRetValue MoveFromCHost(TVMValue value, int type_code) { |
| // Can move POD and everything under the object system. |
| CHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle); |
| TVMRetValue ret; |
| ret.value_ = value; |
| ret.type_code_ = type_code; |
| return ret; |
| } |
| /*! \return The value field, if the data is POD */ |
| const TVMValue& value() const { |
| CHECK(type_code_ != kTVMObjectHandle && type_code_ != kTVMPackedFuncHandle && |
| type_code_ != kTVMModuleHandle && type_code_ != kTVMStr) |
| << "TVMRetValue.value can only be used for POD data"; |
| return value_; |
| } |
| // ObjectRef handling |
| template <typename TObjectRef, |
| typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type> |
| inline TVMRetValue& operator=(TObjectRef other); |
| template <typename T, typename = typename std::enable_if<std::is_class<T>::value>::type> |
| inline operator T() const; |
| |
| private: |
| template <typename T> |
| void Assign(const T& other) { |
| switch (other.type_code()) { |
| case kTVMStr: { |
| SwitchToClass<std::string>(kTVMStr, other); |
| break; |
| } |
| case kTVMBytes: { |
| SwitchToClass<std::string>(kTVMBytes, other); |
| break; |
| } |
| case kTVMPackedFuncHandle: { |
| SwitchToClass<PackedFunc>(kTVMPackedFuncHandle, other); |
| break; |
| } |
| case kTVMModuleHandle: { |
| *this = other.operator Module(); |
| break; |
| } |
| case kTVMNDArrayHandle: { |
| *this = other.operator NDArray(); |
| break; |
| } |
| case kTVMObjectHandle: { |
| // Avoid operator ObjectRef as we already know it is not NDArray/Module |
| SwitchToObject(kTVMObjectHandle, |
| GetObjectPtr<Object>(static_cast<Object*>(other.value_.v_handle))); |
| break; |
| } |
| case kTVMObjectRValueRefArg: { |
| operator=(other.operator ObjectRef()); |
| break; |
| } |
| default: { |
| SwitchToPOD(other.type_code()); |
| value_ = other.value_; |
| break; |
| } |
| } |
| } |
| // get the internal container. |
| void SwitchToPOD(int type_code) { |
| if (type_code_ != type_code) { |
| this->Clear(); |
| type_code_ = type_code; |
| } |
| } |
| template <typename T> |
| void SwitchToClass(int type_code, T v) { |
| if (type_code_ != type_code) { |
| this->Clear(); |
| type_code_ = type_code; |
| value_.v_handle = new T(v); |
| } else { |
| *static_cast<T*>(value_.v_handle) = v; |
| } |
| } |
| void SwitchToObject(int type_code, ObjectPtr<Object> other) { |
| if (other.data_ != nullptr) { |
| this->Clear(); |
| type_code_ = type_code; |
| // move the handle out |
| value_.v_handle = other.data_; |
| other.data_ = nullptr; |
| } else { |
| SwitchToPOD(kTVMNullptr); |
| } |
| } |
| void Clear() { |
| if (type_code_ == kTVMNullptr) return; |
| switch (type_code_) { |
| case kTVMStr: |
| case kTVMBytes: |
| delete ptr<std::string>(); |
| break; |
| case kTVMPackedFuncHandle: |
| delete ptr<PackedFunc>(); |
| break; |
| case kTVMNDArrayHandle: { |
| NDArray::FFIDecRef(static_cast<TVMArrayHandle>(value_.v_handle)); |
| break; |
| } |
| case kTVMModuleHandle: { |
| static_cast<Object*>(value_.v_handle)->DecRef(); |
| break; |
| } |
| case kTVMObjectHandle: { |
| static_cast<Object*>(value_.v_handle)->DecRef(); |
| break; |
| } |
| } |
| type_code_ = kTVMNullptr; |
| } |
| }; |
| |
| /*! |
| * \brief Type trait to specify special value conversion rules from |
| * TVMArgValue and TVMRetValue. |
| * |
| * The trait can be specialized to add type specific conversion logic |
| * from the TVMArgvalue and TVMRetValue. |
| * |
| * \tparam TObjectRef the specific ObjectRefType. |
| */ |
| template <typename TObjectRef> |
| struct PackedFuncValueConverter { |
| /*! |
| * \brief Convert a TObjectRef from an argument value. |
| * \param val The argument value. |
| * \return the converted result. |
| */ |
| static TObjectRef From(const TVMArgValue& val) { return val.AsObjectRef<TObjectRef>(); } |
| /*! |
| * \brief Convert a TObjectRef from a return value. |
| * \param val The argument value. |
| * \return the converted result. |
| */ |
| static TObjectRef From(const TVMRetValue& val) { return val.AsObjectRef<TObjectRef>(); } |
| }; |
| |
| /*! |
| * \brief Export a function with the PackedFunc signature |
| * as a PackedFunc that can be loaded by LibraryModule. |
| * |
| * \param ExportName The symbol name to be exported. |
| * \param Function The function with PackedFunc signature. |
| * \sa PackedFunc |
| * |
| * \code |
| * |
| * void AddOne_(TVMArgs args, TVMRetValue* rv) { |
| * int value = args[0]; |
| * *rv = value + 1; |
| * } |
| * // Expose the function as "AddOne" |
| * TVM_DLL_EXPORT_PACKED_FUNC(AddOne, AddOne_); |
| * |
| * \endcode |
| */ |
| #define TVM_DLL_EXPORT_PACKED_FUNC(ExportName, Function) \ |
| extern "C" { \ |
| TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ |
| int* out_type_code); \ |
| int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ |
| int* out_type_code) { \ |
| try { \ |
| ::tvm::runtime::TVMRetValue rv; \ |
| Function(::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ |
| rv.MoveToCHost(out_value, out_type_code); \ |
| return 0; \ |
| } catch (const ::std::runtime_error& _except_) { \ |
| TVMAPISetLastError(_except_.what()); \ |
| return -1; \ |
| } \ |
| } \ |
| } |
| |
| /*! |
| * \brief Export typed function as a PackedFunc |
| * that can be loaded by LibraryModule. |
| * |
| * \param ExportName The symbol name to be exported. |
| * \param Function The typed function. |
| * \note ExportName and Function must be different, |
| * see code examples below. |
| * |
| * \sa TypedPackedFunc |
| * |
| * \code |
| * |
| * int AddOne_(int x) { |
| * return x + 1; |
| * } |
| * |
| * // Expose the function as "AddOne" |
| * TVM_DLL_EXPORT_TYPED_FUNC(AddOne, AddOne_); |
| * |
| * // Expose the function as "SubOne" |
| * TVM_DLL_EXPORT_TYPED_FUNC(SubOne, [](int x) { |
| * return x - 1; |
| * }); |
| * |
| * // The following code will cause compilation error. |
| * // Because the same Function and ExortName |
| * // TVM_DLL_EXPORT_TYPED_FUNC(AddOne_, AddOne_); |
| * |
| * // The following code is OK, assuming the macro |
| * // is in a different namespace from xyz |
| * // TVM_DLL_EXPORT_TYPED_FUNC(AddOne_, xyz::AddOne_); |
| * |
| * \endcode |
| */ |
| #define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ |
| extern "C" { \ |
| TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ |
| int* out_type_code) { \ |
| try { \ |
| auto f = Function; \ |
| using FType = ::tvm::runtime::detail::function_signature<decltype(f)>::FType; \ |
| ::tvm::runtime::TVMRetValue rv; \ |
| ::tvm::runtime::detail::unpack_call_by_signature<FType>::run( \ |
| f, ::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ |
| rv.MoveToCHost(out_value, out_type_code); \ |
| return 0; \ |
| } catch (const ::std::runtime_error& _except_) { \ |
| TVMAPISetLastError(_except_.what()); \ |
| return -1; \ |
| } \ |
| } \ |
| } |
| |
| inline TVMArgValue TVMArgs::operator[](int i) const { |
| CHECK_LT(i, num_args) << "not enough argument passed, " << num_args << " passed" |
| << " but request arg[" << i << "]."; |
| return TVMArgValue(values[i], type_codes[i]); |
| } |
| |
| inline int TVMArgs::size() const { return num_args; } |
| |
| inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { body_(args, rv); } |
| |
| inline PackedFunc::FType PackedFunc::body() const { return body_; } |
| |
| // internal namespace |
| inline const char* ArgTypeCode2Str(int type_code) { |
| switch (type_code) { |
| case kDLInt: |
| return "int"; |
| case kDLUInt: |
| return "uint"; |
| case kDLFloat: |
| return "float"; |
| case kTVMStr: |
| return "str"; |
| case kTVMBytes: |
| return "bytes"; |
| case kTVMOpaqueHandle: |
| return "handle"; |
| case kTVMNullptr: |
| return "NULL"; |
| case kTVMDLTensorHandle: |
| return "ArrayHandle"; |
| case kTVMDataType: |
| return "DLDataType"; |
| case kTVMContext: |
| return "TVMContext"; |
| case kTVMPackedFuncHandle: |
| return "FunctionHandle"; |
| case kTVMModuleHandle: |
| return "ModuleHandle"; |
| case kTVMNDArrayHandle: |
| return "NDArrayContainer"; |
| case kTVMObjectHandle: |
| return "Object"; |
| case kTVMObjectRValueRefArg: |
| return "ObjectRValueRefArg"; |
| default: |
| LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code); |
| return ""; |
| } |
| } |
| |
| namespace detail { |
| |
| template <bool stop, std::size_t I, typename F> |
| struct for_each_dispatcher { |
| template <typename T, typename... Args> |
| static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*) |
| f(I, std::forward<T>(value)); |
| for_each_dispatcher<sizeof...(Args) == 0, (I + 1), F>::run(f, std::forward<Args>(args)...); |
| } |
| }; |
| |
| template <std::size_t I, typename F> |
| struct for_each_dispatcher<true, I, F> { |
| static void run(const F& f) {} // NOLINT(*) |
| }; |
| |
| template <typename F, typename... Args> |
| inline void for_each(const F& f, Args&&... args) { // NOLINT(*) |
| for_each_dispatcher<sizeof...(Args) == 0, 0, F>::run(f, std::forward<Args>(args)...); |
| } |
| |
| template <typename T> |
| struct func_signature_helper { |
| using FType = void; |
| }; |
| |
| template <typename T, typename R, typename... Args> |
| struct func_signature_helper<R (T::*)(Args...)> { |
| using FType = R(Args...); |
| static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference"); |
| }; |
| |
| template <typename T, typename R, typename... Args> |
| struct func_signature_helper<R (T::*)(Args...) const> { |
| using FType = R(Args...); |
| static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference"); |
| }; |
| |
| /*! |
| * \brief template class to get function signature of a function or functor. |
| * \tparam T The funtion/functor type. |
| */ |
| template <typename T> |
| struct function_signature { |
| using FType = typename func_signature_helper<decltype(&T::operator())>::FType; |
| }; |
| |
| // handle case of function. |
| template <typename R, typename... Args> |
| struct function_signature<R(Args...)> { |
| using FType = R(Args...); |
| static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference"); |
| }; |
| |
| // handle case of function ptr. |
| template <typename R, typename... Args> |
| struct function_signature<R (*)(Args...)> { |
| using FType = R(Args...); |
| static_assert(!std::is_reference<R>::value, "TypedPackedFunc return reference"); |
| }; |
| } // namespace detail |
| |
| /* \brief argument settter to PackedFunc */ |
| class TVMArgsSetter { |
| public: |
| TVMArgsSetter(TVMValue* values, int* type_codes) : values_(values), type_codes_(type_codes) {} |
| // setters for POD types |
| template <typename T, typename = typename std::enable_if<std::is_integral<T>::value>::type> |
| TVM_ALWAYS_INLINE void operator()(size_t i, T value) const { |
| values_[i].v_int64 = static_cast<int64_t>(value); |
| type_codes_[i] = kDLInt; |
| } |
| TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const { |
| values_[i].v_int64 = static_cast<int64_t>(value); |
| CHECK_LE(value, static_cast<uint64_t>(std::numeric_limits<int64_t>::max())); |
| type_codes_[i] = kDLInt; |
| } |
| TVM_ALWAYS_INLINE void operator()(size_t i, double value) const { |
| values_[i].v_float64 = value; |
| type_codes_[i] = kDLFloat; |
| } |
| TVM_ALWAYS_INLINE void operator()(size_t i, std::nullptr_t value) const { |
| values_[i].v_handle = value; |
| type_codes_[i] = kTVMNullptr; |
| } |
| TVM_ALWAYS_INLINE void operator()(size_t i, const TVMArgValue& value) const { |
| values_[i] = value.value_; |
| type_codes_[i] = value.type_code_; |
| } |
| TVM_ALWAYS_INLINE void operator()(size_t i, void* value) const { |
| values_[i].v_handle = value; |
| type_codes_[i] = kTVMOpaqueHandle; |
| } |
| TVM_ALWAYS_INLINE void operator()(size_t i, DLTensor* value) const { |
| values_[i].v_handle = value; |
| type_codes_[i] = kTVMDLTensorHandle; |
| } |
| TVM_ALWAYS_INLINE void operator()(size_t i, TVMContext value) const { |
| values_[i].v_ctx = value; |
| type_codes_[i] = kTVMContext; |
| } |
| TVM_ALWAYS_INLINE void operator()(size_t i, DLDataType value) const { |
| values_[i].v_type = value; |
| type_codes_[i] = kTVMDataType; |
| } |
| TVM_ALWAYS_INLINE void operator()(size_t i, DataType dtype) const { |
| operator()(i, dtype.operator DLDataType()); |
| } |
| TVM_ALWAYS_INLINE void operator()(size_t i, const char* value) const { |
| values_[i].v_str = value; |
| type_codes_[i] = kTVMStr; |
| } |
| // setters for container types |
| TVM_ALWAYS_INLINE void operator()(size_t i, const std::string& value) const { |
| values_[i].v_str = value.c_str(); |
| type_codes_[i] = kTVMStr; |
| } |
| TVM_ALWAYS_INLINE void operator()(size_t i, const TVMByteArray& value) const { |
| values_[i].v_handle = const_cast<TVMByteArray*>(&value); |
| type_codes_[i] = kTVMBytes; |
| } |
| TVM_ALWAYS_INLINE void operator()(size_t i, const PackedFunc& value) const { |
| if (value != nullptr) { |
| values_[i].v_handle = const_cast<PackedFunc*>(&value); |
| type_codes_[i] = kTVMPackedFuncHandle; |
| } else { |
| values_[i].v_handle = nullptr; |
| type_codes_[i] = kTVMNullptr; |
| } |
| } |
| template <typename FType> |
| TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc<FType>& value) const { |
| operator()(i, value.packed()); |
| } |
| void operator()(size_t i, const TVMRetValue& value) const { |
| if (value.type_code() == kTVMStr) { |
| values_[i].v_str = value.ptr<std::string>()->c_str(); |
| type_codes_[i] = kTVMStr; |
| } else { |
| CHECK_NE(value.type_code(), kTVMBytes) << "not handled."; |
| values_[i] = value.value_; |
| type_codes_[i] = value.type_code(); |
| } |
| } |
| // ObjectRef handling |
| template <typename TObjectRef, |
| typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type> |
| TVM_ALWAYS_INLINE void operator()(size_t i, const TObjectRef& value) const { |
| this->SetObject(i, value); |
| } |
| |
| template <typename TObjectRef, |
| typename = typename std::enable_if<std::is_base_of< |
| ObjectRef, typename std::remove_reference<TObjectRef>::type>::value>::type> |
| TVM_ALWAYS_INLINE void operator()(size_t i, TObjectRef&& value) const { |
| this->SetObject(i, std::forward<TObjectRef>(value)); |
| } |
| |
| private: |
| template <typename TObjectRef> |
| inline void SetObject(size_t i, TObjectRef&& value) const; |
| /*! \brief The values fields */ |
| TVMValue* values_; |
| /*! \brief The type code fields */ |
| int* type_codes_; |
| }; |
| |
| template <typename... Args> |
| inline TVMRetValue PackedFunc::operator()(Args&&... args) const { |
| const int kNumArgs = sizeof...(Args); |
| const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; |
| TVMValue values[kArraySize]; |
| int type_codes[kArraySize]; |
| detail::for_each(TVMArgsSetter(values, type_codes), std::forward<Args>(args)...); |
| TVMRetValue rv; |
| body_(TVMArgs(values, type_codes, kNumArgs), &rv); |
| return rv; |
| } |
| |
| namespace detail { |
| template <typename R, int nleft, int index, typename F> |
| struct unpack_call_dispatcher { |
| template <typename... Args> |
| TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv, |
| Args&&... unpacked_args) { |
| // construct a movable argument value |
| // which allows potential move of argument to the input of F. |
| unpack_call_dispatcher<R, nleft - 1, index + 1, F>::run( |
| f, args_pack, rv, std::forward<Args>(unpacked_args)..., |
| TVMMovableArgValue_(args_pack.values[index], args_pack.type_codes[index])); |
| } |
| }; |
| |
| template <typename R, int index, typename F> |
| struct unpack_call_dispatcher<R, 0, index, F> { |
| template <typename... Args> |
| TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv, |
| Args&&... unpacked_args) { |
| using RetType = decltype(f(std::forward<Args>(unpacked_args)...)); |
| if (std::is_same<RetType, R>::value) { |
| *rv = f(std::forward<Args>(unpacked_args)...); |
| } else { |
| *rv = R(f(std::forward<Args>(unpacked_args)...)); |
| } |
| } |
| }; |
| |
| template <int index, typename F> |
| struct unpack_call_dispatcher<void, 0, index, F> { |
| template <typename... Args> |
| TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv, |
| Args&&... unpacked_args) { |
| f(std::forward<Args>(unpacked_args)...); |
| } |
| }; |
| |
| template <typename R, int nargs, typename F> |
| TVM_ALWAYS_INLINE void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) { |
| CHECK_EQ(nargs, args.size()) << "Expect " << nargs << " arguments but get " << args.size(); |
| unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv); |
| } |
| |
| template <typename FType> |
| struct unpack_call_by_signature {}; |
| |
| template <typename R, typename... Args> |
| struct unpack_call_by_signature<R(Args...)> { |
| template <typename F> |
| TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args, TVMRetValue* rv) { |
| unpack_call<R, sizeof...(Args)>(f, args, rv); |
| } |
| }; |
| |
| template <typename R, typename... Args> |
| TVM_ALWAYS_INLINE R call_packed(const PackedFunc& pf, Args&&... args) { |
| return R(pf(std::forward<Args>(args)...)); |
| } |
| |
| template <typename R> |
| struct typed_packed_call_dispatcher { |
| template <typename... Args> |
| TVM_ALWAYS_INLINE static R run(const PackedFunc& pf, Args&&... args) { |
| return pf(std::forward<Args>(args)...); |
| } |
| }; |
| |
| template <> |
| struct typed_packed_call_dispatcher<void> { |
| template <typename... Args> |
| TVM_ALWAYS_INLINE static void run(const PackedFunc& pf, Args&&... args) { |
| pf(std::forward<Args>(args)...); |
| } |
| }; |
| } // namespace detail |
| |
| template <typename R, typename... Args> |
| TypedPackedFunc<R(Args...)>::TypedPackedFunc(PackedFunc packed) : packed_(packed) {} |
| |
| template <typename R, typename... Args> |
| TypedPackedFunc<R(Args...)>::TypedPackedFunc(const TVMRetValue& value) |
| : packed_(value.operator PackedFunc()) {} |
| |
| template <typename R, typename... Args> |
| TypedPackedFunc<R(Args...)>::TypedPackedFunc(const TVMArgValue& value) |
| : packed_(value.operator PackedFunc()) {} |
| |
| template <typename R, typename... Args> |
| TypedPackedFunc<R(Args...)>::TypedPackedFunc(TVMMovableArgValue_&& value) |
| : packed_(value.operator PackedFunc()) {} |
| |
| template <typename R, typename... Args> |
| template <typename FType> |
| inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) { |
| packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) { |
| detail::unpack_call<R, sizeof...(Args)>(flambda, args, rv); |
| }); |
| } |
| |
| template <typename R, typename... Args> |
| TVM_ALWAYS_INLINE R TypedPackedFunc<R(Args...)>::operator()(Args... args) const { |
| return detail::typed_packed_call_dispatcher<R>::run(packed_, std::forward<Args>(args)...); |
| } |
| |
| // ObjectRef related conversion handling |
| // Object can have three possible type codes: |
| // kTVMNDArrayHandle, kTVMModuleHandle, kTVMObjectHandle |
| // |
| // We use type traits to eliminate un-necessary checks. |
| template <typename T> |
| inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { |
| using ContainerType = typename std::remove_reference<T>::type::ContainerType; |
| if (value.defined()) { |
| Object* ptr = value.data_.data_; |
| if (std::is_base_of<NDArray::ContainerType, ContainerType>::value || |
| (std::is_base_of<ContainerType, NDArray::ContainerType>::value && |
| ptr->IsInstance<NDArray::ContainerType>())) { |
| values_[i].v_handle = NDArray::FFIGetHandle(value); |
| type_codes_[i] = kTVMNDArrayHandle; |
| } else if (std::is_base_of<Module::ContainerType, ContainerType>::value || |
| (std::is_base_of<ContainerType, Module::ContainerType>::value && |
| ptr->IsInstance<Module::ContainerType>())) { |
| values_[i].v_handle = ptr; |
| type_codes_[i] = kTVMModuleHandle; |
| } else if (std::is_rvalue_reference<decltype(value)>::value) { |
| values_[i].v_handle = const_cast<Object**>(&(value.data_.data_)); |
| type_codes_[i] = kTVMObjectRValueRefArg; |
| } else { |
| values_[i].v_handle = value.data_.data_; |
| type_codes_[i] = kTVMObjectHandle; |
| } |
| } else { |
| type_codes_[i] = kTVMNullptr; |
| } |
| } |
| |
| template <typename TObjectRef, typename> |
| inline bool TVMPODValue_::IsObjectRef() const { |
| using ContainerType = typename TObjectRef::ContainerType; |
| // NOTE: the following code can be optimized by constant folding. |
| if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) { |
| return type_code_ == kTVMNDArrayHandle && |
| TVMArrayHandleToObjectHandle(static_cast<TVMArrayHandle>(value_.v_handle)) |
| ->IsInstance<ContainerType>(); |
| } |
| if (std::is_base_of<Module::ContainerType, ContainerType>::value) { |
| return type_code_ == kTVMModuleHandle && |
| static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>(); |
| } |
| // NOTE: we don't pass NDArray and runtime::Module as RValue ref. |
| if (type_code_ == kTVMObjectRValueRefArg) { |
| return ObjectTypeChecker<TObjectRef>::Check(*static_cast<Object**>(value_.v_handle)); |
| } |
| return (std::is_base_of<ContainerType, NDArray::ContainerType>::value && |
| type_code_ == kTVMNDArrayHandle) || |
| (std::is_base_of<ContainerType, Module::ContainerType>::value && |
| type_code_ == kTVMModuleHandle) || |
| (type_code_ == kTVMObjectHandle && |
| ObjectTypeChecker<TObjectRef>::Check(static_cast<Object*>(value_.v_handle))); |
| } |
| |
| template <typename TObjectRef> |
| inline TObjectRef TVMPODValue_::AsObjectRef() const { |
| static_assert(std::is_base_of<ObjectRef, TObjectRef>::value, |
| "Conversion only works for ObjectRef"); |
| using ContainerType = typename TObjectRef::ContainerType; |
| |
| if (type_code_ == kTVMNullptr) { |
| CHECK(TObjectRef::_type_is_nullable) |
| << "Expect a not null value of " << ContainerType::_type_key; |
| return TObjectRef(ObjectPtr<Object>(nullptr)); |
| } |
| // NOTE: the following code can be optimized by constant folding. |
| if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) { |
| // Casting to a sub-class of NDArray |
| TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle); |
| ObjectPtr<Object> data = |
| NDArray::FFIDataFromHandle(static_cast<TVMArrayHandle>(value_.v_handle)); |
| CHECK(data->IsInstance<ContainerType>()) |
| << "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey(); |
| return TObjectRef(data); |
| } |
| if (std::is_base_of<Module::ContainerType, ContainerType>::value) { |
| // Casting to a sub-class of Module |
| TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); |
| ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)); |
| CHECK(data->IsInstance<ContainerType>()) |
| << "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey(); |
| return TObjectRef(data); |
| } |
| if (type_code_ == kTVMObjectHandle) { |
| // normal object type check. |
| Object* ptr = static_cast<Object*>(value_.v_handle); |
| CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr)) |
| << "Expect " << ObjectTypeChecker<TObjectRef>::TypeName() << " but get " |
| << ptr->GetTypeKey(); |
| return TObjectRef(GetObjectPtr<Object>(ptr)); |
| } else if (type_code_ == kTVMObjectRValueRefArg) { |
| Object* ptr = *static_cast<Object**>(value_.v_handle); |
| CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr)) |
| << "Expect " << ObjectTypeChecker<TObjectRef>::TypeName() << " but get " |
| << ptr->GetTypeKey(); |
| return TObjectRef(GetObjectPtr<Object>(ptr)); |
| } else if (std::is_base_of<ContainerType, NDArray::ContainerType>::value && |
| type_code_ == kTVMNDArrayHandle) { |
| // Casting to a base class that NDArray can sub-class |
| ObjectPtr<Object> data = |
| NDArray::FFIDataFromHandle(static_cast<TVMArrayHandle>(value_.v_handle)); |
| return TObjectRef(data); |
| } else if (std::is_base_of<ContainerType, Module::ContainerType>::value && |
| type_code_ == kTVMModuleHandle) { |
| // Casting to a base class that Module can sub-class |
| return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle))); |
| } else { |
| TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); |
| return TObjectRef(ObjectPtr<Object>(nullptr)); |
| } |
| } |
| |
| template <typename TObjectRef, typename> |
| inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { |
| using ContainerType = typename TObjectRef::ContainerType; |
| const Object* ptr = other.get(); |
| if (ptr != nullptr) { |
| if (std::is_base_of<NDArray::ContainerType, ContainerType>::value || |
| (std::is_base_of<ContainerType, NDArray::ContainerType>::value && |
| ptr->IsInstance<NDArray::ContainerType>())) { |
| return operator=(NDArray(std::move(other.data_))); |
| } |
| if (std::is_base_of<Module::ContainerType, ContainerType>::value || |
| (std::is_base_of<ContainerType, Module::ContainerType>::value && |
| ptr->IsInstance<Module::ContainerType>())) { |
| return operator=(Module(std::move(other.data_))); |
| } |
| SwitchToObject(kTVMObjectHandle, std::move(other.data_)); |
| } else { |
| SwitchToPOD(kTVMNullptr); |
| } |
| return *this; |
| } |
| |
| template <typename T, typename> |
| inline TVMArgValue::operator T() const { |
| return PackedFuncValueConverter<T>::From(*this); |
| } |
| |
| template <typename T, typename> |
| inline TVMMovableArgValue_::operator T() const { |
| if (type_code_ == kTVMObjectRValueRefArg) { |
| auto** ref = static_cast<Object**>(value_.v_handle); |
| if (ObjectTypeChecker<T>::Check(*ref)) { |
| return T(ObjectPtr<Object>::MoveFromRValueRefArg(ref)); |
| } |
| } |
| // fallback |
| return PackedFuncValueConverter<T>::From(*this); |
| } |
| |
| template <typename T, typename> |
| inline TVMRetValue::operator T() const { |
| return PackedFuncValueConverter<T>::From(*this); |
| } |
| |
| inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) { |
| return (*this)->GetFunction(name, query_imports); |
| } |
| |
| // specializations of PackedFuncValueConverter |
| template <> |
| struct PackedFuncValueConverter<::tvm::runtime::String> { |
| static String From(const TVMArgValue& val) { |
| if (val.IsObjectRef<tvm::runtime::String>()) { |
| return val.AsObjectRef<tvm::runtime::String>(); |
| } else { |
| return tvm::runtime::String(val.operator std::string()); |
| } |
| } |
| |
| static String From(const TVMRetValue& val) { |
| if (val.IsObjectRef<tvm::runtime::String>()) { |
| return val.AsObjectRef<tvm::runtime::String>(); |
| } else { |
| return tvm::runtime::String(val.operator std::string()); |
| } |
| } |
| }; |
| |
| template <typename T> |
| struct PackedFuncValueConverter<Optional<T>> { |
| static Optional<T> From(const TVMArgValue& val) { |
| if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr); |
| return PackedFuncValueConverter<T>::From(val); |
| } |
| static Optional<T> From(const TVMRetValue& val) { |
| if (val.type_code() == kTVMNullptr) return Optional<T>(nullptr); |
| return PackedFuncValueConverter<T>::From(val); |
| } |
| }; |
| |
| inline bool String::CanConvertFrom(const TVMArgValue& val) { |
| return val.type_code() == kTVMStr || val.IsObjectRef<tvm::runtime::String>(); |
| } |
| |
| inline TVMArgValue::operator DLDataType() const { |
| if (String::CanConvertFrom(*this)) { |
| return String2DLDataType(PackedFuncValueConverter<String>::From(*this).operator std::string()); |
| } |
| // None type |
| if (type_code_ == kTVMNullptr) { |
| DLDataType t; |
| t.code = kTVMOpaqueHandle; |
| t.bits = 0; |
| t.lanes = 0; |
| return t; |
| } |
| TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType); |
| return value_.v_type; |
| } |
| |
| inline TVMArgValue::operator DataType() const { return DataType(operator DLDataType()); } |
| |
| } // namespace runtime |
| } // namespace tvm |
| #endif // TVM_RUNTIME_PACKED_FUNC_H_ |