blob: b4c53ab845209370eb5d4339b4171d50b92b1f95 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/ir/expr.h>
#include <tvm/target/target.h>
#include <cmath>
#include <string>
using namespace tvm;
TVM_REGISTER_TARGET_KIND("TestTargetKind", kDLCPU)
.set_attr<std::string>("Attr1", "Value1")
.add_attr_option<Bool>("my_bool")
.add_attr_option<Array<String>>("your_names")
.add_attr_option<Map<String, Integer>>("her_maps");
TEST(TargetKind, GetAttrMap) {
auto map = tvm::TargetKind::GetAttrMap<std::string>("Attr1");
auto target_kind = tvm::TargetKind::Get("TestTargetKind").value();
std::string result = map[target_kind];
CHECK_EQ(result, "Value1");
}
TEST(TargetCreation, NestedConfig) {
Map<String, ObjectRef> config = {
{"my_bool", Bool(true)},
{"your_names", Array<String>{"junru", "jian"}},
{"kind", String("TestTargetKind")},
{
"her_maps",
Map<String, Integer>{
{"a", 1},
{"b", 2},
},
},
};
Target target = Target(config);
CHECK_EQ(target->kind, TargetKind::Get("TestTargetKind").value());
CHECK_EQ(target->tag, "");
CHECK(target->keys.empty());
Bool my_bool = target->GetAttr<Bool>("my_bool").value();
CHECK_EQ(my_bool.operator bool(), true);
Array<String> your_names = target->GetAttr<Array<String>>("your_names").value();
CHECK_EQ(your_names.size(), 2U);
CHECK_EQ(your_names[0], "junru");
CHECK_EQ(your_names[1], "jian");
Map<String, Integer> her_maps = target->GetAttr<Map<String, Integer>>("her_maps").value();
CHECK_EQ(her_maps.size(), 2U);
CHECK_EQ(her_maps["a"], 1);
CHECK_EQ(her_maps["b"], 2);
}
TEST(TargetCreationFail, UnrecognizedConfigOption) {
Map<String, ObjectRef> config = {
{"my_bool", Bool(true)},
{"your_names", Array<String>{"junru", "jian"}},
{"kind", String("TestTargetKind")},
{"bad", ObjectRef(nullptr)},
{
"her_maps",
Map<String, Integer>{
{"a", 1},
{"b", 2},
},
},
};
bool failed = false;
try {
Target tgt(config);
} catch (...) {
failed = true;
}
ASSERT_EQ(failed, true);
}
TEST(TargetCreationFail, TypeMismatch) {
Map<String, ObjectRef> config = {
{"my_bool", String("true")},
{"your_names", Array<String>{"junru", "jian"}},
{"kind", String("TestTargetKind")},
{
"her_maps",
Map<String, Integer>{
{"a", 1},
{"b", 2},
},
},
};
bool failed = false;
try {
Target tgt(config);
} catch (...) {
failed = true;
}
ASSERT_EQ(failed, true);
}
TEST(TargetCreationFail, TargetKindNotFound) {
Map<String, ObjectRef> config = {
{"my_bool", Bool("true")},
{"your_names", Array<String>{"junru", "jian"}},
{
"her_maps",
Map<String, Integer>{
{"a", 1},
{"b", 2},
},
},
};
bool failed = false;
try {
Target tgt(config);
} catch (...) {
failed = true;
}
ASSERT_EQ(failed, true);
}
TEST(TargetCreation, DeduplicateKeys) {
Map<String, ObjectRef> config = {
{"kind", String("llvm")},
{"keys", Array<String>{"cpu", "arm_cpu"}},
{"device", String("arm_cpu")},
};
Target target = Target(config);
CHECK_EQ(target->kind, TargetKind::Get("llvm").value());
CHECK_EQ(target->tag, "");
CHECK_EQ(target->keys.size(), 2U);
CHECK_EQ(target->keys[0], "cpu");
CHECK_EQ(target->keys[1], "arm_cpu");
CHECK_EQ(target->attrs.size(), 1U);
CHECK_EQ(target->GetAttr<String>("device"), "arm_cpu");
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}