blob: 138e0b242e0223847b28e57fb0b87bf96df8eed8 [file] [log] [blame]
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/attrs.h>
#include <tvm/ir.h>
namespace tvm {
namespace test {
// test example usage docs
struct TestAttrs : public AttrsNode<TestAttrs> {
int axis;
std::string name;
Expr expr;
double learning_rate;
TVM_DECLARE_ATTRS(TestAttrs, "attrs.cpptest.TestAttrs") {
TVM_ATTR_FIELD(axis)
.set_default(10)
.set_lower_bound(1)
.set_upper_bound(10)
.describe("axis field");
TVM_ATTR_FIELD(name)
.describe("name of the field");
TVM_ATTR_FIELD(expr)
.describe("expression field")
.set_default(make_const(Int(32), 1));
TVM_ATTR_FIELD(learning_rate)
.describe("learning_rate")
.set_default(0.1);
}
};
}
}
TEST(Attrs, Basic) {
using namespace tvm;
using namespace tvm::test;
std::shared_ptr<TestAttrs> n = std::make_shared<TestAttrs>();
try {
n->InitBySeq("axis", 10);
LOG(FATAL) << "bad";
} catch (const tvm::AttrError& e) {
}
try {
n->InitBySeq("axis", 12, "name", "111");
LOG(FATAL) << "bad";
} catch (const tvm::AttrError& e) {
}
try {
n->InitBySeq("axisx", 12, "name", "111");
LOG(FATAL) << "bad";
} catch (const tvm::AttrError& e) {
std::string what = e.what();
CHECK(what.find("expr : Expr, default=1") != std::string::npos);
CHECK(what.find("axisx") != std::string::npos);
}
n->InitBySeq("learning_rate", Expr(1), "expr", 128, "name", "xx");
CHECK_EQ(n->learning_rate, 1.0);
n->InitBySeq("name", "xxx", "expr", 128);
CHECK_EQ(n->name, "xxx");
CHECK_EQ(n->axis, 10);
CHECK_EQ(n->expr.as<tvm::ir::IntImm>()->value, 128);
// Check docstring
std::ostringstream os;
n->PrintDocString(os);
LOG(INFO) << "docstring\n"<< os.str();
CHECK(os.str().find("expr : Expr, default=1") != std::string::npos);
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}