blob: c4e5db89de07e6af3e0a40b4188af2dd7adac248 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <gtest/gtest.h>
#include <tvm/driver/driver_api.h>
#include <tvm/ir/module.h>
#include <tvm/node/structural_equal.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/op_strategy.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/topi/broadcast.h>
#include <tvm/topi/generic/injective.h>
using namespace tvm;
TVM_REGISTER_GLOBAL("test.seq.strategy")
.set_body_typed([](const Attrs& attrs, const Array<te::Tensor>& inputs, const Type& out_type,
const Target& target) {
relay::FTVMCompute fcompute = [](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) -> Array<te::Tensor> {
ICHECK_EQ(inputs.size(), 2U);
return {topi::add(inputs[0], inputs[1])};
};
relay::FTVMSchedule fschedule = [](const Attrs& attrs, const Array<te::Tensor>& outs,
const Target& target) {
With<Target> target_scope(target);
return topi::generic::schedule_injective(target, outs);
};
auto n = make_object<relay::OpStrategyNode>();
auto strategy = relay::OpStrategy(std::move(n));
strategy.AddImplementation(fcompute, fschedule, "test.strategy", 10);
return strategy;
});
TEST(Relay, Sequential) {
auto tensor_type = relay::TensorType({1, 2, 3}, DataType::Float(32));
auto c_data = tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
// Create a function for optimization.
auto c = relay::Constant(c_data);
auto a = relay::Var("a", tensor_type);
auto x = relay::Var("x", tensor_type);
auto add_op = relay::Op::Get("add");
auto y = relay::Call(add_op, {c, c});
y = relay::Call(add_op, {x, y});
auto z = relay::Call(add_op, {y, c});
auto z1 = relay::Call(add_op, {y, c});
auto z2 = relay::Call(add_op, {z, z1});
// Let expression and varaible a should be dead-code eliminated.
auto z3 = relay::Let(a, c, z2);
relay::Function func = relay::Function(relay::FreeVars(z3), z3, relay::Type(), {});
auto reg = tvm::runtime::Registry::Get("ir.RegisterOpAttr");
if (!reg) {
LOG(FATAL) << "Register is not defined.";
}
auto reset = tvm::runtime::Registry::Get("ir.OpResetAttr");
if (!reset) {
LOG(FATAL) << "Reset is not defined.";
}
auto fs = tvm::runtime::Registry::Get("test.seq.strategy");
if (!fs) {
LOG(FATAL) << "Strategy is not defined.";
}
auto fgeneric = GenericFunc::Get("test.strategy_generic").set_default(*fs, true);
(*reset)(add_op, "FTVMStrategy");
(*reg)("add", "FTVMStrategy", fgeneric, 10);
// Run sequential passes.
tvm::Array<relay::transform::Pass> pass_seqs{
relay::transform::InferType(), relay::transform::DeadCodeElimination(),
relay::transform::EliminateCommonSubexpr(), relay::transform::AlterOpLayout()};
relay::transform::Pass seq = relay::transform::Sequential(pass_seqs);
auto mod = IRModule::FromExpr(func);
auto pass_ctx = relay::transform::PassContext::Create();
pass_ctx->opt_level = 3;
pass_ctx->config.Set("relay.fallback_device_type", Integer(1));
{
tvm::With<relay::transform::PassContext> ctx_scope(pass_ctx);
tvm::With<tvm::Target> tctx(tvm::Target("llvm"));
mod = seq(mod);
}
ICHECK(mod.defined());
auto entry_func = mod->GetGlobalVar("main");
ICHECK(entry_func.defined());
relay::Function f = Downcast<relay::Function>(mod->Lookup("main"));
ICHECK(f.defined());
// Expected function
auto c1 = relay::Constant(c_data);
auto x1 = relay::Var("x", tensor_type);
auto y1 = relay::Call(add_op, {c1, c1});
y1 = relay::Call(add_op, {x1, y1});
auto zz = relay::Call(add_op, {y1, c1});
zz = relay::Call(add_op, {zz, zz});
relay::Function expected_func = relay::Function(relay::FreeVars(zz), zz, relay::Type(), {});
// Infer type for the expected function.
auto mod1 = IRModule::FromExpr(expected_func);
mod1 = relay::transform::InferType()(mod1);
auto expected = mod1->Lookup("main");
ICHECK(tvm::StructuralEqual()(f, expected));
}
TEST(PassContextListConfigs, Basic) {
Map<String, Map<String, String>> configs = relay::transform::PassContext::ListConfigs();
ICHECK_EQ(configs.empty(), false);
auto config = configs["relay.backend.use_auto_scheduler"];
ICHECK_EQ(config["type"], "IntImm");
}