| /*! |
| * Copyright (c) 2016 by Contributors |
| * \file ir.cc |
| */ |
| #include <tvm/base.h> |
| #include <tvm/expr.h> |
| #include <tvm/ir.h> |
| #include <tvm/ir_pass.h> |
| #include <ir/IR.h> |
| #include <ir/IRPrinter.h> |
| #include <memory> |
| #include "../pass/ir_util.h" |
| |
| namespace HalideIR { |
| namespace Internal { |
| |
| using tvm::ir::CommReducerNode; |
| using tvm::ir::Reduce; |
| using tvm::ir::AttrStmt; |
| |
| template<> |
| void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const { |
| LOG(FATAL) << "Reduce do not work with old Visitor, use IRFunctor style visitor"; |
| } |
| |
| TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) |
| .set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) { |
| p->stream << "reduce(combiner=" |
| << op->combiner; |
| p->stream << ", source=" << op->source; |
| p->stream << ", axis=" << op->axis; |
| p->stream << ", where=" << op->condition; |
| p->stream << ", value_index=" << op->value_index; |
| p->stream << ")"; |
| }); |
| |
| TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) |
| .set_dispatch<CommReducerNode>([](const CommReducerNode *op, IRPrinter *p) { |
| p->stream << "comm_reducer(result=" << op->result |
| << ", lhs=" << op->lhs |
| << ", rhs=" << op->rhs |
| << ", identity_element=" << op->identity_element |
| << ")"; |
| }); |
| } // namespace Internal |
| } // namespace HalideIR |
| |
| namespace tvm { |
| namespace ir { |
| |
| CommReducer CommReducerNode::make(Array<Var> lhs, |
| Array<Var> rhs, |
| Array<Expr> result, |
| Array<Expr> identity_element) { |
| auto node = make_node<CommReducerNode>(); |
| node->lhs = lhs; |
| node->rhs = rhs; |
| node->result = result; |
| node->identity_element = identity_element; |
| return CommReducer(node); |
| } |
| |
| Array<Expr> CommReducerNode::operator()(Array<Expr> a, Array<Expr> b) const { |
| CHECK_EQ(a.size(), b.size()); |
| CHECK_EQ(lhs.size(), a.size()); |
| CHECK_EQ(rhs.size(), b.size()); |
| Map<Var, Expr> value_map; |
| for (size_t i = 0; i < a.size(); ++i) { |
| value_map.Set(lhs[i], a[i]); |
| value_map.Set(rhs[i], b[i]); |
| } |
| return UpdateArray(result, [&value_map] (const Expr& e) { |
| return Substitute(e, value_map); |
| }); |
| } |
| |
| Expr Reduce::make(CommReducer combiner, Array<Expr> source, |
| Array<IterVar> axis, Expr condition, int value_index) { |
| for (size_t i = 0; i < axis.size(); ++i) { |
| CHECK_EQ(axis[i]->iter_type, kCommReduce) |
| << "Can only take axis created by reduce_axis"; |
| } |
| if (!condition.defined()) { |
| condition = const_true(); |
| } |
| auto n = make_node<Reduce>(); |
| CHECK(source.defined()); |
| for (size_t i = 0; i < axis.size(); ++i) { |
| CHECK(axis[i].defined()); |
| } |
| n->type = source[value_index].type(); |
| n->combiner = std::move(combiner); |
| n->source = std::move(source); |
| n->axis = std::move(axis); |
| n->condition = condition; |
| n->value_index = value_index; |
| return Expr(n); |
| } |
| |
| TVM_REGISTER_NODE_TYPE(CommReducerNode); |
| TVM_REGISTER_NODE_TYPE(Reduce); |
| TVM_REGISTER_NODE_TYPE(AttrStmt); |
| |
| TVM_REGISTER_NODE_TYPE(FloatImm); |
| TVM_REGISTER_NODE_TYPE(IntImm); |
| TVM_REGISTER_NODE_TYPE(UIntImm); |
| TVM_REGISTER_NODE_TYPE(StringImm); |
| TVM_REGISTER_NODE_TYPE(Cast); |
| TVM_REGISTER_NODE_TYPE(Variable); |
| TVM_REGISTER_NODE_TYPE(Add); |
| TVM_REGISTER_NODE_TYPE(Sub); |
| TVM_REGISTER_NODE_TYPE(Mul); |
| TVM_REGISTER_NODE_TYPE(Div); |
| TVM_REGISTER_NODE_TYPE(Mod); |
| TVM_REGISTER_NODE_TYPE(Min); |
| TVM_REGISTER_NODE_TYPE(Max); |
| TVM_REGISTER_NODE_TYPE(EQ); |
| TVM_REGISTER_NODE_TYPE(NE); |
| TVM_REGISTER_NODE_TYPE(LT); |
| TVM_REGISTER_NODE_TYPE(LE); |
| TVM_REGISTER_NODE_TYPE(GT); |
| TVM_REGISTER_NODE_TYPE(GE); |
| TVM_REGISTER_NODE_TYPE(And); |
| TVM_REGISTER_NODE_TYPE(Or); |
| TVM_REGISTER_NODE_TYPE(Not); |
| TVM_REGISTER_NODE_TYPE(Select); |
| TVM_REGISTER_NODE_TYPE(Load); |
| TVM_REGISTER_NODE_TYPE(Ramp); |
| TVM_REGISTER_NODE_TYPE(Broadcast); |
| TVM_REGISTER_NODE_TYPE(Call); |
| TVM_REGISTER_NODE_TYPE(Let); |
| TVM_REGISTER_NODE_TYPE(LetStmt); |
| TVM_REGISTER_NODE_TYPE(AssertStmt); |
| TVM_REGISTER_NODE_TYPE(ProducerConsumer); |
| TVM_REGISTER_NODE_TYPE(For); |
| TVM_REGISTER_NODE_TYPE(Store); |
| TVM_REGISTER_NODE_TYPE(Provide); |
| TVM_REGISTER_NODE_TYPE(Allocate); |
| TVM_REGISTER_NODE_TYPE(Free); |
| TVM_REGISTER_NODE_TYPE(Realize); |
| TVM_REGISTER_NODE_TYPE(Block); |
| TVM_REGISTER_NODE_TYPE(IfThenElse); |
| TVM_REGISTER_NODE_TYPE(Evaluate); |
| |
| } // namespace ir |
| } // namespace tvm |