| /*! |
| * Copyright (c) 2016 by Contributors |
| * \file expr.cc |
| */ |
| #include <tvm/base.h> |
| #include <tvm/expr.h> |
| #include <tvm/ir.h> |
| #include <tvm/ir_operator.h> |
| #include <ir/IRPrinter.h> |
| #include <memory> |
| |
| namespace tvm { |
| |
| using HalideIR::IR::RangeNode; |
| |
| Range::Range(Expr begin, Expr end) |
| : Range(make_node<RangeNode>( |
| begin, |
| is_zero(begin) ? end : (end - begin))) { |
| } |
| |
| Range Range::make_by_min_extent(Expr min, Expr extent) { |
| return Range(make_node<HalideIR::IR::RangeNode>(min, extent)); |
| } |
| |
| IterVar IterVarNode::make(Range dom, Var var, |
| IterVarType t, std::string thread_tag) { |
| NodePtr<IterVarNode> n = make_node<IterVarNode>(); |
| n->dom = dom; |
| n->var = var; |
| n->iter_type = t; |
| n->thread_tag = thread_tag; |
| return IterVar(n); |
| } |
| |
| IterVar thread_axis(Range dom, std::string tag) { |
| return IterVarNode::make( |
| dom, Var(tag), kThreadIndex, tag); |
| } |
| |
| IterVar reduce_axis(Range dom, std::string name) { |
| return IterVarNode::make( |
| dom, Var(name), kCommReduce); |
| } |
| |
| std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*) |
| IRPrinter(os).print(n); |
| return os; |
| } |
| |
| void Dump(const NodeRef& n) { |
| std::cerr << n << "\n"; |
| } |
| |
| Var var(const std::string& name_hint, Type t) { |
| return Var(name_hint, t); |
| } |
| |
| TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) |
| .set_dispatch<IterVarNode>([](const IterVarNode *op, IRPrinter *p) { |
| p->stream << "iter_var("; |
| if (op->var->name_hint.length() != 0) { |
| p->stream << op->var->name_hint << ", "; |
| } |
| if (op->dom.defined()) { |
| p->stream << op->dom; |
| } |
| if (op->thread_tag.length() != 0) { |
| p->stream << ", " << op->thread_tag; |
| } |
| p->stream << ")"; |
| }); |
| |
| TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) |
| .set_dispatch<RangeNode>([](const HalideIR::IR::RangeNode *op, IRPrinter *p) { |
| p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; |
| }); |
| |
| |
| TVM_REGISTER_NODE_TYPE(ArrayNode); |
| TVM_REGISTER_NODE_TYPE(MapNode); |
| TVM_REGISTER_NODE_TYPE(StrMapNode); |
| TVM_REGISTER_NODE_TYPE(RangeNode); |
| TVM_REGISTER_NODE_TYPE(IterVarNode); |
| |
| } // namespace tvm |