blob: c5e6d4f75843cbc469c357db3e6ff2923ef0410b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <gtest/gtest.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/data_type_rewriter.h>
#include <tvm/tir/op.h>
using namespace tvm;
using namespace tvm::tir;
using namespace tvm::runtime;
using BinaryOpTypes =
::testing::Types<Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod, Min, Max, EQ, NE, LT, LE, GT, GE>;
template <typename T>
class DataTypeLegalizerBinaryOp : public ::testing::Test {};
TYPED_TEST_SUITE(DataTypeLegalizerBinaryOp, BinaryOpTypes);
TYPED_TEST(DataTypeLegalizerBinaryOp, Basic) {
using RefType = TypeParam;
using NodeType = typename RefType::ContainerType;
auto node = make_object<NodeType>();
node->a = Var("a", DataType::Int(32));
node->b = IntImm(DataType::Int(64), 2);
DataTypeLegalizer legalizer;
auto new_expr = Downcast<RefType>(legalizer(RefType(node)));
auto target_dtype = DataType::Int(64);
ASSERT_EQ(new_expr->a.dtype(), target_dtype);
ASSERT_EQ(new_expr->b.dtype(), target_dtype);
}
TEST(DataTypeLegalizer, Select) {
auto node = make_object<SelectNode>();
node->condition = Var("cond", DataType::Bool());
node->true_value = Var("a", DataType::Int(64));
node->false_value = IntImm(DataType::Int(32), 2);
DataTypeLegalizer legalizer;
Select new_select = Downcast<Select>(legalizer(Select(node)));
auto target_dtype = DataType::Int(64);
ASSERT_EQ(new_select->true_value.dtype(), target_dtype);
ASSERT_EQ(new_select->false_value.dtype(), target_dtype);
ASSERT_EQ(new_select.dtype(), target_dtype);
ASSERT_EQ(new_select->condition.dtype(), node->condition.dtype());
}
TEST(DataTypeLegalizer, IfThenElse) {
auto cond = Var("cond", DataType::Bool());
PrimExpr call = Call(DataType::Int(32), builtin::if_then_else(),
{cond, Var("a", DataType::Int(64)), IntImm(DataType::Int(32), 2)});
DataTypeLegalizer legalizer;
Call new_call = Downcast<Call>(legalizer(call));
auto target_dtype = DataType::Int(64);
ASSERT_EQ(new_call->args[1].dtype(), target_dtype);
ASSERT_EQ(new_call->args[2].dtype(), target_dtype);
ASSERT_EQ(new_call->dtype, target_dtype);
}
TEST(DataTypeLegalizer, Block) {
auto block_node = make_object<BlockNode>();
auto iter_var_node = make_object<IterVarNode>();
iter_var_node->var = Var("i", DataType::Int(32));
iter_var_node->dom =
Range::FromMinExtent(IntImm(DataType::Int(64), 0), IntImm(DataType::Int(64), 10));
iter_var_node->iter_type = IterVarType::kDataPar;
block_node->iter_vars = {IterVar(iter_var_node)};
block_node->reads = {};
block_node->writes = {};
block_node->name_hint = "block";
block_node->body = Evaluate(Integer(0));
auto block_realize_node = make_object<BlockRealizeNode>();
auto loop_var = Var("i", DataType::Int(32));
block_realize_node->iter_values = {loop_var};
block_realize_node->predicate = const_true();
block_realize_node->block = Block(block_node);
auto for_node = make_object<ForNode>();
for_node->loop_var = loop_var;
for_node->min = IntImm(DataType::Int(64), 0);
for_node->extent = IntImm(DataType::Int(64), 10);
for_node->kind = ForKind::kSerial;
for_node->body = BlockRealize(block_realize_node);
Stmt stmt = For(for_node);
DataTypeLegalizer legalizer;
DataType target_dtype = loop_var->dtype;
Stmt new_stmt = legalizer(stmt);
const ForNode* new_for = new_stmt.as<ForNode>();
ASSERT_EQ(new_for->loop_var.dtype(), target_dtype);
ASSERT_EQ(new_for->min.dtype(), target_dtype);
ASSERT_EQ(new_for->extent.dtype(), target_dtype);
const BlockRealizeNode* new_block_realize = new_for->body.as<BlockRealizeNode>();
ASSERT_EQ(new_block_realize->iter_values[0].dtype(), target_dtype);
const BlockNode* new_block = new_block_realize->block.as<BlockNode>();
ASSERT_EQ(new_block->iter_vars[0]->dom->min.dtype(), target_dtype);
ASSERT_EQ(new_block->iter_vars[0]->dom->extent.dtype(), target_dtype);
ASSERT_EQ(new_block->iter_vars[0]->var.dtype(), target_dtype);
}
TEST(DataTypeLegalizer, For) {
auto node = make_object<ForNode>();
node->body = Evaluate(Integer(0));
node->loop_var = Var("i", DataType::Int(32));
node->min = IntImm(DataType::Int(64), 0);
node->extent = IntImm(DataType::Int(64), 10);
DataTypeLegalizer legalizer;
For new_for = Downcast<For>(legalizer(For(node)));
ASSERT_EQ(new_for->min.dtype(), DataType::Int(32));
ASSERT_EQ(new_for->extent.dtype(), DataType::Int(32));
ASSERT_EQ(new_for->loop_var.dtype(), DataType::Int(32));
}
TEST(DataTypeLegalizer, Ramp) {
auto node = make_object<RampNode>();
node->base = IntImm(DataType::Int(64), 0);
node->stride = IntImm(DataType::Int(32), 1);
int lanes = 4;
node->lanes = lanes;
DataTypeLegalizer legalizer;
Ramp new_ramp = Downcast<Ramp>(legalizer(Ramp(node)));
DataType target_dtype = DataType::Int(64);
ASSERT_EQ(new_ramp->base.dtype(), target_dtype);
ASSERT_EQ(new_ramp->stride.dtype(), target_dtype);
ASSERT_EQ(new_ramp->dtype, target_dtype.with_lanes(lanes));
}