blob: 60b5f40b98f16365e94b6714dd9d0eb38c687a31 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Expose container API to frontend.
* \file src/node/container.cc
*/
#include <tvm/node/container.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include "../support/str_escape.h"
namespace tvm {
// SEQualReduce traits for runtime containers.
struct StringObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static void SHashReduce(const runtime::StringObj* key, SHashReducer hash_reduce) {
hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes(key->data, key->size));
}
static bool SEqualReduce(const runtime::StringObj* lhs, const runtime::StringObj* rhs,
SEqualReducer equal) {
if (lhs == rhs) return true;
if (lhs->size != rhs->size) return false;
if (lhs->data == rhs->data) return true;
return std::memcmp(lhs->data, rhs->data, lhs->size) == 0;
}
};
struct RefToObjectPtr : public ObjectRef {
static ObjectPtr<Object> Get(const ObjectRef& ref) { return GetDataPtr<Object>(ref); }
};
TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait)
.set_creator([](const std::string& bytes) {
return RefToObjectPtr::Get(runtime::String(bytes));
})
.set_repr_bytes([](const Object* n) -> std::string {
return GetRef<runtime::String>(static_cast<const runtime::StringObj*>(n))
.
operator std::string();
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<runtime::StringObj>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const runtime::StringObj*>(node.get());
p->stream << '"' << support::StrEscape(op->data, op->size) << '"';
});
struct ADTObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static void SHashReduce(const runtime::ADTObj* key, SHashReducer hash_reduce) {
hash_reduce(key->tag);
hash_reduce(static_cast<uint64_t>(key->size));
for (uint32_t i = 0; i < key->size; ++i) {
hash_reduce((*key)[i]);
}
}
static bool SEqualReduce(const runtime::ADTObj* lhs, const runtime::ADTObj* rhs,
SEqualReducer equal) {
if (lhs == rhs) return true;
if (lhs->tag != rhs->tag) return false;
if (lhs->size != rhs->size) return false;
for (uint32_t i = 0; i < lhs->size; ++i) {
if (!equal((*lhs)[i], (*rhs)[i])) return false;
}
return true;
}
};
TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait);
struct NDArrayContainerTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static void SHashReduce(const runtime::NDArray::Container* key, SHashReducer hash_reduce) {
CHECK_EQ(key->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor";
CHECK(runtime::IsContiguous(key->dl_tensor)) << "Can only hash contiguous tensor";
hash_reduce(runtime::DataType(key->dl_tensor.dtype));
hash_reduce(key->dl_tensor.ndim);
for (int i = 0; i < key->dl_tensor.ndim; ++i) {
hash_reduce(key->dl_tensor.shape[i]);
}
hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes(
static_cast<const char*>(key->dl_tensor.data), runtime::GetDataSize(key->dl_tensor)));
}
static bool SEqualReduce(const runtime::NDArray::Container* lhs,
const runtime::NDArray::Container* rhs, SEqualReducer equal) {
if (lhs == rhs) return true;
auto ldt = lhs->dl_tensor.dtype;
auto rdt = rhs->dl_tensor.dtype;
CHECK_EQ(lhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor";
CHECK_EQ(rhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor";
CHECK(runtime::IsContiguous(lhs->dl_tensor)) << "Can only compare contiguous tensor";
CHECK(runtime::IsContiguous(rhs->dl_tensor)) << "Can only compare contiguous tensor";
if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false;
for (int i = 0; i < lhs->dl_tensor.ndim; ++i) {
if (!equal(lhs->dl_tensor.shape[i], rhs->dl_tensor.shape[i])) return false;
}
if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
size_t data_size = runtime::GetDataSize(lhs->dl_tensor);
return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0;
} else {
return false;
}
}
};
TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait);
struct ArrayNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static void SHashReduce(const ArrayNode* key, SHashReducer hash_reduce) {
hash_reduce(static_cast<uint64_t>(key->size()));
for (size_t i = 0; i < key->size(); ++i) {
hash_reduce(key->at(i));
}
}
static bool SEqualReduce(const ArrayNode* lhs, const ArrayNode* rhs, SEqualReducer equal) {
if (lhs->size() != rhs->size()) return false;
for (size_t i = 0; i < lhs->size(); ++i) {
if (!equal(lhs->at(i), rhs->at(i))) return false;
}
return true;
}
};
TVM_REGISTER_OBJECT_TYPE(ArrayNode);
TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait)
.set_creator([](const std::string&) -> ObjectPtr<Object> {
return ::tvm::runtime::make_object<ArrayNode>();
});
TVM_REGISTER_GLOBAL("node.Array").set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<ObjectRef> data;
for (int i = 0; i < args.size(); ++i) {
if (args[i].type_code() != kTVMNullptr) {
data.push_back(args[i].operator ObjectRef());
} else {
data.push_back(ObjectRef(nullptr));
}
}
*ret = Array<ObjectRef>(data);
});
TVM_REGISTER_GLOBAL("node.ArrayGetItem").set_body([](TVMArgs args, TVMRetValue* ret) {
int64_t i = args[1];
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
CHECK(ptr->IsInstance<ArrayNode>());
auto* n = static_cast<const ArrayNode*>(ptr);
CHECK_LT(static_cast<size_t>(i), n->size()) << "out of bound of array";
*ret = n->at(i);
});
TVM_REGISTER_GLOBAL("node.ArraySize").set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
CHECK(ptr->IsInstance<ArrayNode>());
*ret = static_cast<int64_t>(static_cast<const ArrayNode*>(ptr)->size());
});
struct MapNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static void SHashReduceForOMap(const MapNode* key, SHashReducer hash_reduce) {
// SHash's var handling depends on the determinism of traversal.
// NOTE: only book-keep the mapped hash keys.
// This resolves common use cases where we want to store
// Map<Var, Value> where Var is defined in the function
// parameters.
using KV = std::pair<size_t, ObjectRef>;
std::vector<KV> temp;
for (const auto& kv : *key) {
size_t hashed_value;
if (hash_reduce->LookupHashedValue(kv.first, &hashed_value)) {
temp.emplace_back(hashed_value, kv.second);
}
}
// sort by the hash key of the keys.
std::sort(temp.begin(), temp.end(),
[](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; });
// add size to the hash
hash_reduce(static_cast<uint64_t>(key->size()));
// hash the content
for (size_t i = 0; i < temp.size();) {
size_t k = i + 1;
for (; k < temp.size() && temp[k].first == temp[i].first; ++k) {
}
// ties are rare, but we need to skip them to make the hash determinsitic
if (k == i + 1) {
hash_reduce->SHashReduceHashedValue(temp[i].first);
hash_reduce(temp[i].second);
}
i = k;
}
}
static void SHashReduceForSMap(const MapNode* key, SHashReducer hash_reduce) {
// NOTE: only book-keep the mapped hash keys.
// This resolves common use cases where we want to store
// Map<Var, Value> where Var is defined in the function
// parameters.
using KV = std::pair<String, ObjectRef>;
std::vector<KV> temp;
for (const auto& kv : *key) {
temp.push_back(std::make_pair(Downcast<String>(kv.first), kv.second));
}
// sort by the hash key of the keys.
std::sort(temp.begin(), temp.end(),
[](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; });
// NOTE: we won't have ties
// add size to the hash after sorting.
hash_reduce(static_cast<uint64_t>(key->size()));
// hash the content
for (size_t i = 0; i < temp.size(); ++i) {
hash_reduce(temp[i].first);
hash_reduce(temp[i].second);
}
}
static void SHashReduce(const MapNode* key, SHashReducer hash_reduce) {
bool is_str_map = std::all_of(key->begin(), key->end(), [](const auto& v) {
return v.first->template IsInstance<StringObj>();
});
if (is_str_map) {
SHashReduceForSMap(key, hash_reduce);
} else {
SHashReduceForOMap(key, hash_reduce);
}
}
static bool SEqualReduceForOMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) {
for (const auto& kv : *lhs) {
// Only allow equal checking if the keys are already mapped
// This resolves common use cases where we want to store
// Map<Var, Value> where Var is defined in the function
// parameters.
ObjectRef rhs_key = equal->MapLhsToRhs(kv.first);
if (!rhs_key.defined()) return false;
auto it = rhs->find(rhs_key);
if (it == rhs->end()) return false;
if (!equal(kv.second, it->second)) return false;
}
return true;
}
static bool SEqualReduceForSMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) {
for (const auto& kv : *lhs) {
auto it = rhs->find(kv.first);
if (it == rhs->end()) return false;
if (!equal(kv.second, it->second)) return false;
}
return true;
}
static bool SEqualReduce(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) {
if (rhs->size() != lhs->size()) return false;
if (rhs->size() == 0) return true;
bool ls = std::all_of(lhs->begin(), lhs->end(),
[](const auto& v) { return v.first->template IsInstance<StringObj>(); });
bool rs = std::all_of(rhs->begin(), rhs->end(),
[](const auto& v) { return v.first->template IsInstance<StringObj>(); });
if (ls != rs) {
return false;
}
return (ls && rs) ? SEqualReduceForSMap(lhs, rhs, equal) : SEqualReduceForOMap(lhs, rhs, equal);
}
};
TVM_REGISTER_OBJECT_TYPE(MapNode);
TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait)
.set_creator([](const std::string&) -> ObjectPtr<Object> { return MapNode::Empty(); });
TVM_REGISTER_GLOBAL("node.Map").set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size() % 2, 0);
std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> data;
for (int i = 0; i < args.num_args; i += 2) {
ObjectRef k =
String::CanConvertFrom(args[i]) ? args[i].operator String() : args[i].operator ObjectRef();
ObjectRef v = args[i + 1];
data.emplace(std::move(k), std::move(v));
}
*ret = Map<ObjectRef, ObjectRef>(std::move(data));
});
TVM_REGISTER_GLOBAL("node.MapSize").set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
CHECK(ptr->IsInstance<MapNode>());
auto* n = static_cast<const MapNode*>(ptr);
*ret = static_cast<int64_t>(n->size());
});
TVM_REGISTER_GLOBAL("node.MapGetItem").set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
CHECK(ptr->IsInstance<MapNode>());
auto* n = static_cast<const MapNode*>(ptr);
auto it = n->find(String::CanConvertFrom(args[1]) ? args[1].operator String()
: args[1].operator ObjectRef());
CHECK(it != n->end()) << "cannot find the corresponding key in the Map";
*ret = (*it).second;
});
TVM_REGISTER_GLOBAL("node.MapCount").set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
CHECK(ptr->IsInstance<MapNode>());
const MapNode* n = static_cast<const MapNode*>(ptr);
int64_t cnt = n->count(String::CanConvertFrom(args[1]) ? args[1].operator String()
: args[1].operator ObjectRef());
*ret = cnt;
});
TVM_REGISTER_GLOBAL("node.MapItems").set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
auto* n = static_cast<const MapNode*>(ptr);
Array<ObjectRef> rkvs;
for (const auto& kv : *n) {
if (kv.first->IsInstance<StringObj>()) {
rkvs.push_back(Downcast<String>(kv.first));
} else {
rkvs.push_back(kv.first);
}
rkvs.push_back(kv.second);
}
*ret = std::move(rkvs);
});
#if (USE_FALLBACK_STL_MAP == 0)
TVM_DLL constexpr uint64_t DenseMapNode::kNextProbeLocation[];
#endif
} // namespace tvm