blob: 1daf1e7925534e0bf474437497a3d9aead98f71c [file] [log] [blame]
/*!
* Copyright (c) 2018 by Contributors
* \file attrs.cc
*/
#include <tvm/attrs.h>
#include <tvm/api_registry.h>
#include "attr_functor.h"
namespace tvm {
void DictAttrsNode::VisitAttrs(AttrVisitor* v) {
v->Visit("__dict__", &dict);
}
void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) {
v->Visit("__dict__", &dict);
}
void DictAttrsNode::InitByPackedArgs(
const runtime::TVMArgs& args, bool allow_unknown) {
for (int i = 0; i < args.size(); i += 2) {
std::string key = args[i];
runtime::TVMArgValue val = args[i + 1];
if (val.type_code() == kNodeHandle) {
dict.Set(key, val.operator NodeRef());
} else if (val.type_code() == kStr) {
dict.Set(key, Expr(val.operator std::string()));
} else {
dict.Set(key, val.operator Expr());
}
}
}
Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const {
return {};
}
Attrs DictAttrsNode::make(Map<std::string, NodeRef> dict) {
NodePtr<DictAttrsNode> n = make_node<DictAttrsNode>();
n->dict = std::move(dict);
return Attrs(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<DictAttrsNode>([](const DictAttrsNode *op, IRPrinter *p) {
p->stream << op->dict;
});
TVM_REGISTER_NODE_TYPE(DictAttrsNode);
TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode);
using namespace ir;
// Equal handler.
bool AttrsEqualHandler::Equal(const NodeRef& lhs, const NodeRef& rhs) {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
return this->VisitAttr(lhs, rhs);
}
bool AttrsEqualHandler::VisitAttrDefault_(const Node* lhs, const NodeRef& other) {
if (lhs->derived_from<BaseAttrsNode>()) {
AttrsEqual equal;
equal.handler_ = this;
return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual(
other.get(), equal);
}
return lhs == other.get();
}
bool AttrsEqualHandler::VisitAttr_(const IntImm* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<IntImm>()) {
return lhs->value == rhs->value;
}
return false;
}
bool AttrsEqualHandler::VisitAttr_(const UIntImm* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<UIntImm>()) {
return lhs->value == rhs->value;
}
return false;
}
bool AttrsEqualHandler::VisitAttr_(const FloatImm* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<FloatImm>()) {
return lhs->value == rhs->value;
}
return false;
}
bool AttrsEqualHandler::VisitAttr_(const StringImm* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<StringImm>()) {
return lhs->value == rhs->value;
}
return false;
}
bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<ArrayNode>()) {
if (rhs->data.size() != lhs->data.size()) return false;
for (size_t i = 0; i < lhs->data.size(); ++i) {
if (!Equal(NodeRef(lhs->data[i]), NodeRef(rhs->data[i]))) return false;
}
}
return true;
}
bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<StrMapNode>()) {
if (rhs->data.size() != lhs->data.size()) return false;
for (const auto& kv : lhs->data) {
auto it = rhs->data.find(kv.first);
if (it == rhs->data.end()) return false;
if (!Equal(NodeRef(kv.second), NodeRef(it->second))) return false;
}
}
return true;
}
#define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName) \
bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const NodeRef& other) { \
if (const auto* rhs = other.as<NodeName>()) { \
if (!Equal(lhs->a, rhs->a)) return false; \
if (!Equal(lhs->b, rhs->b)) return false; \
return true; \
} else { \
return false; \
} \
} \
TVM_DEFINE_ATTRS_BINOP_EQUAL(Add);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Div);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Max);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Min);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GE);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GT);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LE);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LT);
TVM_DEFINE_ATTRS_BINOP_EQUAL(EQ);
TVM_DEFINE_ATTRS_BINOP_EQUAL(NE);
TVM_DEFINE_ATTRS_BINOP_EQUAL(And);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Or);
bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<Not>()) {
return Equal(lhs->a, rhs->a);
} else {
return false;
}
}
bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<Cast>()) {
if (lhs->type != rhs->type) return false;
return Equal(lhs->value, rhs->value);
} else {
return false;
}
}
bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<Call>()) {
return
lhs->name == rhs->name &&
lhs->type == rhs->type &&
lhs->call_type == rhs->call_type &&
Equal(lhs->args, rhs->args);
} else {
return false;
}
}
bool AttrsEqualHandler::VisitAttr_(const Select* lhs, const NodeRef& other) {
if (const auto* rhs = other.as<Select>()) {
return
Equal(lhs->condition, rhs->condition) &&
Equal(lhs->true_value, rhs->true_value) &&
Equal(lhs->false_value, rhs->false_value);
} else {
return false;
}
}
// Hash Handler.
size_t AttrsHashHandler::VisitAttrDefault_(const Node* value) {
if (value->derived_from<BaseAttrsNode>()) {
AttrsHash hasher;
hasher.handler_ = this;
return static_cast<const BaseAttrsNode*>(value)->ContentHash(hasher);
} else {
return NodeHash()(GetRef<NodeRef>(value));
}
}
size_t AttrsHashHandler::VisitAttr_(const IntImm* op) {
return std::hash<int64_t>()(op->value);
}
size_t AttrsHashHandler::VisitAttr_(const UIntImm* op) {
return std::hash<uint64_t>()(op->value);
}
size_t AttrsHashHandler::VisitAttr_(const FloatImm* op) {
return std::hash<double>()(op->value);
}
size_t AttrsHashHandler::VisitAttr_(const StringImm* op) {
return std::hash<std::string>()(op->value);
}
size_t AttrsHashHandler::VisitAttr_(const ArrayNode* op) {
size_t result = op->data.size();
for (size_t i = 0; i < op->data.size(); ++i) {
result = Combine(result, this->Hash(NodeRef(op->data[i])));
}
return result;
}
size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) {
using Entry = std::pair<std::string, NodePtr<Node> >;
std::vector<Entry> data(lhs->data.begin(), lhs->data.end());
std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) {
return a.first < b.first;
});
size_t result = 0;
for (const Entry& kv : data) {
result = Combine(result, std::hash<std::string>()(kv.first));
result = Combine(result, this->Hash(NodeRef(kv.second)));
}
return result;
}
#define TVM_DEFINE_ATTRS_BINOP_HASH(NodeName) \
size_t AttrsHashHandler::VisitAttr_(const NodeName* op) { \
static size_t key = std::hash<std::string>()(NodeName::_type_key); \
return Combine(key, Combine(Hash(op->a), Hash(op->b))); \
} \
TVM_DEFINE_ATTRS_BINOP_HASH(Add);
TVM_DEFINE_ATTRS_BINOP_HASH(Sub);
TVM_DEFINE_ATTRS_BINOP_HASH(Mul);
TVM_DEFINE_ATTRS_BINOP_HASH(Div);
TVM_DEFINE_ATTRS_BINOP_HASH(Mod);
TVM_DEFINE_ATTRS_BINOP_HASH(Max);
TVM_DEFINE_ATTRS_BINOP_HASH(Min);
TVM_DEFINE_ATTRS_BINOP_HASH(GE);
TVM_DEFINE_ATTRS_BINOP_HASH(GT);
TVM_DEFINE_ATTRS_BINOP_HASH(LE);
TVM_DEFINE_ATTRS_BINOP_HASH(LT);
TVM_DEFINE_ATTRS_BINOP_HASH(EQ);
TVM_DEFINE_ATTRS_BINOP_HASH(NE);
TVM_DEFINE_ATTRS_BINOP_HASH(And);
TVM_DEFINE_ATTRS_BINOP_HASH(Or);
size_t AttrsHashHandler::VisitAttr_(const Not* op) {
static size_t key = std::hash<std::string>()(Not::_type_key);
return Combine(key, Hash(op->a));
}
size_t AttrsHashHandler::VisitAttr_(const Cast* op) {
static size_t key = std::hash<std::string>()(Cast::_type_key);
AttrsHash hasher;
size_t res = key;
res = Combine(res, hasher(op->type));
res = Combine(res, Hash(op->value));
return res;
}
size_t AttrsHashHandler::VisitAttr_(const Call* op) {
static size_t key = std::hash<std::string>()(Call::_type_key);
AttrsHash hasher;
size_t res = key;
res = Combine(res, hasher(op->name));
res = Combine(res, hasher(op->type));
res = Combine(res, Hash(op->args));
return res;
}
size_t AttrsHashHandler::VisitAttr_(const Select* op) {
static size_t key = std::hash<std::string>()(Select::_type_key);
size_t res = key;
res = Combine(res, Hash(op->condition));
res = Combine(res, Hash(op->true_value));
res = Combine(res, Hash(op->false_value));
return res;
}
// Default case
bool AttrsEqual::operator()(const NodeRef& lhs, const NodeRef& rhs) const {
if (lhs.same_as(rhs)) return true;
if (handler_ == nullptr) {
return AttrsEqualHandler().Equal(lhs, rhs);
} else {
return handler_->Equal(lhs, rhs);
}
}
size_t AttrsHash::operator()(const NodeRef& node) const {
if (!node.defined()) return 0;
if (handler_ == nullptr) {
return AttrsHashHandler().Hash(node);
} else {
return handler_->Hash(node);
}
}
size_t DictAttrsNode::ContentHash(AttrsHash hasher) const {
return hasher(this->dict);
}
bool DictAttrsNode::ContentEqual(const Node* other, AttrsEqual equal) const {
if (this == other) return true;
if (other == nullptr) return false;
if (this->type_index() != other->type_index()) return false;
return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
}
TVM_REGISTER_API("_AttrsListFieldInfo")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Attrs()->ListFieldInfo();
});
} // namespace tvm