blob: 9a03b143d0f9acc3c4805c29b6927edf50a8fde6 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file reduce_branching_through_overcompute.cc
*
* \brief Attempt to remove conditional statements by introducing
* extra computations that do not impact the final results.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
#include <optional>
#include "../../arith/ir_mutator_with_analyzer.h"
#include "../analysis/control_flow_graph.h"
#include "remove_no_op.h"
#include "simplify.h"
namespace tvm {
namespace tir {
struct ReduceBranchingThroughOvercomputeConfigNode
: public AttrsNodeReflAdapter<ReduceBranchingThroughOvercomputeConfigNode> {
bool use_dataflow_analysis;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<ReduceBranchingThroughOvercomputeConfigNode>().def_ro(
"use_dataflow_analysis",
&ReduceBranchingThroughOvercomputeConfigNode::use_dataflow_analysis,
"If true, known buffer values are propagated and used "
"to statically prove that overcompute is valid.",
refl::DefaultValue(false));
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.transform.ReduceBranchingThroughOvercomputeConfig",
ReduceBranchingThroughOvercomputeConfigNode, BaseAttrsNode);
};
class ReduceBranchingThroughOvercomputeConfig : public Attrs {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ReduceBranchingThroughOvercomputeConfig, Attrs,
ReduceBranchingThroughOvercomputeConfigNode);
};
TVM_FFI_STATIC_INIT_BLOCK() { ReduceBranchingThroughOvercomputeConfigNode::RegisterReflection(); }
TVM_REGISTER_PASS_CONFIG_OPTION("tir.ReduceBranchingThroughOvercompute",
ReduceBranchingThroughOvercomputeConfig);
struct ElseBranchFiller : StmtExprMutator {
Stmt VisitStmt_(const IfThenElseNode* op) override {
IfThenElse ret = Downcast<IfThenElse>(StmtExprMutator::VisitStmt_(op));
if (ret->else_case.defined()) {
return ret;
} else {
auto new_else_clause = Evaluate(0);
new_else_clauses.insert(new_else_clause);
return IfThenElse(ret->condition, ret->then_case, new_else_clause);
}
}
std::unordered_set<Evaluate, ObjectPtrHash, ObjectPtrEqual> new_else_clauses;
};
class ElseBranchStripper : public StmtExprMutator {
public:
ElseBranchStripper(
const std::unordered_set<Evaluate, ObjectPtrHash, ObjectPtrEqual>& new_else_clauses)
: new_else_clauses_(new_else_clauses) {}
private:
Stmt VisitStmt_(const IfThenElseNode* op) override {
IfThenElse ret = Downcast<IfThenElse>(StmtExprMutator::VisitStmt_(op));
if (auto as_eval = ret->else_case.as<Evaluate>();
as_eval && new_else_clauses_.count(as_eval.value())) {
return IfThenElse(ret->condition, ret->then_case);
} else {
return ret;
}
}
const std::unordered_set<Evaluate, ObjectPtrHash, ObjectPtrEqual>& new_else_clauses_;
};
class BranchReducer : public arith::IRMutatorWithAnalyzer {
public:
static Stmt Apply(Stmt stmt, const std::optional<ControlFlowGraph>& touch_pattern) {
arith::Analyzer analyzer;
BranchReducer visitor(&analyzer, touch_pattern);
return visitor(std::move(stmt));
}
private:
using Parent = IRMutatorWithAnalyzer;
using Parent::VisitStmt;
using Parent::VisitStmt_;
BranchReducer(arith::Analyzer* analyzer, const std::optional<ControlFlowGraph>& touch_pattern)
: Parent(analyzer), touch_pattern_(touch_pattern) {}
Stmt VisitStmt_(const IfThenElseNode* op) final {
IfThenElse cond = Downcast<IfThenElse>(Parent::VisitStmt_(op));
auto is_special_case = [&](PrimExpr condition, Stmt general_case, Stmt special_case) -> bool {
condition = analyzer_->rewrite_simplify(condition);
With<arith::ConstraintContext> constraint(analyzer_, condition);
Stmt stmt = RemoveNoOp(general_case, analyzer_, touch_pattern_, special_case.get());
return StructuralEqual()(stmt, special_case);
};
ICHECK(cond->else_case.defined() || !touch_pattern_.has_value())
<< "Temp assert, should be true whenever touch pattern is available";
Stmt else_case = cond->else_case.value_or(Evaluate(0));
if (is_special_case(cond->condition, else_case, cond->then_case)) {
return else_case;
} else if (is_special_case(!cond->condition, cond->then_case, else_case)) {
return cond->then_case;
} else {
return cond;
}
}
private:
const std::optional<ControlFlowGraph>& touch_pattern_;
};
namespace transform {
Pass ReduceBranchingThroughOvercompute() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
arith::Analyzer analyzer;
ReduceBranchingThroughOvercomputeConfig config =
ctx->GetConfig<ReduceBranchingThroughOvercomputeConfig>(
"tir.ReduceBranchingThroughOvercompute")
.value_or(AttrsWithDefaultValues<ReduceBranchingThroughOvercomputeConfig>());
auto* n = f.CopyOnWrite();
std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
ElseBranchFiller else_branch_filler;
if (config->use_dataflow_analysis) {
n->body = else_branch_filler(std::move(n->body));
touch_pattern.emplace(n->body);
}
n->body = BranchReducer::Apply(std::move(n->body), touch_pattern);
if (config->use_dataflow_analysis) {
n->body = ElseBranchStripper(else_branch_filler.new_else_clauses)(std::move(n->body));
}
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.ReduceBranchingThroughOvercompute", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tir.transform.ReduceBranchingThroughOvercompute",
ReduceBranchingThroughOvercompute);
}
} // namespace transform
} // namespace tir
} // namespace tvm