blob: b10aff96a11698b1014410f7e20708a01deed4a6 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/runtime/object.h
* \brief A managed object in the TVM runtime.
*/
#ifndef TVM_RUNTIME_OBJECT_H_
#define TVM_RUNTIME_OBJECT_H_
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/logging.h>
#include <string>
#include <type_traits>
#include <utility>
/*!
* \brief Whether or not use atomic reference counter.
* If the reference counter is not atomic,
* an object cannot be owned by multiple threads.
* We can, however, move an object across threads
*/
#ifndef TVM_OBJECT_ATOMIC_REF_COUNTER
#define TVM_OBJECT_ATOMIC_REF_COUNTER 1
#endif
#if TVM_OBJECT_ATOMIC_REF_COUNTER
#include <atomic>
#endif // TVM_OBJECT_ATOMIC_REF_COUNTER
namespace tvm {
namespace runtime {
/*!
* \brief Namespace for the list of type index.
* \note Use struct so that we have to use TypeIndex::ENumName to refer to
* the constant, but still able to use enum.
*/
struct TypeIndex {
enum {
/*! \brief Root object type. */
kRoot = 0,
// Standard static index assignments,
// Frontends can take benefit of these constants.
/*! \brief runtime::Module. */
kRuntimeModule = 1,
/*! \brief runtime::NDArray. */
kRuntimeNDArray = 2,
/*! \brief runtime::String. */
kRuntimeString = 3,
/*! \brief runtime::Array. */
kRuntimeArray = 4,
/*! \brief runtime::Map. */
kRuntimeMap = 5,
/*! \brief runtime::ShapeTuple. */
kRuntimeShapeTuple = 6,
/*! \brief runtime::PackedFunc. */
kRuntimePackedFunc = 7,
// static assignments that may subject to change.
kRuntimeClosure,
kRuntimeADT,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
};
}; // namespace TypeIndex
/*!
* \brief base class of all object containers.
*
* Sub-class of objects should declare the following static constexpr fields:
*
* - _type_index:
* Static type index of the object, if assigned to TypeIndex::kDynamic
* the type index will be assigned during runtime.
* Runtime type index can be accessed by ObjectType::TypeIndex();
* - _type_key:
* The unique string identifier of the type.
* - _type_final:
* Whether the type is terminal type(there is no subclass of the type in the object system).
* This field is automatically set by macro TVM_DECLARE_FINAL_OBJECT_INFO
* It is still OK to sub-class a terminal object type T and construct it using make_object.
* But IsInstance check will only show that the object type is T(instead of the sub-class).
*
* The following two fields are necessary for base classes that can be sub-classed.
*
* - _type_child_slots:
* Number of reserved type index slots for child classes.
* Used for runtime optimization for type checking in IsInstance.
* If an object's type_index is within range of [type_index, type_index + _type_child_slots]
* Then the object can be quickly decided as sub-class of the current object class.
* If not, a fallback mechanism is used to check the global type table.
* Recommendation: set to estimate number of children needed.
* - _type_child_slots_can_overflow:
* Whether we can add additional child classes even if the number of child classes
* exceeds the _type_child_slots. A fallback mechanism to check global type table will be
* used. Recommendation: set to false for optimal runtime speed if we know exact number of children.
*
* Two macros are used to declare helper functions in the object:
* - Use TVM_DECLARE_BASE_OBJECT_INFO for object classes that can be sub-classed.
* - Use TVM_DECLARE_FINAL_OBJECT_INFO for object classes that cannot be sub-classed.
*
* New objects can be created using make_object function.
* Which will automatically populate the type_index and deleter of the object.
*
* \sa make_object
* \sa ObjectPtr
* \sa ObjectRef
*
* \code
*
* // Create a base object
* class BaseObj : public Object {
* public:
* // object fields
* int field0;
*
* // object properties
* static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
* static constexpr const char* _type_key = "test.BaseObj";
* TVM_DECLARE_BASE_OBJECT_INFO(BaseObj, Object);
* };
*
* class LeafObj : public BaseObj {
* public:
* // fields
* int child_field0;
* // object properties
* static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
* static constexpr const char* _type_key = "test.LeafObj";
* TVM_DECLARE_BASE_OBJECT_INFO(LeafObj, Object);
* };
*
* // The following code should be put into a cc file.
* TVM_REGISTER_OBJECT_TYPE(BaseObj);
* TVM_REGISTER_OBJECT_TYPE(LeafObj);
*
* // Usage example.
* void TestObjects() {
* // create an object
* ObjectRef leaf_ref(make_object<LeafObj>());
* // cast to a specific instance
* const LeafObj* leaf_ptr = leaf_ref.as<LeafObj>();
* ICHECK(leaf_ptr != nullptr);
* // can also cast to the base class.
* ICHECK(leaf_ref.as<BaseObj>() != nullptr);
* }
*
* \endcode
*/
class TVM_DLL Object {
public:
/*!
* \brief Object deleter
* \param self pointer to the Object.
*/
typedef void (*FDeleter)(Object* self);
/*! \return The internal runtime type index of the object. */
uint32_t type_index() const { return type_index_; }
/*!
* \return the type key of the object.
* \note this operation is expensive, can be used for error reporting.
*/
std::string GetTypeKey() const { return TypeIndex2Key(type_index_); }
/*!
* \return A hash value of the return of GetTypeKey.
*/
size_t GetTypeKeyHash() const { return TypeIndex2KeyHash(type_index_); }
/*!
* Check if the object is an instance of TargetType.
* \tparam TargetType The target type to be checked.
* \return Whether the target type is true.
*/
template <typename TargetType>
inline bool IsInstance() const;
/*!
* \return Whether the cell has only one reference
* \note We use stl style naming to be consistent with known API in shared_ptr.
*/
inline bool unique() const;
/*!
* \brief Get the type key of the corresponding index from runtime.
* \param tindex The type index.
* \return the result.
*/
static std::string TypeIndex2Key(uint32_t tindex);
/*!
* \brief Get the type key hash of the corresponding index from runtime.
* \param tindex The type index.
* \return the related key-hash.
*/
static size_t TypeIndex2KeyHash(uint32_t tindex);
/*!
* \brief Get the type index of the corresponding key from runtime.
* \param key The type key.
* \return the result.
*/
static uint32_t TypeKey2Index(const std::string& key);
#if TVM_OBJECT_ATOMIC_REF_COUNTER
using RefCounterType = std::atomic<int32_t>;
#else
using RefCounterType = int32_t;
#endif
static constexpr const char* _type_key = "runtime.Object";
static uint32_t _GetOrAllocRuntimeTypeIndex() { return TypeIndex::kRoot; }
static uint32_t RuntimeTypeIndex() { return TypeIndex::kRoot; }
// Default object type properties for sub-classes
static constexpr bool _type_final = false;
static constexpr uint32_t _type_child_slots = 0;
static constexpr bool _type_child_slots_can_overflow = true;
// member information
static constexpr bool _type_has_method_visit_attrs = true;
static constexpr bool _type_has_method_sequal_reduce = false;
static constexpr bool _type_has_method_shash_reduce = false;
// NOTE: the following field is not type index of Object
// but was intended to be used by sub-classes as default value.
// The type index of Object is TypeIndex::kRoot
static constexpr uint32_t _type_index = TypeIndex::kDynamic;
// Default constructor and copy constructor
Object() {}
// Override the copy and assign constructors to do nothing.
// This is to make sure only contents, but not deleter and ref_counter
// are copied when a child class copies itself.
// This will enable us to use make_object<ObjectClass>(*obj_ptr)
// to copy an existing object.
Object(const Object& other) { // NOLINT(*)
}
Object(Object&& other) { // NOLINT(*)
}
Object& operator=(const Object& other) { // NOLINT(*)
return *this;
}
Object& operator=(Object&& other) { // NOLINT(*)
return *this;
}
protected:
// The fields of the base object cell.
/*! \brief Type index(tag) that indicates the type of the object. */
uint32_t type_index_{0};
/*! \brief The internal reference counter */
RefCounterType ref_counter_{0};
/*!
* \brief deleter of this object to enable customized allocation.
* If the deleter is nullptr, no deletion will be performed.
* The creator of the object must always set the deleter field properly.
*/
FDeleter deleter_ = nullptr;
// Invariant checks.
static_assert(sizeof(int32_t) == sizeof(RefCounterType) &&
alignof(int32_t) == sizeof(RefCounterType),
"RefCounter ABI check.");
/*!
* \brief Get the type index using type key.
*
* When the function is first time called for a type,
* it will register the type to the type table in the runtime.
* If the static_tindex is TypeIndex::kDynamic, the function will
* allocate a runtime type index.
* Otherwise, we will populate the type table and return the static index.
*
* \param key the type key.
* \param static_tindex The current _type_index field.
* can be TypeIndex::kDynamic.
* \param parent_tindex The index of the parent.
* \param type_child_slots Number of slots reserved for its children.
* \param type_child_slots_can_overflow Whether to allow child to overflow the slots.
* \return The allocated type index.
*/
static uint32_t GetOrAllocRuntimeTypeIndex(const std::string& key, uint32_t static_tindex,
uint32_t parent_tindex, uint32_t type_child_slots,
bool type_child_slots_can_overflow);
// reference counter related operations
/*! \brief developer function, increases reference counter. */
inline void IncRef();
/*!
* \brief developer function, decrease reference counter.
* \note The deleter will be called when ref_counter_ becomes zero.
*/
inline void DecRef();
private:
/*!
* \return The usage count of the cell.
* \note We use stl style naming to be consistent with known API in shared_ptr.
*/
inline int use_count() const;
/*!
* \brief Check of this object is derived from the parent.
* \param parent_tindex The parent type index.
* \return The derivation results.
*/
bool DerivedFrom(uint32_t parent_tindex) const;
// friend classes
template <typename>
friend class ObjAllocatorBase;
template <typename>
friend class ObjectPtr;
friend class TVMRetValue;
friend class ObjectInternal;
};
/*!
* \brief Get a reference type from a raw object ptr type
*
* It is always important to get a reference type
* if we want to return a value as reference or keep
* the object alive beyond the scope of the function.
*
* \param ptr The object pointer
* \tparam RefType The reference type
* \tparam ObjectType The object type
* \return The corresponding RefType
*/
template <typename RelayRefType, typename ObjectType>
inline RelayRefType GetRef(const ObjectType* ptr);
/*!
* \brief Downcast a base reference type to a more specific type.
*
* \param ref The input reference
* \return The corresponding SubRef.
* \tparam SubRef The target specific reference type.
* \tparam BaseRef the current reference type.
*/
template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref);
/*!
* \brief A custom smart pointer for Object.
* \tparam T the content data type.
* \sa make_object
*/
template <typename T>
class ObjectPtr {
public:
/*! \brief default constructor */
ObjectPtr() {}
/*! \brief default constructor */
ObjectPtr(std::nullptr_t) {} // NOLINT(*)
/*!
* \brief copy constructor
* \param other The value to be moved
*/
ObjectPtr(const ObjectPtr<T>& other) // NOLINT(*)
: ObjectPtr(other.data_) {}
/*!
* \brief copy constructor
* \param other The value to be moved
*/
template <typename U>
ObjectPtr(const ObjectPtr<U>& other) // NOLINT(*)
: ObjectPtr(other.data_) {
static_assert(std::is_base_of<T, U>::value,
"can only assign of child class ObjectPtr to parent");
}
/*!
* \brief move constructor
* \param other The value to be moved
*/
ObjectPtr(ObjectPtr<T>&& other) // NOLINT(*)
: data_(other.data_) {
other.data_ = nullptr;
}
/*!
* \brief move constructor
* \param other The value to be moved
*/
template <typename Y>
ObjectPtr(ObjectPtr<Y>&& other) // NOLINT(*)
: data_(other.data_) {
static_assert(std::is_base_of<T, Y>::value,
"can only assign of child class ObjectPtr to parent");
other.data_ = nullptr;
}
/*! \brief destructor */
~ObjectPtr() { this->reset(); }
/*!
* \brief Swap this array with another Object
* \param other The other Object
*/
void swap(ObjectPtr<T>& other) { // NOLINT(*)
std::swap(data_, other.data_);
}
/*!
* \return Get the content of the pointer
*/
T* get() const { return static_cast<T*>(data_); }
/*!
* \return The pointer
*/
T* operator->() const { return get(); }
/*!
* \return The reference
*/
T& operator*() const { // NOLINT(*)
return *get();
}
/*!
* \brief copy assignment
* \param other The value to be assigned.
* \return reference to self.
*/
ObjectPtr<T>& operator=(const ObjectPtr<T>& other) { // NOLINT(*)
// takes in plane operator to enable copy elison.
// copy-and-swap idiom
ObjectPtr(other).swap(*this); // NOLINT(*)
return *this;
}
/*!
* \brief move assignment
* \param other The value to be assigned.
* \return reference to self.
*/
ObjectPtr<T>& operator=(ObjectPtr<T>&& other) { // NOLINT(*)
// copy-and-swap idiom
ObjectPtr(std::move(other)).swap(*this); // NOLINT(*)
return *this;
}
/*!
* \brief nullptr check
* \return result of comparison of internal pointer with nullptr.
*/
explicit operator bool() const { return get() != nullptr; }
/*! \brief reset the content of ptr to be nullptr */
void reset() {
if (data_ != nullptr) {
data_->DecRef();
data_ = nullptr;
}
}
/*! \return The use count of the ptr, for debug purposes */
int use_count() const { return data_ != nullptr ? data_->use_count() : 0; }
/*! \return whether the reference is unique */
bool unique() const { return data_ != nullptr && data_->use_count() == 1; }
/*! \return Whether two ObjectPtr do not equal each other */
bool operator==(const ObjectPtr<T>& other) const { return data_ == other.data_; }
/*! \return Whether two ObjectPtr equals each other */
bool operator!=(const ObjectPtr<T>& other) const { return data_ != other.data_; }
/*! \return Whether the pointer is nullptr */
bool operator==(std::nullptr_t null) const { return data_ == nullptr; }
/*! \return Whether the pointer is not nullptr */
bool operator!=(std::nullptr_t null) const { return data_ != nullptr; }
private:
/*! \brief internal pointer field */
Object* data_{nullptr};
/*!
* \brief constructor from Object
* \param data The data pointer
*/
explicit ObjectPtr(Object* data) : data_(data) {
if (data != nullptr) {
data_->IncRef();
}
}
/*!
* \brief Move an ObjectPtr from an RValueRef argument.
* \param ref The rvalue reference.
* \return the moved result.
*/
static ObjectPtr<T> MoveFromRValueRefArg(Object** ref) {
ObjectPtr<T> ptr;
ptr.data_ = *ref;
*ref = nullptr;
return ptr;
}
// friend classes
friend class Object;
friend class ObjectRef;
friend struct ObjectPtrHash;
template <typename>
friend class ObjectPtr;
template <typename>
friend class ObjAllocatorBase;
friend class TVMPODValue_;
friend class TVMArgsSetter;
friend class TVMRetValue;
friend class TVMArgValue;
friend class TVMMovableArgValue_;
template <typename RelayRefType, typename ObjType>
friend RelayRefType GetRef(const ObjType* ptr);
template <typename BaseType, typename ObjType>
friend ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr);
};
// Forward declaration, to prevent circular includes.
template <typename T>
class Optional;
/*! \brief Base class of all object reference */
class ObjectRef {
public:
/*! \brief default constructor */
ObjectRef() = default;
/*! \brief Constructor from existing object ptr */
explicit ObjectRef(ObjectPtr<Object> data) : data_(data) {}
/*!
* \brief Comparator
* \param other Another object ref.
* \return the compare result.
*/
bool same_as(const ObjectRef& other) const { return data_ == other.data_; }
/*!
* \brief Comparator
* \param other Another object ref.
* \return the compare result.
*/
bool operator==(const ObjectRef& other) const { return data_ == other.data_; }
/*!
* \brief Comparator
* \param other Another object ref.
* \return the compare result.
*/
bool operator!=(const ObjectRef& other) const { return data_ != other.data_; }
/*!
* \brief Comparator
* \param other Another object ref by address.
* \return the compare result.
*/
bool operator<(const ObjectRef& other) const { return data_.get() < other.data_.get(); }
/*!
* \return whether the object is defined(not null).
*/
bool defined() const { return data_ != nullptr; }
/*! \return the internal object pointer */
const Object* get() const { return data_.get(); }
/*! \return the internal object pointer */
const Object* operator->() const { return get(); }
/*! \return whether the reference is unique */
bool unique() const { return data_.unique(); }
/*! \return The use count of the ptr, for debug purposes */
int use_count() const { return data_.use_count(); }
/*!
* \brief Try to downcast the internal Object to a
* raw pointer of a corresponding type.
*
* The function will return a nullptr if the cast failed.
*
* if (const AddNode *ptr = node_ref.as<AddNode>()) {
* // This is an add node
* }
*
* \tparam ObjectType the target type, must be a subtype of Object
*/
template <typename ObjectType, typename = std::enable_if_t<std::is_base_of_v<Object, ObjectType>>>
inline const ObjectType* as() const;
/*!
* \brief Try to downcast the ObjectRef to a
* Optional<T> of the requested type.
*
* The function will return a NullOpt if the cast failed.
*
* if (Optional<Add> opt = node_ref.as<Add>()) {
* // This is an add node
* }
*
* \note While this method is declared in <tvm/runtime/object.h>,
* the implementation is in <tvm/runtime/container/optional.h> to
* prevent circular includes. This additional include file is only
* required in compilation units that uses this method.
*
* \tparam ObjectRefType the target type, must be a subtype of ObjectRef
*/
template <typename ObjectRefType,
typename = std::enable_if_t<std::is_base_of_v<ObjectRef, ObjectRefType>>>
inline Optional<ObjectRefType> as() const;
/*! \brief type indicate the container type. */
using ContainerType = Object;
// Default type properties for the reference class.
static constexpr bool _type_is_nullable = true;
protected:
/*! \brief Internal pointer that backs the reference. */
ObjectPtr<Object> data_;
/*! \return return a mutable internal ptr, can be used by sub-classes. */
Object* get_mutable() const { return data_.get(); }
/*!
* \brief Internal helper function downcast a ref without check.
* \note Only used for internal dev purposes.
* \tparam T The target reference type.
* \return The casted result.
*/
template <typename T>
static T DowncastNoCheck(ObjectRef ref) {
return T(std::move(ref.data_));
}
/*!
* \brief Clear the object ref data field without DecRef
* after we successfully moved the field.
* \param ref The reference data.
*/
static void FFIClearAfterMove(ObjectRef* ref) { ref->data_.data_ = nullptr; }
/*!
* \brief Internal helper function get data_ as ObjectPtr of ObjectType.
* \note only used for internal dev purpose.
* \tparam ObjectType The corresponding object type.
* \return the corresponding type.
*/
template <typename ObjectType>
static ObjectPtr<ObjectType> GetDataPtr(const ObjectRef& ref) {
return ObjectPtr<ObjectType>(ref.data_.data_);
}
// friend classes.
friend struct ObjectPtrHash;
friend class TVMRetValue;
friend class TVMArgsSetter;
friend class ObjectInternal;
template <typename SubRef, typename BaseRef>
friend SubRef Downcast(BaseRef ref);
};
/*!
* \brief Get an object ptr type from a raw object ptr.
*
* \param ptr The object pointer
* \tparam BaseType The reference type
* \tparam ObjectType The object type
* \return The corresponding RefType
*/
template <typename BaseType, typename ObjectType>
inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr);
/*! \brief ObjectRef hash functor */
struct ObjectPtrHash {
size_t operator()(const ObjectRef& a) const { return operator()(a.data_); }
template <typename T>
size_t operator()(const ObjectPtr<T>& a) const {
return std::hash<Object*>()(a.get());
}
};
/*! \brief ObjectRef equal functor */
struct ObjectPtrEqual {
bool operator()(const ObjectRef& a, const ObjectRef& b) const { return a.same_as(b); }
template <typename T>
size_t operator()(const ObjectPtr<T>& a, const ObjectPtr<T>& b) const {
return a == b;
}
};
/*!
* \brief helper macro to declare a base object type that can be inherited.
* \param TypeName The name of the current type.
* \param ParentType The name of the ParentType
*/
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
static_assert(!ParentType::_type_final, "ParentObj marked as final"); \
static uint32_t RuntimeTypeIndex() { \
static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \
TypeName::_type_child_slots < ParentType::_type_child_slots, \
"Need to set _type_child_slots when parent specifies it."); \
if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \
return TypeName::_type_index; \
} \
return _GetOrAllocRuntimeTypeIndex(); \
} \
static uint32_t _GetOrAllocRuntimeTypeIndex() { \
static uint32_t tindex = Object::GetOrAllocRuntimeTypeIndex( \
TypeName::_type_key, TypeName::_type_index, ParentType::_GetOrAllocRuntimeTypeIndex(), \
TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow); \
return tindex; \
}
/*!
* \brief helper macro to declare type information in a final class.
* \param TypeName The name of the current type.
* \param ParentType The name of the ParentType
*/
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \
static const constexpr bool _type_final = true; \
static const constexpr int _type_child_slots = 0; \
TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
/*! \brief helper macro to suppress unused warning */
#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define TVM_ATTRIBUTE_UNUSED
#endif
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
#define TVM_OBJECT_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid
/*!
* \brief Helper macro to register the object type to runtime.
* Makes sure that the runtime type table is correctly populated.
*
* Use this macro in the cc file for each terminal class.
*/
#define TVM_REGISTER_OBJECT_TYPE(TypeName) \
TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = TypeName::_GetOrAllocRuntimeTypeIndex()
/*
* \brief Define the default copy/move constructor and assign operator
* \param TypeName The class typename.
*/
#define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \
TypeName(const TypeName& other) = default; \
TypeName(TypeName&& other) = default; \
TypeName& operator=(const TypeName& other) = default; \
TypeName& operator=(TypeName&& other) = default;
/*
* \brief Define object reference methods.
* \param TypeName The object type name
* \param ParentType The parent type of the objectref
* \param ObjectName The type name of the object.
*/
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
TypeName() = default; \
explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \
const ObjectName* get() const { return operator->(); } \
using ContainerType = ObjectName;
/*
* \brief Define object reference methods that is not nullable.
*
* \param TypeName The object type name
* \param ParentType The parent type of the objectref
* \param ObjectName The type name of the object.
*/
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \
const ObjectName* get() const { return operator->(); } \
static constexpr bool _type_is_nullable = false; \
using ContainerType = ObjectName;
/*
* \brief Define object reference methods of whose content is mutable.
* \param TypeName The object type name
* \param ParentType The parent type of the objectref
* \param ObjectName The type name of the object.
* \note We recommend making objects immutable when possible.
* This macro is only reserved for objects that stores runtime states.
*/
#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
TypeName() = default; \
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
ObjectName* operator->() const { return static_cast<ObjectName*>(data_.get()); } \
using ContainerType = ObjectName;
/*
* \brief Define object reference methods that is both not nullable and mutable.
*
* \param TypeName The object type name
* \param ParentType The parent type of the objectref
* \param ObjectName The type name of the object.
*/
#define TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
ObjectName* operator->() const { return static_cast<ObjectName*>(data_.get()); } \
ObjectName* get() const { return operator->(); } \
static constexpr bool _type_is_nullable = false; \
using ContainerType = ObjectName;
/*!
* \brief Define CopyOnWrite function in an ObjectRef.
* \param ObjectName The Type of the Node.
*
* CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places.
* The function returns the raw pointer to the node to allow modification
* of the content.
*
* \code
*
* MyCOWObjectRef ref, ref2;
* ref2 = ref;
* ref.CopyOnWrite()->value = new_value;
* assert(ref2->value == old_value);
* assert(ref->value == new_value);
*
* \endcode
*/
#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \
ObjectName* CopyOnWrite() { \
ICHECK(data_ != nullptr); \
if (!data_.unique()) { \
auto n = make_object<ObjectName>(*(operator->())); \
ObjectPtr<Object>(std::move(n)).swap(data_); \
} \
return static_cast<ObjectName*>(data_.get()); \
}
// Implementations details below
// Object reference counting.
#if TVM_OBJECT_ATOMIC_REF_COUNTER
inline void Object::IncRef() { ref_counter_.fetch_add(1, std::memory_order_relaxed); }
inline void Object::DecRef() {
if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
std::atomic_thread_fence(std::memory_order_acquire);
if (this->deleter_ != nullptr) {
(*this->deleter_)(this);
}
}
}
inline int Object::use_count() const { return ref_counter_.load(std::memory_order_relaxed); }
#else
inline void Object::IncRef() { ++ref_counter_; }
inline void Object::DecRef() {
if (--ref_counter_ == 0) {
if (this->deleter_ != nullptr) {
(*this->deleter_)(this);
}
}
}
inline int Object::use_count() const { return ref_counter_; }
#endif // TVM_OBJECT_ATOMIC_REF_COUNTER
template <typename TargetType>
inline bool Object::IsInstance() const {
const Object* self = this;
// NOTE: the following code can be optimized by
// compiler dead-code elimination for already known constants.
if (self != nullptr) {
// Everything is a subclass of object.
if (std::is_same<TargetType, Object>::value) return true;
if (TargetType::_type_final) {
// if the target type is a final type
// then we only need to check the equivalence.
return self->type_index_ == TargetType::RuntimeTypeIndex();
} else {
// if target type is a non-leaf type
// Check if type index falls into the range of reserved slots.
uint32_t begin = TargetType::RuntimeTypeIndex();
// The condition will be optimized by constant-folding.
if (TargetType::_type_child_slots != 0) {
uint32_t end = begin + TargetType::_type_child_slots;
if (self->type_index_ >= begin && self->type_index_ < end) return true;
} else {
if (self->type_index_ == begin) return true;
}
if (!TargetType::_type_child_slots_can_overflow) return false;
// Invariance: parent index is always smaller than the child.
if (self->type_index_ < TargetType::RuntimeTypeIndex()) return false;
// The rare slower-path, check type hierarchy.
return self->DerivedFrom(TargetType::RuntimeTypeIndex());
}
} else {
return false;
}
}
inline bool Object::unique() const { return use_count() == 1; }
template <typename ObjectType, typename>
inline const ObjectType* ObjectRef::as() const {
if (data_ != nullptr && data_->IsInstance<ObjectType>()) {
return static_cast<ObjectType*>(data_.get());
} else {
return nullptr;
}
}
template <typename RefType, typename ObjType>
inline RefType GetRef(const ObjType* ptr) {
static_assert(std::is_base_of<typename RefType::ContainerType, ObjType>::value,
"Can only cast to the ref of same container type");
if (!RefType::_type_is_nullable) {
ICHECK(ptr != nullptr);
}
return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
}
template <typename BaseType, typename ObjType>
inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) {
static_assert(std::is_base_of<BaseType, ObjType>::value,
"Can only cast to the ref of same container type");
return ObjectPtr<BaseType>(static_cast<Object*>(ptr));
}
template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref) {
if (ref.defined()) {
ICHECK(ref->template IsInstance<typename SubRef::ContainerType>())
<< "Downcast from " << ref->GetTypeKey() << " to " << SubRef::ContainerType::_type_key
<< " failed.";
} else {
ICHECK(SubRef::_type_is_nullable) << "Downcast from nullptr to not nullable reference of "
<< SubRef::ContainerType::_type_key;
}
return SubRef(std::move(ref.data_));
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_OBJECT_H_