blob: 7890a67eef958f5b9861e7e06eb4bf2db130320b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/node/repr_printer.h>
#include <tvm/runtime/container/map.h>
#include <tvm/script/printer/traced_object.h>
using namespace tvm;
namespace {
class DummyObjectNode : public Object {
public:
void VisitAttrs(AttrVisitor* v) {}
static constexpr const char* _type_key = "TracedObjectTestDummyObject";
TVM_DECLARE_FINAL_OBJECT_INFO(DummyObjectNode, Object);
};
class DummyObject : public ObjectRef {
public:
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(DummyObject, ObjectRef, DummyObjectNode);
};
TVM_REGISTER_NODE_TYPE(DummyObjectNode);
class ObjectWithAttrsNode : public Object {
public:
int64_t int64_attr = 5;
Map<String, String> map_attr;
Array<String> array_attr;
DummyObject obj_attr;
ObjectWithAttrsNode() : obj_attr(make_object<DummyObjectNode>()) {}
void VisitAttrs(AttrVisitor* v) {
v->Visit("int64_attr", &int64_attr);
v->Visit("map_attr", &map_attr);
v->Visit("array_attr", &array_attr);
v->Visit("obj_attr", &obj_attr);
}
static constexpr const char* _type_key = "TracedObjectTestObjectWithAttrs";
TVM_DECLARE_FINAL_OBJECT_INFO(ObjectWithAttrsNode, Object);
};
class ObjectWithAttrs : public ObjectRef {
public:
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ObjectWithAttrs, ObjectRef, ObjectWithAttrsNode);
};
TVM_REGISTER_NODE_TYPE(ObjectWithAttrsNode);
} // anonymous namespace
TEST(TracedObjectTest, MakeTraced_RootObject) {
ObjectWithAttrs root(make_object<ObjectWithAttrsNode>());
auto root_traced = MakeTraced(root);
static_assert(std::is_same<decltype(root_traced), TracedObject<ObjectWithAttrs>>::value);
ICHECK(root_traced.GetPath()->PathsEqual(ObjectPath::Root()));
ICHECK_EQ(root_traced.Get().get(), root.get());
}
TEST(TracedObjectTest, MakeTraced_WithPath) {
ObjectWithAttrs obj(make_object<ObjectWithAttrsNode>());
auto traced = MakeTraced(obj, ObjectPath::Root()->Attr("foo"));
static_assert(std::is_same<decltype(traced), TracedObject<ObjectWithAttrs>>::value);
ICHECK(traced.GetPath()->PathsEqual(ObjectPath::Root()->Attr("foo")));
ICHECK_EQ(traced.Get().get(), obj.get());
}
TEST(TracedObjectTest, TracedObject_ImplicitConversionFromDerived) {
DummyObject obj(make_object<DummyObjectNode>());
auto traced = MakeTraced(obj);
static_assert(std::is_same<decltype(traced), TracedObject<DummyObject>>::value);
// Check that TracedObject<DummyObject> is implicitly converted to TracedObject<ObjectRef>
auto base_traced = [](const TracedObject<ObjectRef>& base) { return base; }(traced);
static_assert(std::is_same<decltype(base_traced), TracedObject<ObjectRef>>::value);
}
TEST(TracedObjectTest, TracedObject_GetAttr_ObjectRef) {
ObjectWithAttrs root(make_object<ObjectWithAttrsNode>());
auto root_traced = MakeTraced(root);
auto obj_attr = root_traced.GetAttr(&ObjectWithAttrsNode::obj_attr);
static_assert(std::is_same<decltype(obj_attr), TracedObject<DummyObject>>::value);
ICHECK(obj_attr.GetPath()->PathsEqual(ObjectPath::Root()->Attr("obj_attr")));
ICHECK_EQ(obj_attr.Get().get(), root->obj_attr.get());
}
TEST(TracedObjectTest, TracedObject_GetAttr_Map) {
ObjectWithAttrs root(make_object<ObjectWithAttrsNode>());
root->map_attr.Set("foo", "bar");
auto root_traced = MakeTraced(root);
auto map_attr = root_traced.GetAttr(&ObjectWithAttrsNode::map_attr);
static_assert(std::is_same<decltype(map_attr), TracedMap<String, String>>::value);
ICHECK(map_attr.GetPath()->PathsEqual(ObjectPath::Root()->Attr("map_attr")));
ICHECK_EQ(map_attr.Get().get(), root->map_attr.get());
auto map_val = map_attr.at("foo");
ICHECK_EQ(map_val.Get(), "bar");
ICHECK(
map_val.GetPath()->PathsEqual(ObjectPath::Root()->Attr("map_attr")->MapValue(String("foo"))));
}
TEST(TracedObjectTest, TracedObject_GetAttr_Array) {
ObjectWithAttrs root(make_object<ObjectWithAttrsNode>());
root->array_attr.push_back("foo");
root->array_attr.push_back("bar");
auto root_traced = MakeTraced(root);
auto array_attr = root_traced.GetAttr(&ObjectWithAttrsNode::array_attr);
static_assert(std::is_same<decltype(array_attr), TracedArray<String>>::value);
ICHECK(array_attr.GetPath()->PathsEqual(ObjectPath::Root()->Attr("array_attr")));
ICHECK_EQ(array_attr.Get().get(), root->array_attr.get());
auto array_val = array_attr[1];
ICHECK_EQ(array_val.Get(), "bar");
ICHECK(array_val.GetPath()->PathsEqual(ObjectPath::Root()->Attr("array_attr")->ArrayIndex(1)));
}
TEST(TracedObjectTest, TracedObject_GetAttr_Int64) {
ObjectWithAttrs root(make_object<ObjectWithAttrsNode>());
auto root_traced = MakeTraced(root);
auto int64_attr = root_traced.GetAttr(&ObjectWithAttrsNode::int64_attr);
static_assert(std::is_same<decltype(int64_attr), TracedBasicValue<int64_t>>::value);
ICHECK_EQ(int64_attr.Get(), 5);
ICHECK(int64_attr.GetPath()->PathsEqual(ObjectPath::Root()->Attr("int64_attr")));
}
TEST(TracedObjectTest, TracedObject_IsInstance) {
ObjectRef dummy(make_object<DummyObjectNode>());
auto traced = MakeTraced(dummy);
ICHECK(traced.IsInstance<DummyObject>());
ICHECK(!traced.IsInstance<ObjectWithAttrs>());
}
TEST(TracedObjectTest, TracedObject_Downcast) {
ObjectRef root(make_object<DummyObjectNode>());
auto traced = MakeTraced(root);
auto as_dummy = traced.Downcast<DummyObject>();
static_assert(std::is_same<decltype(as_dummy), TracedObject<DummyObject>>::value);
ICHECK_EQ(as_dummy.Get(), root);
// Try downcasting to a wrong type
bool caught = false;
try {
traced.Downcast<ObjectWithAttrs>();
} catch (std::exception& e) {
caught = strstr(e.what(),
"Downcast from TracedObjectTestDummyObject to TracedObjectTestObjectWithAttrs "
"failed") != nullptr;
}
ICHECK(caught);
}
TEST(TracedObjectTest, TracedObject_TryDowncast) {
ObjectRef root(make_object<DummyObjectNode>());
auto traced = MakeTraced(root);
auto as_dummy = traced.TryDowncast<DummyObject>();
static_assert(std::is_same<decltype(as_dummy), TracedOptional<DummyObject>>::value);
ICHECK(as_dummy.defined());
ICHECK_EQ(as_dummy.value().Get(), root);
// Try downcasting to a wrong type
ICHECK(!traced.TryDowncast<ObjectWithAttrs>().defined());
}
TEST(TracedObjectTest, TracedMap_At) {
Map<String, String> m({{"k1", "foo"}, {"k2", "bar"}});
auto traced = MakeTraced(m);
auto traced_foo = traced.at("k1");
static_assert(std::is_same<decltype(traced_foo), TracedObject<String>>::value);
ICHECK_EQ(traced_foo.Get(), "foo");
ICHECK(traced_foo.GetPath()->PathsEqual(ObjectPath::Root()->MapValue(String("k1"))));
}
TEST(TracedObjectTest, TracedMap_Iterator) {
Map<String, String> m({{"k1", "foo"}, {"k2", "bar"}});
auto traced = MakeTraced(m);
size_t k1_count = 0;
size_t k2_count = 0;
for (const auto& kv : traced) {
if (kv.first == "k1") {
++k1_count;
ICHECK_EQ(kv.second.Get(), "foo");
ICHECK(kv.second.GetPath()->PathsEqual(ObjectPath::Root()->MapValue(String("k1"))));
} else if (kv.first == "k2") {
++k2_count;
ICHECK_EQ(kv.second.Get(), "bar");
ICHECK(kv.second.GetPath()->PathsEqual(ObjectPath::Root()->MapValue(String("k2"))));
} else {
ICHECK(false);
}
}
ICHECK_EQ(k1_count, 1);
ICHECK_EQ(k2_count, 1);
}
TEST(TracedObjectTest, TracedArray_Index) {
Array<String> a = {"foo", "bar"};
auto traced = MakeTraced(a);
auto traced_bar = traced[1];
static_assert(std::is_same<decltype(traced_bar), TracedObject<String>>::value);
ICHECK_EQ(traced_bar.Get(), "bar");
ICHECK(traced_bar.GetPath()->PathsEqual(ObjectPath::Root()->ArrayIndex(1)));
}
TEST(TracedObjectTest, TracedArray_Iterator) {
Array<String> a = {"foo", "bar"};
auto traced = MakeTraced(a);
size_t index = 0;
for (const auto& x : traced) {
if (index == 0) {
ICHECK_EQ(x.Get(), "foo");
ICHECK(x.GetPath()->PathsEqual(ObjectPath::Root()->ArrayIndex(0)));
} else if (index == 1) {
ICHECK_EQ(x.Get(), "bar");
ICHECK(x.GetPath()->PathsEqual(ObjectPath::Root()->ArrayIndex(1)));
} else {
ICHECK(false);
}
++index;
}
ICHECK_EQ(index, 2);
}
TEST(TracedObjectTest, TracedBasicValue_ApplyFunc) {
auto traced = MakeTraced(123, ObjectPath::Root()->Attr("foo"));
static_assert(std::is_same<decltype(traced), TracedBasicValue<int>>::value);
auto transformed = traced.ApplyFunc([](int x) { return x + 4.0; });
static_assert(std::is_same<decltype(transformed), TracedBasicValue<double>>::value);
ICHECK(transformed.GetPath()->PathsEqual(ObjectPath::Root()->Attr("foo")));
}