| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \brief Hybrid computation rule. |
| * \file hybrid_op.cc |
| */ |
| #include "hybrid_op.h" |
| |
| #include <tvm/arith/analyzer.h> |
| #include <tvm/runtime/registry.h> |
| #include <tvm/te/operation.h> |
| #include <tvm/tir/analysis.h> |
| #include <tvm/tir/expr.h> |
| #include <tvm/tir/op.h> |
| #include <tvm/tir/stmt_functor.h> |
| |
| #include <string> |
| #include <unordered_set> |
| #include <utility> |
| |
| #include "op_util.h" |
| |
| namespace tvm { |
| namespace te { |
| using namespace tir; |
| // HybridOpNode |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<HybridOpNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const HybridOpNode*>(node.get()); |
| p->stream << "hybrid(" << op->name << ", " << op << ")"; |
| }); |
| |
| TVM_REGISTER_NODE_TYPE(HybridOpNode); |
| |
| int HybridOpNode::num_outputs() const { return static_cast<int>(outputs.size()); } |
| |
| Array<IterVar> HybridOpNode::root_iter_vars() const { return this->axis; } |
| |
| DataType HybridOpNode::output_dtype(size_t i) const { return outputs[i]->dtype; } |
| |
| Array<PrimExpr> HybridOpNode::output_shape(size_t i) const { return outputs[i]->shape; } |
| |
| HybridOp::HybridOp(std::string name, std::string tag, Map<String, ObjectRef> attrs, |
| Array<Tensor> inputs, Array<Tensor> outputs, Stmt body) { |
| if (!attrs.defined()) { |
| attrs = Map<String, ObjectRef>(); |
| } |
| auto n = make_object<HybridOpNode>(); |
| n->name = std::move(name); |
| n->tag = std::move(tag); |
| n->attrs = std::move(attrs); |
| n->inputs = std::move(inputs); |
| n->outputs = std::move(outputs); |
| n->axis = te::GatherLoopVars(body); |
| n->body = std::move(body); |
| data_ = std::move(n); |
| } |
| |
| TVM_REGISTER_GLOBAL("te.HybridOp") |
| .set_body_typed([](std::string name, std::string tag, Map<String, ObjectRef> attrs, |
| Array<Tensor> inputs, Array<Tensor> outputs, |
| Stmt body) { return HybridOp(name, tag, attrs, inputs, outputs, body); }); |
| |
| Array<Tensor> HybridOpNode::InputTensors() const { |
| // Because input tensors could be potentially inlined into hybrid scripts, |
| // we need to check if all input tensors are used in the body. |
| std::unordered_set<Tensor> orig_inputs; |
| for (auto t : inputs) { |
| orig_inputs.insert(t); |
| } |
| std::unordered_set<Tensor> visited; |
| Array<Tensor> curr_inputs; |
| tir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) { |
| if (auto* pload = n.as<tir::ProducerLoadNode>()) { |
| Tensor t = Downcast<Tensor>(pload->producer); |
| if (orig_inputs.count(t) && !visited.count(t)) { |
| curr_inputs.push_back(t); |
| visited.insert(t); |
| } |
| } |
| }); |
| return curr_inputs; |
| } |
| |
| Operation HybridOpNode::ReplaceInputs(const Operation& self, |
| const std::unordered_map<Tensor, Tensor>& rmap) const { |
| CHECK_EQ(self.operator->(), this); |
| auto n = make_object<HybridOpNode>(*this); |
| n->body = te::ReplaceTensor(this->body, rmap); |
| for (size_t i = 0; i < n->inputs.size(); ++i) { |
| Tensor t = n->inputs[i]; |
| if (rmap.count(t)) { |
| n->inputs.Set(i, rmap.at(t)); |
| } |
| } |
| |
| if (body.same_as(n->body) && inputs.same_as(n->inputs)) { |
| return self; |
| } else { |
| return Operation(n); |
| } |
| } |
| |
| void HybridOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, |
| const std::unordered_map<const VarNode*, IntSet>& dom_map, |
| std::unordered_map<Tensor, TensorDom>* out_dom_map) const { |
| auto curr_inputs = InputTensors(); |
| for (Tensor t : curr_inputs) { |
| auto it = out_dom_map->find(t); |
| if (it == out_dom_map->end()) continue; |
| TensorDom& dom = it->second; |
| for (size_t i = 0; i < t->shape.size(); ++i) { |
| dom.data[i].emplace_back( |
| IntSet::FromRange(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i]))); |
| } |
| } |
| } |
| |
| void HybridOpNode::GatherBound(const Operation& self, |
| const std::unordered_map<Tensor, TensorDom>& tensor_dom, |
| std::unordered_map<IterVar, Range>* out_dom_map) const { |
| for (auto iter_var : axis) { |
| CHECK(!out_dom_map->count(iter_var)); |
| out_dom_map->operator[](iter_var) = iter_var->dom; |
| } |
| } |
| |
| Stmt HybridOpNode::BuildRealize(const Stage& stage, |
| const std::unordered_map<IterVar, Range>& realize_map, |
| const Stmt& body) const { |
| // TODO(@were): Add attribute inject here and remove it from hybrid parser. |
| CHECK_EQ(stage->op.get(), this); |
| Stmt realize_body = body; |
| for (int k = 0; k < num_outputs(); ++k) { |
| Tensor t = stage->op.output(k); |
| Region bounds; |
| for (size_t i = 0; i < t->shape.size(); ++i) { |
| bounds.push_back(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i])); |
| } |
| realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body); |
| } |
| return realize_body; |
| } |
| |
| Stmt HybridOpNode::BuildProvide(const Stage& stage, |
| const std::unordered_map<IterVar, Range>& dom_map, |
| bool debug_keep_trivial_loop) const { |
| CHECK_EQ(stage->op.operator->(), this); |
| Stmt ret = AttrStmt(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); |
| std::unordered_map<Tensor, Tensor> rmap; |
| for (int i = 0; i < this->num_outputs(); ++i) { |
| rmap[outputs[i]] = stage->op.output(i); |
| } |
| auto n = make_object<HybridOpNode>(*this); |
| /* This is a story little bit complicated. |
| * The following two lines of codes replace output tensors' usage. |
| * This is the simplest way I (@were) can come up with to glue |
| * hybrid operation node to TVM op system. |
| * In hybrid script all the tensors, especially the output tensors, |
| * have their own names defined by the users. However, In TVM |
| * conventional ops: |
| * 1. Output tensors refer the corresponding op node so that the output |
| * tensors have the same names as the operation produces them. |
| * 2. Once OpNode is wrapped up by an Operation node, it is finalized. |
| * Later access will be from a const OpNode*. |
| * This is a chicken-egg paradox. It is impossible to put the output |
| * tensors into the function body without forming the op node. The |
| * function body is immutable after the node is formed. |
| * |
| * Finally, I decided to resolve this issue "lazily". During the |
| * pipeline of compilation, this stage is a very preliminary stage. |
| * Technically, it is before Phase 0. The actual tensors will be replaced |
| * here. |
| * Thus, the operation body is slightly different from the Phase 0 body. |
| * This is a major difference that HybridOpNode is NOT the same as |
| * ExternOpNode. |
| * */ |
| ret = te::ReplaceTensor(ret, rmap); |
| ret = te::ReplaceProvideTensor(ret, rmap); |
| |
| ret = te::ApplySchedule(stage, dom_map, ret); |
| return ret; |
| } |
| |
| Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, |
| Stmt stmt) { |
| class LoopSpliter : public StmtExprMutator { |
| PrimExpr factor; |
| const VarNode* parent; |
| IterVar inner, outer; |
| |
| public: |
| bool splitted; |
| LoopSpliter(const SplitNode* split, const std::unordered_map<IterVar, Range>& dom_map) |
| : factor(split->factor), splitted(false) { |
| parent = split->parent->var.get(); |
| |
| auto& inner_ = split->inner; |
| CHECK(dom_map.count(inner_)); |
| auto& inner_dom = dom_map.find(inner_)->second; |
| CHECK(is_const_int(inner_dom->min, 0)); |
| |
| auto& outer_ = split->outer; |
| CHECK(dom_map.count(outer_)); |
| auto& outer_dom = dom_map.find(outer_)->second; |
| CHECK(is_const_int(outer_dom->min, 0)); |
| |
| inner = IterVar(inner_dom, inner_->var, inner_->iter_type); |
| outer = IterVar(outer_dom, outer_->var, outer_->iter_type); |
| } |
| |
| Stmt VisitStmt_(const ForNode* op) final { |
| if (op->loop_var.get() == parent) { |
| std::unordered_map<const VarNode*, PrimExpr> rmap; |
| rmap[op->loop_var.get()] = inner + outer * factor; |
| Stmt ret = tir::Substitute(op->body, rmap); |
| PrimExpr cond = likely(outer * factor < (op->extent - inner)); |
| ret = IfThenElse(cond, ret); |
| ret = For(inner->var, PrimExpr(0), inner->dom->extent, |
| IterVarTypeToForType(inner->iter_type), op->device_api, ret); |
| ret = For(outer->var, PrimExpr(0), outer->dom->extent, |
| IterVarTypeToForType(outer->iter_type), op->device_api, ret); |
| splitted = true; |
| return ret; |
| } |
| return StmtExprMutator::VisitStmt_(op); |
| } |
| }; |
| |
| class LoopFuser : public StmtExprMutator { |
| const IterVar& parent; |
| const VarNode* inner; |
| const VarNode* outer; |
| bool under_outer; |
| PrimExpr extent; |
| |
| public: |
| bool fused; |
| explicit LoopFuser(const FuseNode* fuse_) |
| : parent(fuse_->fused), |
| inner(fuse_->inner->var.get()), |
| outer(fuse_->outer->var.get()), |
| under_outer(false), |
| extent(0), |
| fused(false) {} |
| |
| // TODO(@were): Handle imperfect loops |
| Stmt VisitStmt_(const ForNode* op) final { |
| if (op->loop_var.get() == inner) { |
| CHECK(under_outer); |
| std::unordered_map<const VarNode*, PrimExpr> rmap; |
| rmap[op->loop_var.get()] = indexmod(parent, op->extent); |
| extent = op->extent; |
| fused = true; |
| return tir::Substitute(op->body, rmap); |
| } else if (op->loop_var.get() == outer) { |
| under_outer = true; |
| Stmt body = this->VisitStmt(op->body); |
| std::unordered_map<const VarNode*, PrimExpr> rmap; |
| rmap[op->loop_var.get()] = indexdiv(parent, extent); |
| body = tir::Substitute(body, rmap); |
| under_outer = false; |
| return For(parent->var, PrimExpr(0), extent * op->extent, op->for_type, op->device_api, |
| body); |
| } else if (under_outer) { |
| Stmt body = this->VisitStmt(op->body); |
| std::unordered_map<const VarNode*, PrimExpr> rmap; |
| rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent); |
| body = tir::Substitute(body, rmap); |
| extent = extent * op->extent; |
| return body; |
| } |
| return StmtExprMutator::VisitStmt_(op); |
| } |
| }; |
| |
| for (auto& rel : stage->relations) { |
| if (const SplitNode* split = rel.as<SplitNode>()) { |
| LoopSpliter Spliter(split, dom_map); |
| stmt = Spliter(stmt); |
| CHECK(Spliter.splitted); |
| } else if (const FuseNode* fuse = rel.as<FuseNode>()) { |
| LoopFuser Fuser(fuse); |
| stmt = Fuser(stmt); |
| CHECK(Fuser.fused); |
| } |
| } |
| |
| return stmt; |
| } |
| |
| Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_map<IterVar, IterVar>& rebased, |
| Stmt stmt) { |
| class LoopAnnotator : public StmtMutator { |
| const VarNode* var; |
| const IterVarAttr& attr; |
| |
| public: |
| LoopAnnotator(const VarNode* var_, const IterVarAttr& attr_) : var(var_), attr(attr_) {} |
| |
| Stmt VisitStmt_(const ForNode* op) final { |
| tir::ExprDeepEqual expr_equal; |
| |
| if (op->loop_var.get() == var) { |
| if (attr->bind_thread.defined()) { |
| const auto& iter_var = attr->bind_thread; |
| if (iter_var->dom.defined()) { |
| CHECK(is_const_int(iter_var->dom->min, 0)); |
| CHECK(expr_equal(iter_var->dom->extent, op->extent)) |
| << "Thread extent and loop extent mismatch!\n"; |
| } |
| std::unordered_map<const VarNode*, PrimExpr> rmap; |
| rmap[op->loop_var.get()] = iter_var; |
| Stmt body = tir::Substitute(op->body, rmap); |
| return AttrStmt(iter_var, "thread_extent", op->extent, body); |
| } else { |
| return For(op->loop_var, op->min, op->extent, IterVarTypeToForType(attr->iter_type), |
| op->device_api, op->body); |
| } |
| } |
| return StmtMutator::VisitStmt_(op); |
| } |
| }; |
| |
| for (auto& iter_var : stage->leaf_iter_vars) { |
| bool need_change = false; |
| int found = 0; |
| |
| const IterVar& actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var; |
| const VarNode* var = actual->var.get(); |
| ForType expected = IterVarTypeToForType(iter_var->iter_type); |
| IterVarAttr attr; |
| if (stage->iter_var_attrs.count(iter_var)) { |
| attr = stage->iter_var_attrs[iter_var]; |
| expected = IterVarTypeToForType(attr->iter_type); |
| } |
| |
| PostOrderVisit(stmt, [&found, &var, &attr, &expected, &need_change](const ObjectRef& node) { |
| if (const ForNode* op = node.as<ForNode>()) { |
| if (op->loop_var.get() == var) { |
| ++found; |
| need_change = expected != op->for_type || (attr.defined() && attr->bind_thread.defined()); |
| } |
| } |
| }); |
| |
| CHECK_EQ(found, 1) << " iter var should be found exactly once!"; |
| if (need_change) { |
| stmt = LoopAnnotator(var, attr)(std::move(stmt)); |
| } |
| } |
| return stmt; |
| } |
| |
| Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, |
| const std::unordered_map<IterVar, IterVar>& rebased, Stmt stmt) { |
| std::vector<const VarNode*> current_order; |
| PostOrderVisit(stmt, [¤t_order](const ObjectRef& node) { |
| if (const ForNode* op = node.as<ForNode>()) current_order.push_back(op->loop_var.get()); |
| }); |
| std::reverse(current_order.begin(), current_order.end()); |
| auto& required_ord = stage->leaf_iter_vars; |
| CHECK_EQ(current_order.size(), required_ord.size()) << "Cannot reorder the loops!"; |
| std::unordered_map<const VarNode*, IterVar> reorder; |
| bool need_reorder = false; |
| for (size_t i = 0; i < current_order.size(); ++i) { |
| auto& current = current_order[i]; |
| const IterVar& iter_var = required_ord[i]; |
| const IterVar& required = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var; |
| CHECK(required->dom.defined() || dom_map.count(required)) << required << "\n"; |
| reorder[current] = required; |
| if (current != required->var.get()) { |
| need_reorder = true; |
| } |
| } |
| |
| class LoopReorder : public StmtMutator { |
| const Stage& stage; |
| const std::unordered_map<IterVar, Range>& dom_map; |
| const std::unordered_map<const VarNode*, IterVar>& reorder; |
| |
| public: |
| LoopReorder(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, |
| const std::unordered_map<const VarNode*, IterVar>& reorder) |
| : stage(stage), dom_map(dom_map), reorder(reorder) {} |
| |
| Stmt VisitStmt_(const ForNode* op) final { |
| // Reorder from in to out |
| Stmt body_ = this->VisitStmt(op->body); |
| CHECK(reorder.count(op->loop_var.get())); |
| auto target = reorder.find(op->loop_var.get())->second; |
| if (body_.same_as(op->body) && op->loop_var.get() == target->var.get()) |
| return GetRef<Stmt>(op); |
| const Stmt& body = op->body.same_as(body_) ? op->body : body_; |
| ForType for_type = IterVarTypeToForType(target->iter_type); |
| if (stage->iter_var_attrs.count(target)) { |
| for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type); |
| } |
| const Range& range = target->dom.defined() ? target->dom : dom_map.find(target)->second; |
| return For(target->var, range->min, range->extent, for_type, DeviceAPI::None, body); |
| } |
| }; |
| |
| if (need_reorder) return LoopReorder(stage, dom_map, reorder)(stmt); |
| |
| return stmt; |
| } |
| |
| Stmt ApplySchedule(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map, |
| Stmt stmt) { |
| // TODO(@were): Eliminate loop rebase in script parser and move the burden here |
| // Gather rebased variables |
| std::unordered_map<IterVar, IterVar> rebased; |
| for (auto rel : stage->relations) { |
| if (const auto* rebase = rel.as<RebaseNode>()) { |
| rebased[rebase->rebased] = rebase->parent; |
| CHECK(rebase->parent->dom.defined()); |
| CHECK(dom_map.count(rebase->rebased)); |
| } |
| } |
| stmt = ApplyLoopShapes(stage, dom_map, stmt); |
| stmt = ApplyLoopOrder(stage, dom_map, rebased, stmt); |
| stmt = ApplyLoopAnnotations(stage, rebased, stmt); |
| return stmt; |
| } |
| |
| std::vector<IterVar> GatherLoopVars(Stmt stmt) { |
| // TODO(@were): Write a comprehensive pass to analyze iter var types |
| std::vector<IterVar> res_; |
| PostOrderVisit(stmt, [&res_](const ObjectRef& node) { |
| if (const ForNode* op = node.as<ForNode>()) { |
| Var loop_var(op->loop_var); |
| Range dom = Range::FromMinExtent(op->min, op->extent); |
| res_.push_back(IterVar(dom, loop_var, ForTypeToIterVarType(op->for_type))); |
| } |
| }); |
| std::reverse(res_.begin(), res_.end()); |
| return res_; |
| } |
| |
| // replacer to replace tensors' usage in Provide |
| class ProviderReplacer : public tir::StmtMutator { |
| public: |
| explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor>& vmap) : vmap_(vmap) {} |
| |
| Stmt VisitStmt_(const tir::ProducerStoreNode* op) final { |
| Tensor t = Downcast<Tensor>(op->producer); |
| auto it = vmap_.find(t); |
| if (it != vmap_.end()) { |
| Stmt ret = tir::ProducerStore(it->second, op->value, op->indices); |
| found = true; |
| return this->VisitStmt(ret); |
| } |
| return StmtMutator::VisitStmt_(op); |
| } |
| |
| // whether it is found. |
| bool found{false}; |
| |
| private: |
| const std::unordered_map<Tensor, Tensor>& vmap_; |
| }; |
| |
| Stmt ReplaceProvideTensor(Stmt stmt, const std::unordered_map<Tensor, Tensor>& replace) { |
| ProviderReplacer repl(replace); |
| Stmt ret = repl(stmt); |
| return repl.found ? ret : stmt; |
| } |
| } // namespace te |
| } // namespace tvm |