blob: 4e7589c3a795328857202b310901b7a2f3455314 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file hoist_if_then_else.cc
*/
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <queue>
#include <unordered_map>
#include <unordered_set>
#include "../../arith/interval_set.h"
#include "../../runtime/thread_storage_scope.h"
#include "ir_util.h"
namespace tvm {
namespace tir {
struct HoistIfThenElseConfigNode : public tvm::AttrsNode<HoistIfThenElseConfigNode> {
bool support_block_scope_hosting;
TVM_DECLARE_ATTRS(HoistIfThenElseConfigNode, "tir.transform.HoistIfThenElseConfig") {
TVM_ATTR_FIELD(support_block_scope_hosting)
.describe("Hoist if cond with block scope variables")
.set_default(false);
}
};
class HoistIfThenElseConfig : public Attrs {
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(HoistIfThenElseConfig, Attrs,
HoistIfThenElseConfigNode);
};
TVM_REGISTER_NODE_TYPE(HoistIfThenElseConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.HoistIfThenElse", HoistIfThenElseConfig);
using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
/*
* This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
* For example, given the following block:
* for (i = 0; i < 3; i++)
* for (j = 0; j < 4; j++)
* for (k = 0; k < 5; k++)
* if (likely(i*2 < 4))
* A[3*i+2j+k] = B[7*i+3j+k]
*
* We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
* Then we hoist IfThenElse stmt by one For stmt each step:
*
* Step 1:
* for (i = 0; i < 3; i++)
* for (j = 0; j < 4; j++)
* if (likely(i*2 < 4))
* for (k = 0; k < 5; k++)
* A[3*i+2j+k] = B[7*i+3j+k]
*
* Step 2:
* for (i = 0; i < 3; i++)
* if (likely(i*2 < 4))
* for (j = 0; j < 4; j++)
* for (k = 0; k < 5; k++)
* A[3*i+2j+k] = B[7*i+3j+k]
*
* In this pass, we only continue detecting possible hoisting chance when visiting For,
* IfThenElse or AttrStmt Node. For example, for the following block:
* for (i = 0; i < 3; i++)
* for (j = 0; j < 4; j++)
* A[i + j] = A[i + j] - 1
* for (k = 0; k < 5; k++)
* if (likely(i*2 < 4))
* A[3*i+2j+k] = B[7*i+3j+k]
*
* Only the For with k variable will be considered and the resulting stmt would be:
* for (i = 0; i < 3; i++)
* for (j = 0; j < 4; j++)
* A[i + j] = A[i + j] - 1
* if (likely(i*2 < 4))
* for (k = 0; k < 5; k++)
* A[3*i+2j+k] = B[7*i+3j+k]
*
* This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
* block won't be optimized:
* for (i = 0; i < 3; i++)
* for (j = 0; j < 4; j++)
* for (k = 0; k < 5; k++)
* if (likely(i*2 < 4))
* A[3*i+2j+k] = B[7*i+3j+k]
* if (likely(j > 2))
* A[i+j+k] = B[i+j+k]
*
*
* This pass do hoisting for Block scope variables also.
* As below:
* Attr(IterVar: threadIdx.x)
* for (i = 0; i < 3; i++)
* for (j = 0; j < 4; j++)
* for (k = 0; k < 5; k++)
* if (likely(threadIdx.x < 3))
* A[3*i+2j+k] = B[7*i+3j+k]
*
* Will be transformed to as below:
* Attr(IterVar: threadIdx.x)
* if (likely(threadIdx.x < 3))
* for (i = 0; i < 3; i++)
* for (j = 0; j < 4; j++)
* for (k = 0; k < 5; k++)
* A[3*i+2j+k] = B[7*i+3j+k]
*
*/
// Select potential candidate IRs that can be hoisted.
class HoistCandidateSelector final : public StmtExprVisitor {
public:
explicit HoistCandidateSelector(bool support_block_scope_hosting)
: support_block_scope_hosting_(support_block_scope_hosting) {
InitRecorder();
}
HoistCandidateSelector() { InitRecorder(); }
void VisitStmt_(const ForNode* op) final {
// If already recording complete,
// then stop tracing
if (RecordingComplete()) {
return;
}
// Check if it is first for loop, then start the recorder
StartOrAddRecord(GetRef<ObjectRef>(op));
StmtExprVisitor::VisitStmt_(op);
RemoveRecord(GetRef<ObjectRef>(op));
}
void VisitStmt_(const SeqStmtNode* op) final {
// If SeqStmt is encountered in the middle of recording
// then need to purge all, as it can not be hoisted
if (IsRecordingOn()) {
ResetRecorderInternal();
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AttrStmtNode* op) final {
// Maintain list of all vars in AttrStmt
// To stop hoisting if any of the block variables are used.
//
// In case we want to use hoisting in between certain passes
// which have interdependencies of the postioning of if nodes with scope var
// it is better to disable this section
if (support_block_scope_hosting_) {
if (IsRecordingOn()) {
StartOrAddRecord(GetRef<ObjectRef>(op));
StmtExprVisitor::VisitStmt_(op);
RemoveRecord(GetRef<ObjectRef>(op));
return;
} else {
return StmtExprVisitor::VisitStmt_(op);
}
}
UpdateAttrVarList(op);
StmtExprVisitor::VisitStmt_(op);
RemoveAttrVarList(op);
}
void VisitStmt_(const IfThenElseNode* op) final {
if (!IsRecordingOn()) {
StmtExprVisitor::VisitStmt_(op);
return;
}
is_if_cond_ = true;
StmtExprVisitor::VisitExpr(op->condition);
is_if_cond_ = false;
if (CheckValidIf()) {
// Check corresponding for loop
int match_for_loop_pos = -1;
for (auto var : if_var_list_) {
for (int i = 0; i < static_cast<int>(ordered_list_.size()); ++i) {
if ((ordered_list_[i] == var_for_map_[var]) || (ordered_list_[i] == var)) {
if (match_for_loop_pos < i) {
match_for_loop_pos = i;
}
}
}
}
// If none of the for loop has the matching loop variable as if condition,
// then the if node need to be hoisted on top of all, provided no parent loop exists.
int target_for_pos = GetNextLoopPos(match_for_loop_pos);
// Check if valid position
if (target_for_pos >= 0) {
StopAndAddRecord(static_cast<const ForNode*>(ordered_list_[target_for_pos]), op);
if_var_list_.clear();
return;
}
}
if_var_list_.clear();
StmtExprVisitor::VisitStmt_(op);
StopRecording();
}
void VisitExpr_(const VarNode* op) final {
if (is_if_cond_) {
if_var_list_.emplace_back(op);
}
}
HoistForIfTuple hoist_for_if_recorder;
void ResetRecorder() {
ResetRecorderInternal();
// Reset Block scope vars also here
attr_var_list_.clear();
}
bool RecordingComplete() { return std::get<0>(hoist_for_if_recorder); }
const ForNode* GetTargetForNode() { return std::get<1>(hoist_for_if_recorder); }
const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }
private:
void ResetRecorderInternal() {
if (is_recorder_on_) {
CHECK_GT(ordered_list_.size(), 0);
is_recorder_on_ = false;
}
ordered_list_.clear();
var_for_map_.clear();
hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
}
bool CheckValidIf() {
// If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
// hoisting
return ((!if_var_list_.empty()) && (!CheckAttrVar()));
}
int GetNextLoopPos(int cur_pos) {
for (size_t i = cur_pos + 1; i < ordered_list_.size(); ++i) {
if (ordered_list_[i]->IsInstance<ForNode>()) {
return i;
}
}
return -1;
}
void InitRecorder() { hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); }
void StopRecording() { is_recorder_on_ = false; }
bool IsRecordingOn() { return is_recorder_on_; }
void StartOrAddRecord(const ObjectRef& op) {
is_recorder_on_ = true;
if (const auto* node = op.as<ForNode>()) {
if (!var_for_map_.count(node->loop_var.get()))
var_for_map_.insert({node->loop_var.get(), node});
ordered_list_.emplace_back(op.get());
} else if (const auto* node = op.as<AttrStmtNode>()) {
if (const auto* iv = node->node.as<IterVarNode>()) {
ordered_list_.emplace_back(iv->var.get());
} else if (const auto* iv = node->node.as<VarNode>()) {
ordered_list_.emplace_back(iv);
}
}
}
void RemoveRecord(const ObjectRef& op) {
StopRecording();
if (const auto* node = op.as<ForNode>()) var_for_map_.erase(node->loop_var.get());
if (ordered_list_.size() > 0) ordered_list_.pop_back();
}
void StopAndAddRecord(const ForNode* for_node, const IfThenElseNode* if_node) {
hoist_for_if_recorder = std::make_tuple(true, for_node, if_node);
StopRecording();
}
void UpdateAttrVarList(const AttrStmtNode* op) {
if (const auto* iv = op->node.as<IterVarNode>()) {
attr_var_list_.insert(iv->var.get());
} else if (const auto* iv = op->node.as<VarNode>()) {
attr_var_list_.insert(iv);
}
}
void RemoveAttrVarList(const AttrStmtNode* op) {
if (const auto* iv = op->node.as<IterVarNode>()) {
attr_var_list_.erase(iv->var.get());
} else if (const auto* iv = op->node.as<VarNode>()) {
attr_var_list_.erase(iv);
}
}
bool CheckAttrVar() {
for (auto var : if_var_list_) {
if (attr_var_list_.count(var)) {
return true;
}
}
return false;
}
// Ordered List maintains all ForNodes & AttrStmtNodes encountered in sequence
std::vector<const Object*> ordered_list_;
std::vector<const VarNode*> if_var_list_;
std::unordered_set<const VarNode*> attr_var_list_;
VarForMap var_for_map_;
bool is_if_cond_{false};
bool is_recorder_on_{false};
bool support_block_scope_hosting_{false};
};
class IfThenElseHoister : public StmtMutator {
public:
IfThenElseHoister() : hoist_selector_(HoistCandidateSelector()) {}
explicit IfThenElseHoister(bool support_block_scope_hosting)
: hoist_selector_(HoistCandidateSelector(support_block_scope_hosting)) {}
Stmt VisitAndMutate(Stmt stmt) {
hoist_selector_(stmt);
Stmt stmt_copy = std::move(stmt);
while (hoist_selector_.RecordingComplete()) {
target_for_ = hoist_selector_.GetTargetForNode();
target_if_ = hoist_selector_.GetTargetIfNode();
stmt_copy = operator()(stmt_copy);
hoist_selector_.ResetRecorder();
hoist_selector_(stmt_copy);
}
// Support SSA Form
stmt_copy = ConvertSSA(stmt_copy);
return stmt_copy;
}
Stmt VisitStmt_(const ForNode* op) final {
if ((!is_updating_) && (target_for_ == op)) {
is_updating_ = true;
is_then_case_ = true;
Stmt then_case = StmtMutator::VisitStmt_(op);
is_then_case_ = false;
Stmt else_case = Stmt();
if (target_if_->else_case.defined()) {
else_case = StmtMutator::VisitStmt_(op);
}
is_updating_ = false;
return IfThenElse(target_if_->condition, then_case, else_case);
}
return StmtMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const IfThenElseNode* op) final {
if (is_updating_ && (target_if_ == op)) {
if (is_then_case_) {
return StmtMutator::VisitStmt(op->then_case);
} else if (op->else_case.defined()) {
return StmtMutator::VisitStmt(op->else_case);
}
}
return StmtMutator::VisitStmt_(op);
}
private:
bool is_updating_{false};
bool is_then_case_{false};
HoistCandidateSelector hoist_selector_;
const ForNode* target_for_;
const IfThenElseNode* target_if_;
};
Stmt HoistIfThenElse(Stmt stmt, bool support_block_scope_hosting) {
return IfThenElseHoister(support_block_scope_hosting).VisitAndMutate(stmt);
}
Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoister().VisitAndMutate(stmt); }
namespace transform {
Pass HoistIfThenElse() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto cfg = ctx->GetConfig<HoistIfThenElseConfig>("tir.HoistIfThenElse");
if (!cfg.defined()) {
cfg = AttrsWithDefaultValues<HoistIfThenElseConfig>();
}
n->body = HoistIfThenElse(std::move(n->body), cfg.value()->support_block_scope_hosting);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.HoistIfThenElse", {});
}
Pass HoistIfThenElseBasic() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = HoistIfThenElse(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.HoistIfThenElseBasic", {});
}
TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElse").set_body_typed(HoistIfThenElse);
TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElseBasic").set_body_typed(HoistIfThenElseBasic);
} // namespace transform
} // namespace tir
} // namespace tvm