blob: e82c4f554be0ff1c8c722c6cb90459e68ef72b4b [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file ir_mutator.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/packed_func_ext.h>
#include "ir_util.h"
namespace tvm {
namespace ir {
class IRTransformer final : public IRMutator {
public:
IRTransformer(const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
const std::unordered_set<uint32_t>& only_enable)
: f_preorder_(f_preorder),
f_postorder_(f_postorder),
only_enable_(only_enable) {
}
Stmt Mutate(Stmt stmt) final {
return MutateInternal<Stmt>(stmt);
}
Expr Mutate(Expr expr) final {
return MutateInternal<Expr>(expr);
}
private:
template<typename T>
T MutateInternal(T node) {
if (only_enable_.size() &&
!only_enable_.count(node->type_index())) {
return IRMutator::Mutate(node);
}
if (f_preorder_ != nullptr) {
T pre = f_preorder_(node);
if (pre.defined()) return pre;
}
node = IRMutator::Mutate(node);
if (f_postorder_ != nullptr) {
T post = f_postorder_(node);
if (post.defined()) return post;
}
return node;
}
// The functions
const runtime::PackedFunc& f_preorder_;
const runtime::PackedFunc& f_postorder_;
// type indices enabled.
const std::unordered_set<uint32_t>& only_enable_;
};
Stmt IRTransform(const Stmt& ir_node,
const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
const Array<Expr>& only_enable) {
std::unordered_set<uint32_t> only_type_index;
for (Expr s : only_enable) {
only_type_index.insert(Node::TypeKey2Index(s.as<StringImm>()->value.c_str()));
}
return IRTransformer(f_preorder, f_postorder, only_type_index)
.Mutate(ir_node);
}
IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*)
static FMutateExpr inst; return inst;
}
IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*)
static FMutateStmt inst; return inst;
}
inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
return UpdateArray(arr, [&m] (const Expr& e) { return m->Mutate(e); });
}
inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) {
std::vector<IterVar> new_dom(rdom.size());
bool changed = false;
for (size_t i = 0; i < rdom.size(); i++) {
IterVar v = rdom[i];
Range r = v->dom;
Expr new_min = m->Mutate(r->min);
Expr new_extent = m->Mutate(r->extent);
if (!r->min.same_as(new_min)) changed = true;
if (!r->extent.same_as(new_extent)) changed = true;
new_dom[i] = IterVarNode::make(
Range::make_by_min_extent(new_min, new_extent),
v->var, v->iter_type, v->thread_tag);
}
if (!changed) {
return rdom;
} else {
return Array<IterVar>(new_dom);
}
}
// Mutate Stmt
#define DISPATCH_TO_MUTATE_STMT(OP) \
set_dispatch<OP>([](const OP* op, const Stmt& s, IRMutator* m) { \
return m->Mutate_(op, s); \
})
Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) {
Expr value = this->Mutate(op->value);
Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return AttrStmt::make(op->node, op->attr_key, value, body);
}
}
Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) {
Expr value = this->Mutate(op->value);
Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return s;
} else {
return LetStmt::make(op->var, value, body);
}
}
Stmt IRMutator::Mutate_(const For *op, const Stmt& s) {
Expr min = this->Mutate(op->min);
Expr extent = this->Mutate(op->extent);
Stmt body = this->Mutate(op->body);
if (min.same_as(op->min) &&
extent.same_as(op->extent) &&
body.same_as(op->body)) {
return s;
} else {
return For::make(
op->loop_var, min, extent, op->for_type, op->device_api, body);
}
}
Stmt IRMutator::Mutate_(const Allocate* op, const Stmt& s) {
IRMutator* m = this;
std::vector<Expr> new_extents;
bool all_extents_unmodified = true;
for (size_t i = 0; i < op->extents.size(); i++) {
new_extents.push_back(m->Mutate(op->extents[i]));
all_extents_unmodified &= new_extents[i].same_as(op->extents[i]);
}
Stmt body = m->Mutate(op->body);
Expr condition = m->Mutate(op->condition);
Expr new_expr;
if (op->new_expr.defined()) {
new_expr = m->Mutate(op->new_expr);
}
if (all_extents_unmodified &&
body.same_as(op->body) &&
condition.same_as(op->condition) &&
new_expr.same_as(op->new_expr)) {
return s;
} else {
return Allocate::make(
op->buffer_var, op->type,
new_extents, condition, body,
new_expr, op->free_function);
}
}
Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Stmt then_case = this->Mutate(op->then_case);
Stmt else_case;
if (op->else_case.defined()) {
else_case = this->Mutate(op->else_case);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(condition, then_case, else_case);
}
}
Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) {
Expr value = this->Mutate(op->value);
Expr index = this->Mutate(op->index);
Expr pred = this->Mutate(op->predicate);
if (value.same_as(op->value) && index.same_as(op->index) && pred.same_as(op->predicate)) {
return s;
} else {
return Store::make(op->buffer_var, value, index, pred);
}
}
Stmt IRMutator::Mutate_(const Provide* op, const Stmt& s) {
auto new_args = MutateArray(op->args, this);
auto new_value = this->Mutate(op->value);
if (op->args.same_as(new_args) && op->value.same_as(new_value)) {
return s;
} else {
return Provide::make(op->func, op->value_index, new_value, new_args);
}
}
Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
IRMutator* m = this;
HalideIR::Internal::Region new_bounds;
bool bounds_changed = false;
// Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) {
Expr old_min = op->bounds[i]->min;
Expr old_extent = op->bounds[i]->extent;
Expr new_min = m->Mutate(old_min);
Expr new_extent = m->Mutate(old_extent);
if (!new_min.same_as(old_min)) bounds_changed = true;
if (!new_extent.same_as(old_extent)) bounds_changed = true;
new_bounds.push_back(
Range::make_by_min_extent(new_min, new_extent));
}
Stmt body = m->Mutate(op->body);
Expr condition = m->Mutate(op->condition);
if (!bounds_changed &&
body.same_as(op->body) &&
condition.same_as(op->condition)) {
return s;
} else {
return Realize::make(op->func, op->value_index,
op->type, new_bounds,
condition, body);
}
}
Stmt IRMutator::Mutate_(const Prefetch* op, const Stmt& s) {
IRMutator* m = this;
HalideIR::Internal::Region new_bounds;
bool bounds_changed = false;
// Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) {
Expr old_min = op->bounds[i]->min;
Expr old_extent = op->bounds[i]->extent;
Expr new_min = m->Mutate(old_min);
Expr new_extent = m->Mutate(old_extent);
if (!new_min.same_as(old_min)) bounds_changed = true;
if (!new_extent.same_as(old_extent)) bounds_changed = true;
new_bounds.push_back(
Range::make_by_min_extent(new_min, new_extent));
}
if (!bounds_changed) {
return s;
} else {
return Prefetch::make(op->func, op->value_index,
op->type, new_bounds);
}
}
Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) {
Stmt first = this->Mutate(op->first);
Stmt rest = this->Mutate(op->rest);
if (first.same_as(op->first) &&
rest.same_as(op->rest)) {
return s;
} else {
return Block::make(first, rest);
}
}
Stmt IRMutator::Mutate_(const AssertStmt *op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Expr message = this->Mutate(op->message);
Stmt body = this->Mutate(op->body);
if (condition.same_as(op->condition) &&
message.same_as(op->message) &&
body.same_as(op->body)) {
return s;
} else {
return AssertStmt::make(condition, message, body);
}
}
Stmt IRMutator::Mutate_(const ProducerConsumer *op, const Stmt& s) {
Stmt body = this->Mutate(op->body);
if (body.same_as(op->body)) {
return s;
} else {
return ProducerConsumer::make(op->func, op->is_producer, body);
}
}
Stmt IRMutator::Mutate_(const Evaluate *op, const Stmt& s) {
Expr v = this->Mutate(op->value);
if (v.same_as(op->value)) {
return s;
} else {
return Evaluate::make(v);
}
}
Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) {
return s;
}
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.DISPATCH_TO_MUTATE_STMT(LetStmt)
.DISPATCH_TO_MUTATE_STMT(AttrStmt)
.DISPATCH_TO_MUTATE_STMT(IfThenElse)
.DISPATCH_TO_MUTATE_STMT(For)
.DISPATCH_TO_MUTATE_STMT(Allocate)
.DISPATCH_TO_MUTATE_STMT(Store)
.DISPATCH_TO_MUTATE_STMT(Free)
.DISPATCH_TO_MUTATE_STMT(AssertStmt)
.DISPATCH_TO_MUTATE_STMT(ProducerConsumer)
.DISPATCH_TO_MUTATE_STMT(Provide)
.DISPATCH_TO_MUTATE_STMT(Realize)
.DISPATCH_TO_MUTATE_STMT(Block)
.DISPATCH_TO_MUTATE_STMT(Evaluate)
.DISPATCH_TO_MUTATE_STMT(Prefetch);
// Mutate Expr
#define DISPATCH_TO_MUTATE_EXPR(OP) \
set_dispatch<OP>([](const OP* op, const Expr& e, IRMutator* m) { \
return m->Mutate_(op, e); \
})
Expr IRMutator::Mutate_(const Variable *op, const Expr& e) {
return e;
}
Expr IRMutator::Mutate_(const Load *op, const Expr& e) {
Expr index = this->Mutate(op->index);
Expr pred = this->Mutate(op->predicate);
if (index.same_as(op->index) && pred.same_as(op->predicate)) {
return e;
} else {
return Load::make(op->type, op->buffer_var, index, pred);
}
}
Expr IRMutator::Mutate_(const Let *op, const Expr& e) {
Expr value = this->Mutate(op->value);
Expr body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return e;
} else {
return Let::make(op->var, value, body);
}
}
Expr IRMutator::Mutate_(const Call* op, const Expr& e) {
auto new_args = MutateArray(op->args, this);
if (op->args.same_as(new_args)) {
return e;
} else {
return Call::make(op->type, op->name, new_args, op->call_type,
op->func, op->value_index);
}
}
#define DEFINE_BIOP_EXPR_MUTATE_(OP) \
Expr IRMutator::Mutate_(const OP* op, const Expr& e) { \
Expr a = this->Mutate(op->a); \
Expr b = this->Mutate(op->b); \
if (a.same_as(op->a) && \
b.same_as(op->b)) { \
return e; \
} else { \
return OP::make(a, b); \
} \
}
DEFINE_BIOP_EXPR_MUTATE_(Add)
DEFINE_BIOP_EXPR_MUTATE_(Sub)
DEFINE_BIOP_EXPR_MUTATE_(Mul)
DEFINE_BIOP_EXPR_MUTATE_(Div)
DEFINE_BIOP_EXPR_MUTATE_(Mod)
DEFINE_BIOP_EXPR_MUTATE_(Min)
DEFINE_BIOP_EXPR_MUTATE_(Max)
DEFINE_BIOP_EXPR_MUTATE_(EQ)
DEFINE_BIOP_EXPR_MUTATE_(NE)
DEFINE_BIOP_EXPR_MUTATE_(LT)
DEFINE_BIOP_EXPR_MUTATE_(LE)
DEFINE_BIOP_EXPR_MUTATE_(GT)
DEFINE_BIOP_EXPR_MUTATE_(GE)
DEFINE_BIOP_EXPR_MUTATE_(And)
DEFINE_BIOP_EXPR_MUTATE_(Or)
Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) {
Array<IterVar> new_axis = MutateIterVarArr(op->axis, this);
Array<Expr> new_source = MutateArray(op->source, this);
Expr new_cond = this->Mutate(op->condition);
if (op->axis.same_as(new_axis) &&
op->source.same_as(new_source) &&
op->condition.same_as(new_cond)) {
return e;
} else {
return Reduce::make(
op->combiner, new_source, new_axis, new_cond, op->value_index);
}
}
Expr IRMutator::Mutate_(const Cast *op, const Expr& e) {
Expr value = this->Mutate(op->value);
if (value.same_as(op->value)) {
return e;
} else {
return Cast::make(op->type, value);
}
}
Expr IRMutator::Mutate_(const Not *op, const Expr& e) {
Expr a = this->Mutate(op->a);
if (a.same_as(op->a)) {
return e;
} else {
return Not::make(a);
}
}
Expr IRMutator::Mutate_(const Select *op, const Expr& e) {
Expr cond = this->Mutate(op->condition);
Expr t = this->Mutate(op->true_value);
Expr f = this->Mutate(op->false_value);
if (cond.same_as(op->condition) &&
t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return e;
} else {
return Select::make(cond, t, f);
}
}
Expr IRMutator::Mutate_(const Ramp *op, const Expr& e) {
Expr base = this->Mutate(op->base);
Expr stride = this->Mutate(op->stride);
if (base.same_as(op->base) &&
stride.same_as(op->stride)) {
return e;
} else {
return Ramp::make(base, stride, op->lanes);
}
}
Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) {
Expr value = this->Mutate(op->value);
if (value.same_as(op->value)) {
return e;
} else {
return Broadcast::make(value, op->lanes);
}
}
Expr IRMutator::Mutate_(const Shuffle *op, const Expr& e) {
auto new_vec = MutateArray(op->vectors, this);
if (new_vec.same_as(op->vectors)) {
return e;
} else {
return Shuffle::make(new_vec, op->indices);
}
}
#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \
Expr IRMutator::Mutate_(const OP *op, const Expr& e) { \
return e; \
}
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm)
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(Variable)
.DISPATCH_TO_MUTATE_EXPR(Load)
.DISPATCH_TO_MUTATE_EXPR(Let)
.DISPATCH_TO_MUTATE_EXPR(Call)
.DISPATCH_TO_MUTATE_EXPR(Add)
.DISPATCH_TO_MUTATE_EXPR(Sub)
.DISPATCH_TO_MUTATE_EXPR(Mul)
.DISPATCH_TO_MUTATE_EXPR(Div)
.DISPATCH_TO_MUTATE_EXPR(Mod)
.DISPATCH_TO_MUTATE_EXPR(Min)
.DISPATCH_TO_MUTATE_EXPR(Max)
.DISPATCH_TO_MUTATE_EXPR(EQ)
.DISPATCH_TO_MUTATE_EXPR(NE)
.DISPATCH_TO_MUTATE_EXPR(LT)
.DISPATCH_TO_MUTATE_EXPR(LE)
.DISPATCH_TO_MUTATE_EXPR(GT)
.DISPATCH_TO_MUTATE_EXPR(GE)
.DISPATCH_TO_MUTATE_EXPR(And)
.DISPATCH_TO_MUTATE_EXPR(Or)
.DISPATCH_TO_MUTATE_EXPR(Reduce)
.DISPATCH_TO_MUTATE_EXPR(Cast)
.DISPATCH_TO_MUTATE_EXPR(Not)
.DISPATCH_TO_MUTATE_EXPR(Select)
.DISPATCH_TO_MUTATE_EXPR(Ramp)
.DISPATCH_TO_MUTATE_EXPR(Broadcast)
.DISPATCH_TO_MUTATE_EXPR(IntImm)
.DISPATCH_TO_MUTATE_EXPR(UIntImm)
.DISPATCH_TO_MUTATE_EXPR(FloatImm)
.DISPATCH_TO_MUTATE_EXPR(StringImm)
.DISPATCH_TO_MUTATE_EXPR(Shuffle);
} // namespace ir
} // namespace tvm