| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * Loop unrolling as in Halide pipeline. |
| * \file unroll_loop.cc |
| */ |
| // Unrolls the loop as in Halide pipeline. |
| #include <tvm/arith/analyzer.h> |
| #include <tvm/runtime/registry.h> |
| #include <tvm/tir/expr.h> |
| #include <tvm/tir/op.h> |
| #include <tvm/tir/stmt_functor.h> |
| #include <tvm/tir/transform.h> |
| |
| #include <unordered_map> |
| #include <unordered_set> |
| #include <vector> |
| |
| #include "ir_util.h" |
| |
| namespace tvm { |
| namespace tir { |
| |
| struct UnrollLoopConfigNode : public tvm::AttrsNode<UnrollLoopConfigNode> { |
| int auto_max_step; |
| int auto_max_depth; |
| int auto_max_extent; |
| int explicit_unroll; |
| |
| TVM_DECLARE_ATTRS(UnrollLoopConfigNode, "tir.transform.UnrollLoopConfig") { |
| TVM_ATTR_FIELD(auto_max_step) |
| .describe("Threshold of number of steps in the loop to be automatically unrolled") |
| .set_default(0); |
| TVM_ATTR_FIELD(auto_max_depth) |
| .describe("The maximum nested level of loops that can be automatically unrolled.") |
| .set_default(8); |
| TVM_ATTR_FIELD(auto_max_extent) |
| .describe("The maximum extent of loop that will be unrolled.") |
| .set_default(0); |
| TVM_ATTR_FIELD(explicit_unroll) |
| .describe("Whether to explicitly unroll the loop instead of setting a pragma") |
| .set_default(true); |
| } |
| }; |
| |
| class UnrollLoopConfig : public Attrs { |
| public: |
| TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(UnrollLoopConfig, Attrs, UnrollLoopConfigNode); |
| }; |
| |
| TVM_REGISTER_NODE_TYPE(UnrollLoopConfigNode); |
| TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig); |
| |
| class LoopUnroller : public StmtExprMutator { |
| public: |
| explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent, |
| bool explicit_unroll) |
| : auto_max_step_(auto_max_step), |
| auto_max_depth_(auto_max_depth), |
| auto_max_extent_(auto_max_extent), |
| explicit_unroll_(explicit_unroll) {} |
| |
| Stmt VisitStmt_(const AttrStmtNode* op) final { |
| if (op->attr_key == "pragma_auto_unroll_max_step") { |
| int value = static_cast<int>(Downcast<Integer>(op->value)->value); |
| std::swap(value, auto_max_step_); |
| Stmt ret = this->VisitStmt(op->body); |
| std::swap(value, auto_max_step_); |
| return ret; |
| } else if (op->attr_key == "pragma_unroll_explicit") { |
| bool explicit_unroll = Downcast<Integer>(op->value)->value; |
| std::swap(explicit_unroll, explicit_unroll_); |
| Stmt ret = this->VisitStmt(op->body); |
| std::swap(explicit_unroll, explicit_unroll_); |
| return ret; |
| } else { |
| return StmtExprMutator::VisitStmt_(op); |
| } |
| } |
| |
| Stmt VisitStmt_(const ForNode* op) { |
| Stmt stmt = StmtExprMutator::VisitStmt_(op); |
| op = stmt.as<ForNode>(); |
| int value = GetExtent(op); |
| // condition for auto unroll |
| bool auto_unroll = (op->for_type == ForType::Serial && value >= 0 && normal_loop_depth_ == 0 && |
| unroll_depth_ <= auto_max_depth_); |
| |
| auto_unroll = |
| auto_unroll && (value * step_count_ <= auto_max_step_ || value <= auto_max_extent_); |
| |
| if (op->for_type == ForType::Unrolled) { |
| CHECK_GE(value, 0) << "Cannot unroll non-constant loop"; |
| auto_unroll = true; |
| } |
| |
| if (auto_unroll) { |
| step_count_ *= value; |
| unroll_depth_ += 1; |
| } else { |
| normal_loop_depth_ += 1; |
| } |
| |
| if ((auto_unroll && explicit_unroll_) || |
| // unroll loops with extent = 1, no matter how many steps in body |
| (0 <= value && value <= auto_max_extent_ && auto_max_extent_ == 1)) { |
| return Unroll(op); |
| } else { |
| if (auto_unroll) { |
| if (op->for_type != ForType::Unrolled) { |
| return For(op->loop_var, op->min, op->extent, ForType::Unrolled, op->device_api, |
| op->body); |
| } |
| } |
| return stmt; |
| } |
| } |
| |
| Stmt VisitStmt_(const StoreNode* op) final { |
| ++step_count_; |
| return StmtExprMutator::VisitStmt_(op); |
| } |
| |
| Stmt VisitStmt_(const EvaluateNode* op) final { |
| ++step_count_; |
| return StmtExprMutator::VisitStmt_(op); |
| } |
| |
| Stmt VisitStmt_(const SeqStmtNode* op) final { |
| auto fmutate = [this](const Stmt& s) { |
| int step_count = step_count_; |
| int unroll_depth = unroll_depth_; |
| int normal_loop_depth = normal_loop_depth_; |
| step_count_ = 0; |
| unroll_depth_ = 0; |
| normal_loop_depth_ = 0; |
| Stmt ret = this->VisitStmt(s); |
| step_count_ += step_count; |
| normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_); |
| unroll_depth_ = std::max(unroll_depth_, unroll_depth); |
| return ret; |
| }; |
| return StmtMutator::VisitSeqStmt_(op, false, fmutate); |
| } |
| |
| Stmt Unroll(const ForNode* op) { |
| int value = GetExtent(op); |
| // For loop must have a constant integer extent |
| CHECK_NE(value, -1) << "loop doesn't have a constant integer extent"; |
| if (value == 0) return Evaluate(0); |
| Stmt body = op->body; |
| Map<Var, PrimExpr> vmap; |
| Array<Stmt> unrolled; |
| for (int i = 0; i < value; ++i) { |
| vmap.Set(op->loop_var, op->min + make_const(op->loop_var.dtype(), i)); |
| Stmt step = Substitute(body, vmap); |
| unrolled.push_back(step); |
| } |
| return SeqStmt::Flatten(unrolled); |
| } |
| |
| private: |
| // returns the extent of the loop if it's a constant integer, otherwise return -1 |
| int GetExtent(const ForNode* op) { |
| // constant folding. |
| PrimExpr extent = analyzer_.Simplify(op->extent); |
| const IntImmNode* v1 = extent.as<IntImmNode>(); |
| int value = -1; |
| // integers that do not fit in int32_t are treated as symbolic, |
| // as it's impossible to unroll such large loops |
| if (v1 != nullptr && v1->value <= std::numeric_limits<int>::max()) { |
| value = static_cast<int>(v1->value); |
| } |
| return value; |
| } |
| |
| // maximum number of step to perform auto unroll. |
| int auto_max_step_; |
| int auto_max_depth_; |
| // max extent of loop to auto unroll |
| // this not not count the total steps, only count the number of loops |
| int auto_max_extent_; |
| bool explicit_unroll_; |
| // Number of normal loops in scope |
| int normal_loop_depth_{0}; |
| // number of unrolled cases in current scope. |
| int unroll_depth_{0}; |
| // Number of total steps unrolled |
| int step_count_{0}; |
| // analyzer |
| arith::Analyzer analyzer_; |
| }; |
| |
| Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) { |
| Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth, cfg->auto_max_extent, |
| cfg->explicit_unroll)(stmt); |
| if (!ret.same_as(stmt)) { |
| return ConvertSSA(ret); |
| } else { |
| return ret; |
| } |
| } |
| |
| namespace transform { |
| |
| Pass UnrollLoop() { |
| auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
| auto* n = f.CopyOnWrite(); |
| auto cfg = ctx->GetConfig<UnrollLoopConfig>("tir.UnrollLoop"); |
| if (!cfg.defined()) { |
| cfg = AttrsWithDefaultValues<UnrollLoopConfig>(); |
| } |
| n->body = UnrollLoop(std::move(f->body), cfg.value()); |
| return f; |
| }; |
| return CreatePrimFuncPass(pass_func, 0, "tir.UnrollLoop", {}); |
| } |
| |
| TVM_REGISTER_GLOBAL("tir.transform.UnrollLoop").set_body_typed(UnrollLoop); |
| |
| } // namespace transform |
| |
| } // namespace tir |
| } // namespace tvm |