blob: de06a0e7189fc3613a33894b0ce01aa1136e15ac [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/node/functor.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
TEST(IRF, Basic) {
using namespace tvm;
using namespace tvm::tir;
Var x("x");
auto z = x + 1;
NodeFunctor<int(const ObjectRef& n, int b)> f;
f.set_dispatch<VarNode>([](const ObjectRef& n, int b) { return b; });
f.set_dispatch<AddNode>([](const ObjectRef& n, int b) { return b + 2; });
CHECK_EQ(f(x, 2), 2);
CHECK_EQ(f(z, 2), 4);
}
TEST(IRF, CountVar) {
using namespace tvm;
using namespace tvm::tir;
int n_var = 0;
Var x("x"), y;
auto z = x + 1 + y + y;
tir::PostOrderVisit(z, [&n_var](const ObjectRef& n) {
if (n.as<VarNode>()) ++n_var;
});
CHECK_EQ(n_var, 2);
}
TEST(IRF, ExprTransform) {
using namespace tvm;
using namespace tvm::tir;
Var x("x");
auto z = x + 1;
class MyExprFunctor : public tir::ExprFunctor<int(const PrimExpr&, int)> {
public:
int VisitExpr_(const VarNode* op, int b) final { return b; }
int VisitExpr_(const IntImmNode* op, int b) final { return op->value; }
int VisitExpr_(const AddNode* op, int b) final {
return VisitExpr(op->a, b) + VisitExpr(op->b, b);
}
};
MyExprFunctor f;
CHECK_EQ(f(x, 2), 2);
CHECK_EQ(f(z, 2), 3);
try {
f(z - 1, 2);
LOG(FATAL) << "should fail";
} catch (dmlc::Error) {
}
}
TEST(IRF, ExprVisit) {
using namespace tvm;
using namespace tvm::tir;
Var x("x");
auto z = x + 1;
class MyVisitor : public tir::ExprFunctor<void(const PrimExpr&)>,
public tir::StmtFunctor<void(const Stmt&)> {
public:
int count = 0;
// implementation
void VisitExpr_(const VarNode* op) final { ++count; }
void VisitExpr_(const IntImmNode* op) final {}
void VisitExpr_(const AddNode* op) final {
VisitExpr(op->a);
VisitExpr(op->b);
}
void VisitStmt_(const EvaluateNode* op) final { VisitExpr(op->value); }
};
MyVisitor v;
v.VisitStmt(Evaluate(z));
CHECK_EQ(v.count, 1);
}
TEST(IRF, StmtVisitor) {
using namespace tvm;
using namespace tvm::tir;
Var x("x");
class MyVisitor : public StmtExprVisitor {
public:
int count = 0;
// implementation
void VisitExpr_(const VarNode* op) final { ++count; }
};
MyVisitor v;
auto fmaketest = [&]() {
auto z = x + 1;
Stmt body = Evaluate(z);
Var buffer("b", DataType::Handle());
return Allocate(buffer, DataType::Float(32), {z, z}, const_true(), body);
};
v(fmaketest());
CHECK_EQ(v.count, 3);
}
TEST(IRF, StmtMutator) {
using namespace tvm;
using namespace tvm::tir;
Var x("x");
class MyVisitor : public tir::StmtMutator, public tir::ExprMutator {
public:
using StmtMutator::operator();
using ExprMutator::operator();
protected:
// implementation
PrimExpr VisitExpr_(const AddNode* op) final { return op->a; }
Stmt VisitStmt_(const SeqStmtNode* op) final { return StmtMutator::VisitSeqStmt_(op, true); }
PrimExpr VisitExpr(const PrimExpr& expr) final { return ExprMutator::VisitExpr(expr); }
};
auto fmakealloc = [&]() {
auto z = x + 1;
Stmt body = Evaluate(z);
Var buffer("b", DataType::Handle());
return Allocate(buffer, DataType::Float(32), {1, z}, const_true(), body);
};
auto fmakeif = [&]() {
auto z = x + 1;
Stmt body = Evaluate(z);
return IfThenElse(x, Evaluate(0), body);
};
MyVisitor v;
{
auto body = fmakealloc();
Stmt body2 = Evaluate(1);
Stmt bref = body.as<AllocateNode>()->body;
auto* extentptr = body.as<AllocateNode>()->extents.get();
Array<Stmt> arr{std::move(body), body2, body2};
auto* arrptr = arr.get();
arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
CHECK(arr.get() == arrptr);
// inplace update body
CHECK(arr[0].as<AllocateNode>()->extents[1].same_as(x));
CHECK(arr[0].as<AllocateNode>()->extents.get() == extentptr);
// copy because there is additional refs
CHECK(!arr[0].as<AllocateNode>()->body.same_as(bref));
CHECK(arr[0].as<AllocateNode>()->body.as<EvaluateNode>()->value.same_as(x));
CHECK(bref.as<EvaluateNode>()->value.as<AddNode>());
}
{
Array<Stmt> arr{fmakealloc()};
// mutate array get reference by another one, triiger copy.
Array<Stmt> arr2 = arr;
auto* arrptr = arr.get();
arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
CHECK(arr.get() != arrptr);
CHECK(arr[0].as<AllocateNode>()->extents[1].same_as(x));
CHECK(!arr2[0].as<AllocateNode>()->extents[1].same_as(x));
// mutate but no content change.
arr2 = arr;
arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
CHECK(arr2.get() == arr.get());
}
{
Array<Stmt> arr{fmakeif()};
arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
CHECK(arr[0].as<IfThenElseNode>()->else_case.as<EvaluateNode>()->value.same_as(x));
// mutate but no content change.
auto arr2 = arr;
arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
CHECK(arr2.get() == arr.get());
}
{
auto body =
Evaluate(Call(DataType::Int(32), builtin::call_extern(), {StringImm("xyz"), x + 1}));
auto res = v(std::move(body));
CHECK(res.as<EvaluateNode>()->value.as<CallNode>()->args[1].same_as(x));
}
{
Stmt body = fmakealloc();
Stmt body2 = Evaluate(1);
auto* ref2 = body2.get();
auto* extentptr = body.as<AllocateNode>()->extents.get();
// construct a recursive SeqStmt.
body = SeqStmt({body});
body = SeqStmt({body, body2});
body = SeqStmt({body, body2});
body = v(std::move(body));
// the seq get flattened
CHECK(body.as<SeqStmtNode>()->size() == 3);
CHECK(body.as<SeqStmtNode>()->seq[0].as<AllocateNode>()->extents.get() == extentptr);
CHECK(body.as<SeqStmtNode>()->seq[1].get() == ref2);
}
{
// Cannot cow because of bref
Stmt body = fmakealloc();
Stmt body2 = Evaluate(1);
auto* extentptr = body.as<AllocateNode>()->extents.get();
// construct a recursive SeqStmt.
body = SeqStmt({body});
auto bref = body;
body = SeqStmt({body, body2});
body = v(std::move(body));
// the seq get flattened
CHECK(body.as<SeqStmtNode>()->seq[0].as<AllocateNode>()->extents.get() != extentptr);
}
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}