blob: 45366f3ad55a82b018294ef6203926005dd03b73 [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file tvm/packed_func_ext.h
* \brief Extension package to PackedFunc
* This enales pass NodeRef types into/from PackedFunc.
*/
#ifndef TVM_PACKED_FUNC_EXT_H_
#define TVM_PACKED_FUNC_EXT_H_
#include <sstream>
#include <string>
#include <memory>
#include <limits>
#include <type_traits>
#include "base.h"
#include "expr.h"
#include "tensor.h"
#include "runtime/packed_func.h"
namespace tvm {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;
namespace runtime {
/*!
* \brief Runtime type checker for node type.
* \tparam T the type to be checked.
*/
template<typename T>
struct NodeTypeChecker {
static inline bool Check(Node* sptr) {
// This is the only place in the project where RTTI is used
// It can be turned off, but will make non strict checking.
// TODO(tqchen) possibly find alternative to turn of RTTI
using ContainerType = typename T::ContainerType;
// always allow nullptr.
if (sptr == nullptr) return true;
return sptr->derived_from<ContainerType>();
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType;
os << ContainerType::_type_key;
}
};
template<typename T>
struct NodeTypeChecker<Array<T> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return true;
if (!sptr->is_type<ArrayNode>()) return false;
ArrayNode* n = static_cast<ArrayNode*>(sptr);
for (const auto& p : n->data) {
if (!NodeTypeChecker<T>::Check(p.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "array<";
NodeTypeChecker<T>::PrintName(os);
os << ">";
}
};
template<typename V>
struct NodeTypeChecker<Map<std::string, V> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return true;
if (!sptr->is_type<StrMapNode>()) return false;
StrMapNode* n = static_cast<StrMapNode*>(sptr);
for (const auto& kv : n->data) {
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<string";
os << ',';
NodeTypeChecker<V>::PrintName(os);
os << '>';
}
};
template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return true;
if (!sptr->is_type<MapNode>()) return false;
MapNode* n = static_cast<MapNode*>(sptr);
for (const auto& kv : n->data) {
if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<";
NodeTypeChecker<K>::PrintName(os);
os << ',';
NodeTypeChecker<V>::PrintName(os);
os << '>';
}
};
template<typename T>
inline std::string NodeTypeName() {
std::ostringstream os;
NodeTypeChecker<T>::PrintName(os);
return os.str();
}
// extensions for tvm arg value
template<typename TNodeRef>
inline TNodeRef TVMArgValue::AsNodeRef() const {
static_assert(
std::is_base_of<NodeRef, TNodeRef>::value,
"Conversion only works for NodeRef");
if (type_code_ == kNull) return TNodeRef(NodePtr<Node>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<TNodeRef>()
<< " but get " << sptr->type_key();
return TNodeRef(sptr);
}
inline TVMArgValue::operator HalideIR::Expr() const {
if (type_code_ == kNull) return Expr();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return Expr(static_cast<int>(value_.v_int64));
}
if (type_code_ == kDLFloat) {
return Expr(static_cast<float>(value_.v_float64));
}
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
if (sptr->is_type<IterVarNode>()) {
return IterVar(sptr)->var;
}
if (sptr->is_type<TensorNode>()) {
return Tensor(sptr)();
}
CHECK(NodeTypeChecker<Expr>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<Expr>()
<< " but get " << sptr->type_key();
return Expr(sptr);
}
inline TVMArgValue::operator tvm::Integer() const {
if (type_code_ == kNull) return Integer();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return Integer(static_cast<int>(value_.v_int64));
}
NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
CHECK(NodeTypeChecker<Integer>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<Expr>()
<< " but get " << sptr->type_key();
return Integer(sptr);
}
inline NodePtr<Node>& TVMArgValue::node_sptr() {
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
return *ptr<NodePtr<Node> >();
}
template<typename TNodeRef, typename>
inline bool TVMArgValue::IsNodeType() const {
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
NodePtr<Node>& sptr =
*ptr<NodePtr<Node> >();
return NodeTypeChecker<TNodeRef>::Check(sptr.get());
}
// extensions for TVMRetValue
inline TVMRetValue& TVMRetValue::operator=(
const NodePtr<Node>& other) {
if (other.get() == nullptr) {
SwitchToPOD(kNull);
} else {
SwitchToClass<NodePtr<Node> >(kNodeHandle, other);
}
return *this;
}
inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) {
if (!other.defined()) {
SwitchToPOD(kNull);
} else {
SwitchToClass<NodePtr<Node> >(kNodeHandle, other.node_);
}
return *this;
}
template<typename TNodeRef>
inline TNodeRef TVMRetValue::AsNodeRef() const {
static_assert(
std::is_base_of<NodeRef, TNodeRef>::value,
"Conversion only works for NodeRef");
if (type_code_ == kNull) return TNodeRef();
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<TNodeRef>()
<< " but get " << sptr->type_key();
return TNodeRef(sptr);
}
inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { // NOLINT(*)
if (other.defined()) {
values_[i].v_handle = const_cast<NodePtr<Node>*>(&(other.node_));
type_codes_[i] = kNodeHandle;
} else {
type_codes_[i] = kNull;
}
}
// type related stuffs
inline TVMRetValue& TVMRetValue::operator=(const HalideIR::Type& t) {
return this->operator=(Type2TVMType(t));
}
inline TVMRetValue::operator HalideIR::Type() const {
return TVMType2Type(operator TVMType());
}
inline TVMArgValue::operator HalideIR::Type() const {
return TVMType2Type(operator TVMType());
}
inline void TVMArgsSetter::operator()(
size_t i, const HalideIR::Type& t) const {
this->operator()(i, Type2TVMType(t));
}
} // namespace runtime
} // namespace tvm
#endif // TVM_PACKED_FUNC_EXT_H_