blob: 875258540584cb65849006db077f89ca5930529d [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file ir.cc
*/
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <ir/IR.h>
#include <ir/IRPrinter.h>
#include <memory>
#include "../pass/ir_util.h"
namespace HalideIR {
namespace Internal {
using tvm::ir::CommReducerNode;
using tvm::ir::Reduce;
using tvm::ir::AttrStmt;
template<>
void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
LOG(FATAL) << "Reduce do not work with old Visitor, use IRFunctor style visitor";
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
p->stream << "reduce(combiner="
<< op->combiner;
p->stream << ", source=" << op->source;
p->stream << ", axis=" << op->axis;
p->stream << ", where=" << op->condition;
p->stream << ", value_index=" << op->value_index;
p->stream << ")";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<CommReducerNode>([](const CommReducerNode *op, IRPrinter *p) {
p->stream << "comm_reducer(result=" << op->result
<< ", lhs=" << op->lhs
<< ", rhs=" << op->rhs
<< ", identity_element=" << op->identity_element
<< ")";
});
} // namespace Internal
} // namespace HalideIR
namespace tvm {
namespace ir {
CommReducer CommReducerNode::make(Array<Var> lhs,
Array<Var> rhs,
Array<Expr> result,
Array<Expr> identity_element) {
auto node = make_node<CommReducerNode>();
node->lhs = lhs;
node->rhs = rhs;
node->result = result;
node->identity_element = identity_element;
return CommReducer(node);
}
Array<Expr> CommReducerNode::operator()(Array<Expr> a, Array<Expr> b) const {
CHECK_EQ(a.size(), b.size());
CHECK_EQ(lhs.size(), a.size());
CHECK_EQ(rhs.size(), b.size());
Map<Var, Expr> value_map;
for (size_t i = 0; i < a.size(); ++i) {
value_map.Set(lhs[i], a[i]);
value_map.Set(rhs[i], b[i]);
}
return UpdateArray(result, [&value_map] (const Expr& e) {
return Substitute(e, value_map);
});
}
Expr Reduce::make(CommReducer combiner, Array<Expr> source,
Array<IterVar> axis, Expr condition, int value_index) {
for (size_t i = 0; i < axis.size(); ++i) {
CHECK_EQ(axis[i]->iter_type, kCommReduce)
<< "Can only take axis created by reduce_axis";
}
if (!condition.defined()) {
condition = const_true();
}
auto n = make_node<Reduce>();
CHECK(source.defined());
for (size_t i = 0; i < axis.size(); ++i) {
CHECK(axis[i].defined());
}
n->type = source[value_index].type();
n->combiner = std::move(combiner);
n->source = std::move(source);
n->axis = std::move(axis);
n->condition = condition;
n->value_index = value_index;
return Expr(n);
}
TVM_REGISTER_NODE_TYPE(CommReducerNode);
TVM_REGISTER_NODE_TYPE(Reduce);
TVM_REGISTER_NODE_TYPE(AttrStmt);
TVM_REGISTER_NODE_TYPE(FloatImm);
TVM_REGISTER_NODE_TYPE(IntImm);
TVM_REGISTER_NODE_TYPE(UIntImm);
TVM_REGISTER_NODE_TYPE(StringImm);
TVM_REGISTER_NODE_TYPE(Cast);
TVM_REGISTER_NODE_TYPE(Variable);
TVM_REGISTER_NODE_TYPE(Add);
TVM_REGISTER_NODE_TYPE(Sub);
TVM_REGISTER_NODE_TYPE(Mul);
TVM_REGISTER_NODE_TYPE(Div);
TVM_REGISTER_NODE_TYPE(Mod);
TVM_REGISTER_NODE_TYPE(Min);
TVM_REGISTER_NODE_TYPE(Max);
TVM_REGISTER_NODE_TYPE(EQ);
TVM_REGISTER_NODE_TYPE(NE);
TVM_REGISTER_NODE_TYPE(LT);
TVM_REGISTER_NODE_TYPE(LE);
TVM_REGISTER_NODE_TYPE(GT);
TVM_REGISTER_NODE_TYPE(GE);
TVM_REGISTER_NODE_TYPE(And);
TVM_REGISTER_NODE_TYPE(Or);
TVM_REGISTER_NODE_TYPE(Not);
TVM_REGISTER_NODE_TYPE(Select);
TVM_REGISTER_NODE_TYPE(Load);
TVM_REGISTER_NODE_TYPE(Ramp);
TVM_REGISTER_NODE_TYPE(Broadcast);
TVM_REGISTER_NODE_TYPE(Call);
TVM_REGISTER_NODE_TYPE(Let);
TVM_REGISTER_NODE_TYPE(LetStmt);
TVM_REGISTER_NODE_TYPE(AssertStmt);
TVM_REGISTER_NODE_TYPE(ProducerConsumer);
TVM_REGISTER_NODE_TYPE(For);
TVM_REGISTER_NODE_TYPE(Store);
TVM_REGISTER_NODE_TYPE(Provide);
TVM_REGISTER_NODE_TYPE(Allocate);
TVM_REGISTER_NODE_TYPE(Free);
TVM_REGISTER_NODE_TYPE(Realize);
TVM_REGISTER_NODE_TYPE(Block);
TVM_REGISTER_NODE_TYPE(IfThenElse);
TVM_REGISTER_NODE_TYPE(Evaluate);
} // namespace ir
} // namespace tvm