| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file control_flow_graph.cc |
| * \brief Utility to deduce bound of expression |
| */ |
| |
| #include "control_flow_graph.h" |
| |
| #include <tvm/ffi/function.h> |
| #include <tvm/tir/analysis.h> |
| #include <tvm/tir/builtin.h> |
| #include <tvm/tir/expr.h> |
| #include <tvm/tir/op.h> |
| #include <tvm/tir/stmt_functor.h> |
| |
| #include <algorithm> |
| #include <numeric> |
| #include <optional> |
| #include <queue> |
| #include <set> |
| #include <sstream> |
| #include <unordered_set> |
| |
| #include "../../arith/conjunctive_normal_form.h" |
| #include "../../arith/constraint_extract.h" |
| #include "../../arith/ir_mutator_with_analyzer.h" |
| #include "../../arith/ir_visitor_with_analyzer.h" |
| #include "../../arith/narrow_predicate_expression.h" |
| #include "../../arith/unwrap_vector_expr.h" |
| |
| namespace tvm { |
| namespace tir { |
| |
| using namespace arith; |
| |
| namespace { |
| bool HasBufferLoad(PrimExpr expr) { |
| struct Visitor : public ExprVisitor { |
| void VisitExpr_(const BufferLoadNode* node) override { found_buffer_load = true; } |
| bool found_buffer_load{false}; |
| }; |
| |
| Visitor visitor; |
| visitor(expr); |
| return visitor.found_buffer_load; |
| } |
| |
| ffi::Optional<PrimExpr> SubstituteParamValues(const ffi::Array<Var>& param_vars, |
| const ffi::Array<PrimExpr>& param_values, |
| const PrimExpr& expr) { |
| ICHECK_EQ(param_vars.size(), param_values.size()) |
| << "Expression was defined as having " << param_vars.size() << " parameters, but received " |
| << param_values.size() << " arguments."; |
| |
| ffi::Map<tir::Var, PrimExpr> var_map; |
| for (size_t i = 0; i < param_values.size(); i++) { |
| var_map.Set(param_vars[i], param_values[i]); |
| } |
| |
| return Substitute(expr, var_map); |
| } |
| } // namespace |
| |
| PrimExpr BufferTouch::BeforeLoopIteration() const { |
| PrimExpr loop_predicate = Bool(true); |
| for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) { |
| const Var& loop_var = it->first; |
| const PrimExpr& loop_expr = it->second; |
| loop_predicate = (loop_var <= loop_expr) || ((loop_var == loop_expr) && loop_predicate); |
| } |
| return loop_predicate; |
| } |
| |
| PrimExpr BufferTouch::AtLoopIteration() const { |
| PrimExpr loop_predicate = Bool(true); |
| for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) { |
| const Var& loop_var = it->first; |
| const PrimExpr& loop_expr = it->second; |
| loop_predicate = (loop_var == loop_expr) && loop_predicate; |
| } |
| return loop_predicate; |
| } |
| |
| PrimExpr BufferTouch::AfterLoopIteration() const { |
| PrimExpr loop_predicate = Bool(true); |
| for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) { |
| const Var& loop_var = it->first; |
| const PrimExpr& loop_expr = it->second; |
| loop_predicate = (loop_var >= loop_expr) || ((loop_var == loop_expr) && loop_predicate); |
| } |
| return loop_predicate; |
| } |
| |
| bool BufferTouch::IsSubsetOf(const BufferTouch& other, Analyzer* analyzer) const { |
| if (this->buffer.same_as(other.buffer)) { |
| With<ConstraintContext> constraint(analyzer, predicate); |
| |
| return analyzer->CanProve(other.predicate); |
| } else { |
| return false; |
| } |
| } |
| |
| bool BufferTouch::IsDistinctFrom(const BufferTouch& other, Analyzer* analyzer) const { |
| if (this->buffer.same_as(other.buffer)) { |
| With<ConstraintContext> constraint(analyzer, predicate); |
| |
| return analyzer->CanProve(!other.predicate); |
| } else { |
| return true; |
| } |
| } |
| |
| std::ostream& operator<<(std::ostream& os, const BufferTouch& tp) { |
| auto touch_type = [&]() { |
| if (tp.touch_type == BufferTouch::AccessType::Read) { |
| return "read"; |
| } else if (tp.touch_type == BufferTouch::AccessType::Write) { |
| return "write"; |
| } else if (tp.touch_type == BufferTouch::AccessType::Assume) { |
| return "assume"; |
| } else { |
| return "???"; |
| } |
| }(); |
| |
| os << "BufferTouch(" << tp.buffer->name << ", " << touch_type << ", " << tp.predicate |
| << ", value = " << tp.value << ")"; |
| return os; |
| } |
| |
| class BufferConstraintApply : public IRMutatorWithAnalyzer { |
| public: |
| using Parent = IRMutatorWithAnalyzer; |
| |
| BufferConstraintApply(const ffi::Map<Buffer, ffi::Array<Var>>& axis_var_lookup, |
| const std::vector<BufferTouch>& knowns, Analyzer* analyzer) |
| : Parent(analyzer), axis_var_lookup_(axis_var_lookup), knowns_(knowns) {} |
| |
| using Parent::VisitExpr_; |
| |
| PrimExpr VisitExpr_(const BufferLoadNode* op) override { |
| for (const auto& known : knowns_) { |
| if (!op->buffer.same_as(known.buffer)) { |
| continue; |
| } |
| |
| ffi::Optional<Var> lane_var = std::nullopt; |
| IntImm num_lanes; |
| |
| ffi::Array<PrimExpr> indices = op->indices.Map([&](const auto& index) { |
| if (index.dtype().lanes() == 1) { |
| return index; |
| } else { |
| ICHECK(!lane_var) << "Multiple indices found with non-scalar values"; |
| lane_var = Var("lane", index.dtype().element_of()); |
| num_lanes = IntImm(index.dtype().element_of(), index.dtype().lanes()); |
| return UnwrapVectorExpr(index, lane_var.value()); |
| } |
| }); |
| |
| auto axis_vars = axis_var_lookup_.at(op->buffer); |
| PrimExpr predicate = SubstituteParamValues(axis_vars, indices, known.predicate).value(); |
| |
| std::optional<With<ConstraintContext>> context; |
| if (lane_var.defined()) { |
| Var lanes = lane_var.value(); |
| PrimExpr known = (IntImm(lanes.dtype(), 0) <= lanes) && (lanes < num_lanes); |
| context.emplace(analyzer_, known); |
| } |
| |
| if (analyzer_->CanProve(predicate)) { |
| return SubstituteParamValues(axis_vars, op->indices, known.value).value(); |
| } |
| } |
| |
| return ffi::GetRef<PrimExpr>(op); |
| } |
| |
| private: |
| const ffi::Map<Buffer, ffi::Array<Var>>& axis_var_lookup_; |
| const std::vector<BufferTouch>& knowns_; |
| }; |
| |
| /*! \brief Extract the control-flow graph |
| * |
| * Walk through a statement, populating the control-flow graph. |
| */ |
| class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { |
| public: |
| static void Build(ControlFlowGraph* out, const Stmt& stmt) { |
| ControlFlowGraphBuilder extractor(out); |
| extractor.AppendControlBlock(); |
| extractor(stmt); |
| } |
| |
| private: |
| ControlFlowGraphBuilder(ControlFlowGraph* out) : out_(out) {} |
| |
| using Parent = IRVisitorWithAnalyzer; |
| using Parent::VisitExpr_; |
| using Parent::VisitStmt_; |
| |
| void VisitStmt(const Stmt& stmt) override { |
| // Update the lookup table to determine which control-flow block |
| // contains the start of the specified statement. This is used |
| // later to determine which set of known values should be used to |
| // simplify a statement. |
| out_->control_flow_lookup_[stmt.get()] = CurrentControlBlock(); |
| Stmt prev_stmt = current_stmt_; |
| current_stmt_ = stmt; |
| Parent::VisitStmt(stmt); |
| current_stmt_ = prev_stmt; |
| } |
| |
| void VisitStmt_(const EvaluateNode* op) override { |
| if (auto* call = op->value.as<CallNode>()) { |
| if (call->op.same_as(builtin::assume())) { |
| Assume(call->args[0], true); |
| return; |
| } |
| } |
| |
| Parent::VisitStmt_(op); |
| } |
| |
| void Assume(PrimExpr assumption, bool from_assume_statement) { |
| for (const auto& expr : ExtractConstraints(assumption, false)) { |
| AssumeConstraintComponent(expr, from_assume_statement); |
| } |
| } |
| |
| void AssumeConstraintComponent(PrimExpr assumption, bool from_assume_statement) { |
| PrimExpr additional_predicate = Bool(true); |
| |
| std::vector<PrimExpr> buffer_exprs; |
| for (const auto& expr : ExtractComponents(assumption)) { |
| auto side_effect = tir::SideEffect(expr); |
| if (side_effect <= tir::CallEffectKind::kPure) { |
| // Pulling out portions of the assumption that do not depend |
| // on a buffer value allows the following two forms to be |
| // treated identically. |
| // |
| // Option 1: if i < 3: T.assume(buf[i] == value) |
| // Option 2: T.assume(i>=3 or buf[i] == value) |
| additional_predicate = additional_predicate && logical_not(expr); |
| } else if (side_effect == tir::CallEffectKind::kReadState) { |
| buffer_exprs.push_back(expr); |
| } else { |
| LOG(FATAL) << "Assumption must be pure or read-only, but contained expression " << expr |
| << " with side-effect \'" << side_effect << "\'"; |
| } |
| } |
| |
| if (buffer_exprs.empty()) { |
| out_->non_buffer_assumptions_.push_back(!CurrentScopePredicate() || assumption); |
| return; |
| } |
| |
| CHECK_EQ(buffer_exprs.size(), 1) << "T.assume must contain only a single buffer expression"; |
| |
| auto* as_equal_node = buffer_exprs[0].as<tir::EQNode>(); |
| CHECK(as_equal_node || !from_assume_statement) |
| << "T.assume buffer constraint must be of the form 'buffer[indices] == " |
| "value', but received " |
| << assumption; |
| if (!as_equal_node) { |
| // This assumption is an inequality on a data-dependent |
| // conditional. Not an error for this to occur, but also not |
| // something that is currently supported. |
| return; |
| } |
| |
| tir::BufferLoad load; |
| PrimExpr value; |
| if (auto opt = as_equal_node->a.as<tir::BufferLoad>()) { |
| load = opt.value(); |
| value = as_equal_node->b; |
| } else if (auto opt = as_equal_node->b.as<tir::BufferLoad>()) { |
| load = opt.value(); |
| value = as_equal_node->a; |
| } else if (!from_assume_statement) { |
| return; |
| } else { |
| LOG(FATAL) << "T.assume buffer constraint must be of the form 'buffer[indices] == value'"; |
| } |
| |
| auto has_side_effect = tir::SideEffect(value) > tir::CallEffectKind::kPure; |
| CHECK(!has_side_effect || !from_assume_statement) |
| << "Buffer value in constraint must be pure expression, but was " << value; |
| if (has_side_effect) { |
| return; |
| } |
| |
| { |
| InternalConstraintContext context(this, additional_predicate); |
| VisitAccess(load, BufferTouch::AccessType::Assume, value); |
| } |
| // Appending a control block ensures that all control blocks have |
| // at most one statement that changes the known buffer contents. |
| auto prev_block = CurrentControlBlock(); |
| auto new_block = AppendControlBlock(); |
| MarkControlFlow(prev_block, new_block); |
| } |
| |
| void VisitExpr_(const LetNode* op) override { |
| std::optional<BindLetVar> binding; |
| if (UsesLoopVar(op->value)) { |
| binding.emplace(this, op->var, op->value); |
| } |
| Parent::VisitExpr_(op); |
| } |
| |
| void VisitStmt_(const LetStmtNode* op) override { |
| std::optional<BindLetVar> binding; |
| if (UsesLoopVar(op->value)) { |
| binding.emplace(this, op->var, op->value); |
| } |
| Parent::VisitStmt_(op); |
| } |
| |
| void VisitExpr_(const BufferLoadNode* op) override { |
| Parent::VisitExpr_(op); |
| BufferLoad load = ffi::GetRef<BufferLoad>(op); |
| VisitAccess(load, BufferTouch::AccessType::Read, load); |
| } |
| |
| void VisitStmt_(const BufferStoreNode* op) override { |
| Parent::VisitStmt_(op); |
| VisitAccess(ffi::GetRef<BufferStore>(op), BufferTouch::AccessType::Write, op->value); |
| // Appending a control block ensures that all control blocks have |
| // at most one statement that changes the buffer contents. |
| auto prev_block = CurrentControlBlock(); |
| auto new_block = AppendControlBlock(); |
| MarkControlFlow(prev_block, new_block); |
| } |
| |
| void VisitStmt_(const ForNode* op) override { |
| out_->iterator_ranges_.Set(op->loop_var, Range::FromMinExtent(op->min, op->extent)); |
| |
| auto before_loop = CurrentControlBlock(); |
| size_t loop_start = -1; |
| |
| { |
| BindActiveLoopVar binding(this, op->loop_var, op->min, op->extent); |
| loop_start = AppendControlBlock(); |
| Parent::VisitStmt_(op); |
| } |
| |
| auto loop_end = CurrentControlBlock(); |
| auto after_loop = AppendControlBlock(); |
| PrimExpr max_iterator_value = analyzer_.Simplify(op->min + op->extent - 1); |
| { |
| auto [forward, backward] = MarkControlFlow(before_loop, loop_start); |
| backward.post_condition = (op->loop_var == op->min); |
| forward.var_remap = {{op->loop_var, op->min}}; |
| } |
| { |
| auto [forward, backward] = MarkControlFlow(loop_end, after_loop); |
| backward.var_remap = {{op->loop_var, max_iterator_value}}; |
| forward.post_condition = (op->loop_var == max_iterator_value); |
| } |
| { |
| auto [forward, backward] = MarkControlFlow(loop_end, loop_start); |
| backward.var_remap = {{op->loop_var, op->loop_var - 1}}; |
| forward.var_remap = {{op->loop_var, op->loop_var + 1}}; |
| backward.post_condition = (op->loop_var > op->min); |
| forward.post_condition = (op->loop_var < max_iterator_value); |
| } |
| } |
| |
| void VisitStmt_(const IfThenElseNode* op) override { |
| this->VisitExpr(op->condition); |
| |
| PrimExpr real_condition = ExtractRealCondition(op->condition); |
| |
| auto before_branching = CurrentControlBlock(); |
| |
| auto branch_start = AppendControlBlock(); |
| MarkControlFlow(before_branching, branch_start); |
| |
| { |
| InternalConstraintContext context(this, real_condition); |
| auto then_start = AppendControlBlock(); |
| if (context.assume.defined()) { |
| Assume(context.assume.value(), false); |
| } |
| auto [forward, backward] = MarkControlFlow(branch_start, then_start); |
| backward.post_condition = real_condition; |
| forward.post_condition = real_condition; |
| this->VisitStmt(op->then_case); |
| } |
| auto then_end = CurrentControlBlock(); |
| |
| auto negation = analyzer_.rewrite_simplify(!real_condition); |
| { |
| InternalConstraintContext context(this, negation); |
| auto else_start = AppendControlBlock(); |
| if (context.assume.defined()) { |
| Assume(context.assume.value(), false); |
| } |
| auto [forward, backward] = MarkControlFlow(branch_start, else_start); |
| backward.post_condition = negation; |
| forward.post_condition = negation; |
| |
| if (op->else_case.defined()) { |
| this->VisitStmt(op->else_case.value()); |
| } |
| } |
| |
| auto else_end = CurrentControlBlock(); |
| auto after_branching = AppendControlBlock(); |
| |
| if (HasBufferLoad(real_condition)) { |
| // The buffer value may have changed during the body of the |
| // condition, so we can't provide it as a post-condition. |
| MarkControlFlow(then_end, after_branching); |
| MarkControlFlow(else_end, after_branching); |
| } else { |
| { |
| auto [forward, backward] = MarkControlFlow(then_end, after_branching); |
| backward.post_condition = real_condition; |
| forward.post_condition = real_condition; |
| } |
| { |
| auto [forward, backward] = MarkControlFlow(else_end, after_branching); |
| backward.post_condition = negation; |
| forward.post_condition = negation; |
| } |
| } |
| } |
| |
| /*! \brief Internal utility, returns true if the expression depends |
| * on a loop iterator |
| */ |
| bool UsesLoopVar(const PrimExpr& expr) { |
| return UsesVar(expr, [&](const VarNode* expr_var) { |
| return loop_dependent_vars_.find(expr_var) != loop_dependent_vars_.end(); |
| }); |
| } |
| |
| /*! \brief Record the interaction with the buffer. |
| * |
| * \param node The TIR node that accesses the buffer. Should be |
| * either a BufferLoad or BufferStore node. |
| * |
| * \param touch_type The type of buffer access being performed. A |
| * BufferStore should always use AccessType::Write. A BufferLoad |
| * may use either AccessType::Read or AccessType::Assume, depending |
| * on whether the BufferLoad occurs within `builtin::assume`. |
| * |
| * \param known_value_expr The value in the buffer following the access. |
| */ |
| template <typename BufferAccess> |
| void VisitAccess(const BufferAccess& node, BufferTouch::AccessType touch_type, |
| PrimExpr known_value_expr) { |
| auto& current_block = out_->control_flow_.back(); |
| BufferTouch buffer_touch = current_block.MakeBufferTouch(out_, node->buffer, node->indices, |
| touch_type, known_value_expr); |
| current_block.touch_points.push_back(buffer_touch); |
| } |
| |
| /*! \brief Return a predicate for having reached the current |
| * control-flow block |
| * |
| * For example, while inside an IfThenElse, will return the |
| * IfThenElse's condition. |
| */ |
| PrimExpr CurrentScopePredicate() const { |
| PrimExpr predicate = Bool(true); |
| for (const auto& condition : conditions_) { |
| predicate = predicate && condition; |
| } |
| return predicate; |
| } |
| |
| /* \brief Add a new control block, returning its index */ |
| size_t AppendControlBlock() { |
| size_t index = out_->control_flow_.size(); |
| auto& block = out_->control_flow_.emplace_back(); |
| block.active_loop_iterators = active_loop_iterators_; |
| block.let_bindings_using_loop = let_bindings_using_loop_; |
| block.scope_predicate = CurrentScopePredicate(); |
| return index; |
| } |
| |
| /* \brief The index of the current control block */ |
| size_t CurrentControlBlock() { return out_->control_flow_.size() - 1; } |
| |
| /* \brief Mark a possible control from one block to another |
| * |
| * \param from_block The block from which control leaves |
| * |
| * \param to_block The block to which control enters |
| * |
| * \param var_remap Variable replacements that should be made in |
| * known expression while traversing this edge. For example, |
| * replacing `i` with `i-1` when entering the next loop iteration, |
| * or replacing `i` with `n-1` when concluding a loop. |
| */ |
| std::pair<ControlFlowGraph::ControlFlowEdge&, ControlFlowGraph::ControlFlowEdge&> MarkControlFlow( |
| size_t from_block, size_t to_block) { |
| ICHECK_LE(from_block, out_->control_flow_.size()); |
| ICHECK_LE(to_block, out_->control_flow_.size()); |
| |
| auto& forward = out_->control_flow_[from_block].successors.emplace_back( |
| ControlFlowGraph::ControlFlowEdge{to_block, {}, std::nullopt}); |
| auto& backward = out_->control_flow_[to_block].predecessors.emplace_back( |
| ControlFlowGraph::ControlFlowEdge{from_block, {}, std::nullopt}); |
| return {forward, backward}; |
| } |
| |
| // Internal utility, context manager for entering/leaving a scoped constraint |
| struct InternalConstraintContext { |
| InternalConstraintContext(ControlFlowGraphBuilder* self, PrimExpr constraint) |
| : self(self), analyzer_context(&self->analyzer_, constraint) { |
| old_num_constraints = self->conditions_.size(); |
| |
| auto side_effect = tir::SideEffect(constraint); |
| if (side_effect <= tir::CallEffectKind::kPure) { |
| self->conditions_.push_back(constraint); |
| } else if (side_effect <= tir::CallEffectKind::kReadState) { |
| assume = constraint; |
| } |
| |
| new_num_constraints = self->conditions_.size(); |
| } |
| ~InternalConstraintContext() { |
| ICHECK_EQ(self->conditions_.size(), new_num_constraints) |
| << "Internal error: Each condition should only be popped once."; |
| self->conditions_.erase(self->conditions_.begin() + old_num_constraints, |
| self->conditions_.end()); |
| } |
| |
| ControlFlowGraphBuilder* self{nullptr}; |
| With<ConstraintContext> analyzer_context; |
| size_t old_num_constraints{0}; |
| size_t new_num_constraints{0}; |
| ffi::Optional<PrimExpr> assume{std::nullopt}; |
| |
| // Disable default-generated copy/move assignment and constructors |
| InternalConstraintContext(const InternalConstraintContext&) = delete; |
| InternalConstraintContext& operator=(const InternalConstraintContext&) = delete; |
| InternalConstraintContext(InternalConstraintContext&&) = delete; |
| InternalConstraintContext& operator=(InternalConstraintContext&&) = delete; |
| }; |
| |
| // Internal utility, context manager for tracking a loop |
| struct BindActiveLoopVar { |
| BindActiveLoopVar(ControlFlowGraphBuilder* self, Var var, PrimExpr loop_min, |
| PrimExpr loop_extent) |
| : self(self), var(var) { |
| PrimExpr loop_max = loop_min + (loop_extent - 1); |
| auto loop_range = Range::FromMinExtent(loop_min, loop_extent); |
| self->active_loop_iterators_.push_back({var, loop_min, loop_max, loop_range}); |
| self->loop_dependent_vars_.insert(var.get()); |
| } |
| ~BindActiveLoopVar() { self->active_loop_iterators_.pop_back(); } |
| |
| ControlFlowGraphBuilder* self; |
| Var var; |
| |
| // Disable default-generated copy/move assignment and constructors |
| BindActiveLoopVar(const BindActiveLoopVar&) = delete; |
| BindActiveLoopVar& operator=(const BindActiveLoopVar&) = delete; |
| BindActiveLoopVar(BindActiveLoopVar&&) = delete; |
| BindActiveLoopVar& operator=(BindActiveLoopVar&&) = delete; |
| }; |
| |
| // Internal utility, context manager for tracking a variable binding |
| struct BindLetVar { |
| BindLetVar(ControlFlowGraphBuilder* self, Var var, PrimExpr value) : self(self), var(var) { |
| self->let_bindings_using_loop_.Set(var, value); |
| self->loop_dependent_vars_.insert(var.get()); |
| } |
| ~BindLetVar() { |
| self->loop_dependent_vars_.erase(var.get()); |
| self->let_bindings_using_loop_.erase(var); |
| } |
| ControlFlowGraphBuilder* self; |
| Var var; |
| |
| // Disable default-generated copy/move assignment and constructors |
| BindLetVar(const BindLetVar&) = delete; |
| BindLetVar& operator=(const BindLetVar&) = delete; |
| BindLetVar(BindLetVar&&) = delete; |
| BindLetVar& operator=(BindLetVar&&) = delete; |
| }; |
| |
| struct LoopEntry { |
| Var loop_var; |
| PrimExpr loop_min; |
| PrimExpr loop_max; |
| Range loop_range; |
| }; |
| |
| // Track in order to know which Vars to write in terms of the buffer |
| // indices and substitute out of the predicate. |
| std::vector<ControlFlowGraph::ControlFlowBlock::LoopEntry> active_loop_iterators_; |
| |
| // Track all loop iterators, along with values derived from loop iterators. |
| std::unordered_set<const VarNode*> loop_dependent_vars_; |
| |
| // Any let binding that depends, directly or indirectly, on a loop |
| // binding. When making a predicate in terms of the buffer indices, |
| // these need to be substituted out. |
| // std::unordered_map<const VarNode*, PrimExpr> let_bindings_using_loop_; |
| ffi::Map<Var, PrimExpr> let_bindings_using_loop_; |
| |
| // Track in order to know what conditions limit the buffer access |
| std::vector<PrimExpr> conditions_; |
| |
| // Track in order to know what statement initiated the buffer access |
| Stmt current_stmt_; |
| |
| // Output data structure |
| ControlFlowGraph* out_; |
| }; |
| |
| std::pair<BufferTouch, ffi::Map<Var, Range>> ControlFlowGraph::ControlFlowBlock::MakeBufferTouch( |
| const tir::Buffer& buf, ffi::Array<Var> index_variables, ffi::Array<PrimExpr> indices, |
| BufferTouch::AccessType touch_type, PrimExpr known_value_expr) const { |
| const auto& current_block = *this; |
| |
| Analyzer local_analyzer; |
| |
| ffi::Optional<Var> lane_var = std::nullopt; |
| IntImm num_lanes; |
| |
| ffi::Array<PrimExpr> index_expressions = indices.Map([&](const auto& index) { |
| if (index.dtype().lanes() == 1) { |
| return index; |
| } else { |
| ICHECK(!lane_var) << "Multiple indices found with non-scalar values"; |
| lane_var = Var("lane", index.dtype().element_of()); |
| num_lanes = IntImm(index.dtype().element_of(), index.dtype().lanes()); |
| return UnwrapVectorExpr(index, lane_var.value()); |
| } |
| }); |
| |
| ffi::Array<Var> loop_vars; |
| |
| ffi::Map<Var, Range> loop_ranges; |
| for (const auto& loop_entry : current_block.active_loop_iterators) { |
| loop_vars.push_back(loop_entry.loop_var); |
| loop_ranges.Set(loop_entry.loop_var, loop_entry.loop_range); |
| } |
| |
| // If the indices contain multiple lanes, treat the lane variable |
| // as an additional loop iterator to be solved for and substituted |
| // out. |
| if (lane_var) { |
| loop_vars.push_back(lane_var.value()); |
| loop_ranges.Set(lane_var.value(), Range::FromMinExtent(0, num_lanes)); |
| } |
| |
| IntConstraintsTransform transform = [&]() { |
| ICHECK_EQ(index_variables.size(), index_expressions.size()); |
| |
| ffi::Array<PrimExpr> relations; |
| |
| for (size_t i = 0; i < index_expressions.size(); i++) { |
| PrimExpr expr = index_expressions[i]; |
| Var var = index_variables[i]; |
| |
| expr = Substitute(expr, current_block.let_bindings_using_loop); |
| relations.push_back(var == expr); |
| } |
| |
| IntConstraints system(loop_vars, loop_ranges, relations); |
| return arith::SolveLinearEquations(system); |
| }(); |
| |
| ffi::Map<Var, PrimExpr> loop_var_to_axis_var = transform->src_to_dst; |
| ffi::Map<Var, Range> free_params = transform->dst->ranges; |
| PrimExpr transform_predicate = |
| std::accumulate(transform->dst->relations.begin(), transform->dst->relations.end(), |
| PrimExpr(Bool(true)), [](PrimExpr a, PrimExpr b) { return a && b; }); |
| |
| transform_predicate = SimplifyAsAndOfOrs(transform_predicate, &local_analyzer); |
| |
| auto find_removable_params = [&]() -> ffi::Map<Var, PrimExpr> { |
| ffi::Map<Var, PrimExpr> removable_params; |
| |
| // The arith::SolveLinearEquations is more general than the |
| // utilities in iter_affine_map.h, but can introduce free |
| // parameters that could later be determined with the known |
| // constraints. This step removes all such free parameters. |
| for (const auto& expr : ExtractConstraints(transform_predicate)) { |
| if (auto* as_equal = expr.as<EQNode>()) { |
| auto check_expr = [&](const PrimExpr& a, const PrimExpr& b) { |
| auto* var_ptr = a.as<VarNode>(); |
| if (!var_ptr) { |
| return; |
| } |
| |
| Var var = ffi::GetRef<Var>(var_ptr); |
| if (free_params.count(var) == 0) { |
| return; |
| } |
| |
| bool uses_free_param = UsesVar( |
| b, [&](const VarNode* v) { return free_params.count(ffi::GetRef<Var>(v)) > 0; }); |
| if (uses_free_param) { |
| return; |
| } |
| removable_params.Set(var, b); |
| }; |
| check_expr(as_equal->a, as_equal->b); |
| check_expr(as_equal->b, as_equal->a); |
| } |
| } |
| |
| // In addition, the arith::SolveLinearEquation can introduce |
| // free parameters with an extent of one. Filtering them out here |
| // avoids needing to track them through later simplifications. |
| for (const auto [var, range] : free_params) { |
| if (is_one(range->extent)) { |
| removable_params.Set(var, range->min); |
| } |
| } |
| |
| return removable_params; |
| }; |
| for (auto removable_params = find_removable_params(); removable_params.size() > 0; |
| removable_params = find_removable_params()) { |
| auto update = [&](const PrimExpr& expr) { |
| return local_analyzer.Simplify(Substitute(expr, removable_params)); |
| }; |
| |
| ffi::Map<Var, PrimExpr> new_map; |
| for (const auto [loop_var, expr] : loop_var_to_axis_var) { |
| static_cast<void>(expr); // gcc 7.x bug, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 |
| new_map.Set(loop_var, update(expr)); |
| } |
| loop_var_to_axis_var = new_map; |
| |
| transform_predicate = update(transform_predicate); |
| |
| for (const auto [var, expr] : removable_params) { |
| static_cast<void>(expr); // gcc 7.x bug, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 |
| free_params.erase(var); |
| } |
| } |
| |
| // Normalization function, applied to both the predicate and the |
| // known value. Converts from an expression in terms of loop |
| // iterators to an expression in terms of buffer indices. |
| auto normalize_expr = [&](PrimExpr expr) -> PrimExpr { |
| expr = Substitute(expr, current_block.let_bindings_using_loop); |
| |
| if (lane_var) { |
| expr = UnwrapVectorExpr(expr, lane_var.value()); |
| } |
| expr = Substitute(expr, loop_var_to_axis_var); |
| |
| return expr; |
| }; |
| |
| // Collect the current loop variables, along with an expression for |
| // the loop variables in terms of the buffer axis variables. This |
| // is used during forward/backward propagation to generate predicate |
| // tracking whether a loop iteration has been reached. |
| std::vector<std::pair<Var, PrimExpr>> loop_var_expressions; |
| for (const auto& entry : current_block.active_loop_iterators) { |
| auto expr_it = loop_var_to_axis_var.find(entry.loop_var); |
| ICHECK(expr_it != loop_var_to_axis_var.end()); |
| loop_var_expressions.push_back({entry.loop_var, (*expr_it).second}); |
| } |
| |
| // The full predicate is composed of the values required to reach |
| // the scope of the BufferStore or builtin::assume(), any bounds |
| // implied by solving for the axis variables, and any additional |
| // statements resulting from unpacking the expression contained in |
| // builtin::assume(). |
| PrimExpr scope_predicate = normalize_expr(current_block.scope_predicate); |
| transform_predicate = normalize_expr(transform_predicate); |
| |
| known_value_expr = local_analyzer.Simplify(normalize_expr(known_value_expr)); |
| |
| // Deliberately use an analyzer without scope-based information, |
| // to avoid simplifying `scope_predicate` to True. |
| PrimExpr predicate_expr = local_analyzer.Simplify(transform_predicate && scope_predicate); |
| |
| BufferTouch buffer_touch = {buf, predicate_expr, known_value_expr, loop_var_expressions, |
| touch_type}; |
| |
| return {buffer_touch, free_params}; |
| } |
| |
| BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph* graph, |
| const tir::Buffer& buf, |
| const ffi::Array<PrimExpr>& indices, |
| BufferTouch::AccessType touch_type, |
| PrimExpr known_value_expr) const { |
| ICHECK(graph); |
| auto [buffer_touch, free_params] = MakeBufferTouch(buf, graph->GetIndexVariables(buf, indices), |
| indices, touch_type, known_value_expr); |
| for (const auto& pair : free_params) { |
| graph->free_predicate_parameters_.Set(pair.first, pair.second); |
| } |
| return buffer_touch; |
| } |
| |
| ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, int64_t max_simplification_steps, |
| size_t max_revisits) |
| : max_revisits_(max_revisits), max_simplification_steps_(max_simplification_steps) { |
| ControlFlowGraphBuilder::Build(this, stmt); |
| ForwardPropagateKnownValues(); |
| BackwardPropagateUnusedValues(); |
| } |
| |
| void ControlFlowGraph::RemoveStore(const tir::BufferStore& store) { |
| size_t context_index = [&]() { |
| auto it = control_flow_lookup_.find(store.get()); |
| ICHECK(it != control_flow_lookup_.end()) |
| << "BufferStore did not occur in the Stmt provided to BufferTouchPattern's constructor"; |
| return it->second; |
| }(); |
| |
| auto& touch_points = control_flow_[context_index].touch_points; |
| |
| touch_points.erase(std::remove_if(touch_points.begin(), touch_points.end(), |
| [](const BufferTouch& touch) { |
| return touch.touch_type == BufferTouch::AccessType::Write; |
| }), |
| touch_points.end()); |
| ForwardPropagateKnownValues(context_index); |
| BackwardPropagateUnusedValues(context_index); |
| } |
| |
| std::ostream& operator<<(std::ostream& os, const ControlFlowGraph::ControlFlowEdge& edge) { |
| os << edge.index; |
| if (edge.var_remap.size()) { |
| os << " with remap " << edge.var_remap; |
| } |
| if (edge.post_condition) { |
| os << " with postcondition " << edge.post_condition; |
| } |
| |
| return os; |
| } |
| |
| std::ostream& operator<<(std::ostream& os, const ControlFlowGraph::ControlFlowBlock& block) { |
| os << "Predecessors: ["; |
| for (size_t i = 0; i < block.predecessors.size(); i++) { |
| if (i) { |
| os << ", "; |
| } |
| os << block.predecessors[i]; |
| } |
| os << "]\n"; |
| |
| os << "Active loop iterators: ["; |
| for (size_t i = 0; i < block.active_loop_iterators.size(); i++) { |
| if (i) { |
| os << ", "; |
| } |
| os << block.active_loop_iterators[i].loop_var; |
| } |
| os << "]\n"; |
| |
| os << "Before block knowns: " << block.known_at_block_start << "\n"; |
| |
| os << "Before block unused: " << block.unused_at_block_start << "\n"; |
| |
| for (size_t i = 0; i < block.touch_points.size(); i++) { |
| os << "Touch[" << i << "] = " << block.touch_points[i] << "\n"; |
| } |
| os << "After block: " << block.known_at_block_end << "\n"; |
| |
| os << "After block unused: " << block.unused_at_block_end << "\n"; |
| |
| os << "Successors: ["; |
| for (size_t i = 0; i < block.successors.size(); i++) { |
| if (i) { |
| os << ", "; |
| } |
| os << block.successors[i]; |
| } |
| os << "]"; |
| return os; |
| } |
| |
| std::ostream& operator<<(std::ostream& os, const ControlFlowGraph& pattern) { |
| os << "Touch pattern contains " << pattern.control_flow_.size() << " control blocks." |
| << (pattern.control_flow_.size() ? "\n" : ""); |
| for (size_t i = 0; i < pattern.control_flow_.size(); i++) { |
| os << "\t" |
| << "ControlBlock[" << i << "] = " << pattern.control_flow_[i] << "\n"; |
| } |
| |
| return os; |
| } |
| |
| bool BufferTouch::IsEquivalentTo(const BufferTouch& other, Analyzer* analyzer) const { |
| // Constraints must apply to the same buffer to be equivalent |
| if (!buffer.same_as(other.buffer) || touch_type != other.touch_type) { |
| return false; |
| } |
| |
| ExprDeepEqual deep_equal; |
| |
| auto implies = [&](const PrimExpr& a, const PrimExpr& b) -> bool { |
| With<ConstraintContext> context(analyzer, a); |
| return analyzer->CanProve(b); |
| }; |
| |
| // Predicates must be equivalent expressions, or must both be undefined |
| bool equivalent_predicates = |
| deep_equal(predicate, other.predicate) || |
| (implies(predicate, other.predicate) && implies(other.predicate, predicate)); |
| if (!equivalent_predicates) { |
| return false; |
| } |
| |
| // The known value must be equal |
| if (!deep_equal(value, other.value) && !analyzer->CanProveEqual(value, other.value)) { |
| return false; |
| } |
| |
| return true; |
| } |
| |
| std::ostream& operator<<(std::ostream& os, const BufferState& state) { |
| for (size_t i = 0; i < state.constraints_.size(); i++) { |
| os << "constraints[" << i << "] = " << state.constraints_[i] |
| << (i + 1 == state.constraints_.size() ? "" : "\n"); |
| } |
| return os; |
| } |
| |
| PrimExpr BufferState::SubstituteKnownBufferValues( |
| PrimExpr expr, const ffi::Map<tir::Buffer, ffi::Array<tir::Var>>& axis_var_lookup, |
| Analyzer* analyzer) const { |
| BufferConstraintApply mutator(axis_var_lookup, constraints_, analyzer); |
| return mutator(std::move(expr)); |
| } |
| |
| void BufferState::AddCondition(const PrimExpr& condition) { |
| for (auto& constraint : constraints_) { |
| constraint.predicate = constraint.predicate && condition; |
| } |
| } |
| |
| void BufferState::Substitute(const ffi::Map<Var, PrimExpr>& var_remap, Analyzer* analyzer) { |
| if (var_remap.size()) { |
| for (auto& prior : constraints_) { |
| PrimExpr updated = tvm::tir::Substitute(prior.predicate, var_remap); |
| if (!updated.same_as(prior.predicate)) { |
| prior.predicate = SimplifyAsAndOfOrs(updated, analyzer); |
| } |
| } |
| } |
| } |
| |
| void BufferState::Simplify(Analyzer* analyzer) { |
| for (auto& constraint : constraints_) { |
| constraint.predicate = SimplifyAsAndOfOrs(constraint.predicate, analyzer); |
| } |
| } |
| |
| void BufferState::Union(const BufferState& b, Analyzer* analyzer) { |
| for (const auto& b_constraint : b.constraints_) { |
| bool used = false; |
| for (auto& a_constraint : constraints_) { |
| if (a_constraint.buffer.same_as(b_constraint.buffer) && |
| analyzer->CanProveEqual(a_constraint.value, b_constraint.value)) { |
| a_constraint.predicate = |
| SimplifyAsAndOfOrs(a_constraint.predicate || b_constraint.predicate, analyzer); |
| used = true; |
| break; |
| } |
| } |
| if (!used) { |
| constraints_.push_back(b_constraint); |
| } |
| } |
| } |
| |
| void BufferState::Intersection(const BufferState& b, Analyzer* analyzer) { |
| // For a constraint to be in the output, it must be present in both |
| // inputs. |
| |
| std::vector<BufferTouch> new_constraints; |
| for (const auto& ai : constraints_) { |
| for (const auto& bi : b.constraints_) { |
| if (ai.buffer.same_as(bi.buffer)) { |
| PrimExpr predicate = SimplifyAsAndOfOrs(ai.predicate && bi.predicate, analyzer); |
| if (!is_zero(predicate)) { |
| With<ConstraintContext> context(analyzer, predicate); |
| PrimExpr known_value_a = ai.value; |
| PrimExpr known_value_b = bi.value; |
| |
| bool is_consistent = analyzer->CanProveEqual(known_value_a, known_value_b); |
| if (is_consistent) { |
| new_constraints.push_back({ai.buffer, predicate, known_value_a}); |
| } |
| } |
| } |
| } |
| } |
| |
| constraints_ = std::move(new_constraints); |
| } |
| |
| class BufferRegionCollector : public ExprVisitor { |
| public: |
| struct Region { |
| PrimExpr region_predicate; |
| std::unordered_map<const BufferLoadNode*, ffi::Optional<PrimExpr>> known_values; |
| }; |
| |
| static std::vector<Region> Collect(const ffi::Map<Buffer, ffi::Array<Var>>& axis_var_lookup, |
| const std::vector<BufferTouch>& knowns, |
| const std::vector<ffi::Optional<PrimExpr>>& exprs, |
| Analyzer* analyzer) { |
| BufferRegionCollector collector(axis_var_lookup, knowns, analyzer); |
| for (const auto& expr : exprs) { |
| if (expr) { |
| collector(expr.value()); |
| } |
| } |
| |
| return collector.regions_; |
| } |
| |
| private: |
| using Parent = ExprVisitor; |
| |
| BufferRegionCollector(const ffi::Map<Buffer, ffi::Array<Var>>& axis_var_lookup, |
| const std::vector<BufferTouch>& knowns, Analyzer* analyzer) |
| : analyzer_(analyzer), axis_var_lookup_(axis_var_lookup), knowns_(knowns) { |
| regions_.push_back(Region{Bool(true), {}}); |
| } |
| |
| using Parent::VisitExpr_; |
| |
| void VisitExpr_(const BufferLoadNode* op) override { |
| // Helper struct for the known values of this BufferLoad |
| struct Known { |
| PrimExpr predicate; |
| ffi::Optional<PrimExpr> value; |
| }; |
| |
| std::vector<Known> new_regions; |
| |
| PrimExpr unknown_region = Bool(true); |
| |
| for (const BufferTouch& constraint : knowns_) { |
| if (!op->buffer.same_as(constraint.buffer)) { |
| // This is a different buffer, so continue searching. |
| continue; |
| } |
| |
| auto axis_vars = axis_var_lookup_.at(op->buffer); |
| PrimExpr touch_predicate = |
| SubstituteParamValues(axis_vars, op->indices, constraint.predicate).value(); |
| touch_predicate = SimplifyAsAndOfOrs(touch_predicate, analyzer_); |
| |
| if (!is_zero(touch_predicate)) { |
| ffi::Optional<PrimExpr> known_value = |
| SubstituteParamValues(axis_vars, op->indices, constraint.value); |
| new_regions.push_back(Known{touch_predicate, known_value}); |
| |
| unknown_region = unknown_region && !touch_predicate; |
| unknown_region = SimplifyAsAndOfOrs(unknown_region, analyzer_); |
| } |
| } |
| |
| if (new_regions.size()) { |
| Analyzer local_analyzer; |
| |
| if (!is_zero(unknown_region)) { |
| new_regions.insert(new_regions.begin(), Known{unknown_region, std::nullopt}); |
| } |
| |
| std::vector<Region> updated_regions; |
| for (const auto& prev_region : regions_) { |
| for (const auto& new_region : new_regions) { |
| PrimExpr intersection = |
| SimplifyAsAndOfOrs(prev_region.region_predicate && new_region.predicate, analyzer_); |
| |
| if (!is_zero(intersection)) { |
| Region merged{intersection, prev_region.known_values}; |
| merged.known_values[op] = new_region.value; |
| updated_regions.push_back(std::move(merged)); |
| } |
| } |
| } |
| regions_ = updated_regions; |
| } |
| } |
| |
| Analyzer* analyzer_; |
| std::vector<Region> regions_; |
| const ffi::Map<Buffer, ffi::Array<Var>>& axis_var_lookup_; |
| const std::vector<BufferTouch>& knowns_; |
| }; |
| |
| class BufferRegionValueReplacer : public IRMutatorWithAnalyzer { |
| public: |
| static PrimExpr Apply( |
| const std::unordered_map<const BufferLoadNode*, ffi::Optional<PrimExpr>>& known_values, |
| PrimExpr expr, Analyzer* analyzer) { |
| BufferRegionValueReplacer mutator(known_values, analyzer); |
| PrimExpr result = mutator(expr); |
| // Simplification must occur after the substitution, as known |
| // values may provide enable simplifications. Also, cannot track |
| // whether a BufferLoad was |
| result = analyzer->Simplify(result); |
| return result; |
| } |
| |
| private: |
| using Parent = IRMutatorWithAnalyzer; |
| |
| BufferRegionValueReplacer( |
| const std::unordered_map<const BufferLoadNode*, ffi::Optional<PrimExpr>>& known_values, |
| Analyzer* analyzer) |
| : Parent(analyzer), known_values_(known_values) {} |
| |
| using Parent::VisitExpr_; |
| |
| PrimExpr VisitExpr_(const BufferLoadNode* op) override { |
| auto it = known_values_.find(op); |
| if (it != known_values_.end() && it->second) { |
| return it->second.value(); |
| } else { |
| return ffi::GetRef<PrimExpr>(op); |
| } |
| } |
| |
| const std::unordered_map<const BufferLoadNode*, ffi::Optional<PrimExpr>>& known_values_; |
| }; |
| |
| void BufferState::ApplyTouches(const ffi::Map<Buffer, ffi::Array<Var>>& axis_var_lookup, |
| const std::vector<BufferTouch>& touch_points, Analyzer* analyzer) { |
| std::vector<BufferTouch> new_knowns; |
| ffi::Map<Buffer, PrimExpr> keep_prior_known_at; |
| |
| for (auto& touch : touch_points) { |
| if (touch.touch_type == BufferTouch::AccessType::Read) { |
| continue; |
| } |
| |
| PrimExpr known_value = touch.value; |
| |
| PrimExpr predicate = touch.predicate && touch.AfterLoopIteration(); |
| auto regions = BufferRegionCollector::Collect(axis_var_lookup, constraints_, |
| {predicate, touch.value}, analyzer); |
| |
| for (const auto& region : regions) { |
| PrimExpr updated_predicate = BufferRegionValueReplacer::Apply( |
| region.known_values, region.region_predicate && predicate, analyzer); |
| |
| updated_predicate = SimplifyAsAndOfOrs(updated_predicate, analyzer); |
| PrimExpr updated_value = |
| BufferRegionValueReplacer::Apply(region.known_values, known_value, analyzer); |
| |
| if (!is_zero(updated_predicate)) { |
| if (auto it = keep_prior_known_at.find(touch.buffer); it != keep_prior_known_at.end()) { |
| keep_prior_known_at.Set(touch.buffer, (*it).second && !updated_predicate); |
| } else { |
| keep_prior_known_at.Set(touch.buffer, !updated_predicate); |
| } |
| |
| if (!HasBufferLoad(updated_value)) { |
| BufferTouch new_constraint{touch.buffer, updated_predicate, updated_value}; |
| new_knowns.push_back(new_constraint); |
| } |
| } |
| } |
| } |
| |
| if (keep_prior_known_at.size()) { |
| for (auto& constraint : constraints_) { |
| if (auto it = keep_prior_known_at.find(constraint.buffer); it != keep_prior_known_at.end()) { |
| constraint.predicate = SimplifyAsAndOfOrs(constraint.predicate && (*it).second, analyzer); |
| } |
| } |
| } |
| |
| if (new_knowns.size()) { |
| std::vector<bool> used(new_knowns.size(), false); |
| |
| for (auto& constraint : constraints_) { |
| PrimExpr expand_known_at = Bool(false); |
| |
| PrimExpr prev_value = constraint.value; |
| |
| for (size_t i = 0; i < new_knowns.size(); i++) { |
| if (new_knowns[i].buffer.same_as(constraint.buffer)) { |
| ffi::Optional<PrimExpr> overwritten_with = new_knowns[i].value; |
| if (overwritten_with && analyzer->CanProveEqual(prev_value, overwritten_with.value())) { |
| expand_known_at = |
| SimplifyAsAndOfOrs(expand_known_at || new_knowns[i].predicate, analyzer); |
| used[i] = true; |
| } |
| } |
| } |
| |
| if (!is_zero(expand_known_at)) { |
| constraint.predicate = |
| SimplifyAsAndOfOrs(constraint.predicate || expand_known_at, analyzer); |
| } |
| } |
| |
| for (size_t i = 0; i < new_knowns.size(); i++) { |
| if (!used[i]) { |
| constraints_.push_back(new_knowns[i]); |
| } |
| } |
| } |
| |
| constraints_.erase( |
| std::remove_if(constraints_.begin(), constraints_.end(), |
| [&](const auto& constraint) { return is_zero(constraint.predicate); }), |
| constraints_.end()); |
| } |
| |
| void BufferState::BackpropUnusedIndices(const ffi::Map<Buffer, ffi::Array<Var>>& axis_var_lookup, |
| const std::vector<BufferTouch>& touch_points, |
| Analyzer* analyzer) { |
| std::vector<BufferTouch> new_knowns; |
| ffi::Map<Buffer, PrimExpr> keep_prior_known_at; |
| |
| ffi::Map<Buffer, PrimExpr> regions_written; |
| ffi::Map<Buffer, PrimExpr> regions_read; |
| for (auto it = touch_points.rbegin(); it != touch_points.rend(); it++) { |
| const auto& touch = *it; |
| |
| ffi::Map<Buffer, PrimExpr>* to_update{nullptr}; |
| if (touch.touch_type == BufferTouch::AccessType::Write) { |
| to_update = ®ions_written; |
| |
| } else if (touch.touch_type == BufferTouch::AccessType::Read) { |
| to_update = ®ions_read; |
| } else { |
| continue; |
| } |
| |
| PrimExpr prev = to_update->Get(touch.buffer).value_or(Bool(false)); |
| PrimExpr new_predicate = touch.predicate && touch.BeforeLoopIteration(); |
| to_update->Set(touch.buffer, prev || new_predicate); |
| } |
| |
| auto update_map = [&](auto& map) { |
| ffi::Map<Buffer, PrimExpr> new_map; |
| for (auto [buffer, predicate] : map) { |
| new_map.Set(buffer, SimplifyAsAndOfOrs(predicate, analyzer)); |
| } |
| map = std::move(new_map); |
| }; |
| update_map(regions_written); |
| update_map(regions_read); |
| |
| // If buffer is already in used, widen the predicate |
| for (auto& prev_unused : constraints_) { |
| if (auto opt_predicate = regions_written.Get(prev_unused.buffer)) { |
| PrimExpr new_predicate = prev_unused.predicate || opt_predicate.value(); |
| prev_unused.predicate = SimplifyAsAndOfOrs(new_predicate, analyzer); |
| regions_written.erase(prev_unused.buffer); |
| } |
| } |
| |
| // Otherwise, add new "touch" to represent the unused values |
| for (auto [buffer, predicate] : regions_written) { |
| constraints_.push_back( |
| BufferTouch{buffer, predicate, tir::Call(buffer->dtype, builtin::undef(), {})}); |
| } |
| |
| // If buffer is read out, narrow the predicate |
| for (auto& prev_unused : constraints_) { |
| if (auto opt_pred = regions_read.Get(prev_unused.buffer)) { |
| PrimExpr predicate = opt_pred.value(); |
| prev_unused.predicate = SimplifyAsAndOfOrs(prev_unused.predicate && !predicate, analyzer); |
| } |
| } |
| |
| // Clean-up and remove any empty constraints |
| constraints_.erase( |
| std::remove_if(constraints_.begin(), constraints_.end(), |
| [](const auto& constraint) { return is_zero(constraint.predicate); }), |
| constraints_.end()); |
| } |
| |
| void BufferState::RemoveFreeParameters(const ffi::Map<Var, Range>& free_predicate_parameters, |
| Analyzer* analyzer) { |
| for (auto& known : constraints_) { |
| known.predicate = NarrowPredicateExpression(known.predicate, free_predicate_parameters); |
| known.predicate = SimplifyAsAndOfOrs(known.predicate, analyzer); |
| } |
| } |
| |
| bool BufferState::IsEquivalentTo(const BufferState& other, Analyzer* analyzer) const { |
| if (constraints_.size() != other.constraints_.size()) { |
| return false; |
| } |
| |
| for (size_t i = 0; i < constraints_.size(); i++) { |
| if (!constraints_[i].IsEquivalentTo(other.constraints_[i], analyzer)) { |
| return false; |
| } |
| } |
| |
| return true; |
| } |
| |
| ffi::Optional<ffi::Array<Var>> ControlFlowGraph::GetIndexVariables(const Buffer& buf) const { |
| if (auto it = axis_var_lookup_.find(buf); it != axis_var_lookup_.end()) { |
| return (*it).second; |
| } else { |
| return std::nullopt; |
| } |
| } |
| |
| ffi::Array<Var> ControlFlowGraph::GetIndexVariables(const Buffer& buf, |
| const ffi::Array<PrimExpr>& indices) { |
| if (auto it = axis_var_lookup_.find(buf); it != axis_var_lookup_.end()) { |
| return (*it).second; |
| } |
| |
| ffi::Array<Var> vars; |
| for (size_t i = 0; i < indices.size(); i++) { |
| std::stringstream ss; |
| ss << buf->name << "_axis_" << i; |
| vars.push_back(Var(ss.str(), indices[i].dtype().element_of())); |
| } |
| |
| axis_var_lookup_.Set(buf, vars); |
| return vars; |
| } |
| |
| void ControlFlowGraph::ForwardPropagateKnownValues(std::optional<size_t> flow_from) { |
| // Values to visit when searching. Using a std::set to |
| // preferentially visit nodes near the start of the control flow. |
| std::set<size_t> to_visit; |
| |
| if (flow_from.has_value()) { |
| to_visit.insert(flow_from.value()); |
| } else { |
| // Initiatize the locations to search from, propagating values |
| // forward from all locations that have a known value. |
| for (size_t i = 0; i < control_flow_.size(); i++) { |
| bool has_known_value = false; |
| for (const auto& touch : control_flow_[i].touch_points) { |
| if (!HasBufferLoad(touch.value)) { |
| has_known_value = true; |
| break; |
| } |
| } |
| |
| if (has_known_value) { |
| to_visit.insert(i); |
| } |
| } |
| } |
| |
| // Map from a block's index |
| std::unordered_map<size_t, size_t> visit_count_lookup; |
| |
| Analyzer analyzer; |
| analyzer.rewrite_simplify.SetMaximumRewriteSteps(max_simplification_steps_); |
| analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension( |
| arith::RewriteSimplifier::kTransitivelyProveInequalities | |
| arith::RewriteSimplifier::kConvertBooleanToAndOfOrs | |
| arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches)); |
| |
| analyzer.Bind(iterator_ranges_); |
| analyzer.Bind(free_predicate_parameters_); |
| |
| while (to_visit.size()) { |
| size_t visiting = *to_visit.begin(); |
| to_visit.erase(visiting); |
| |
| size_t num_previous_visits = visit_count_lookup[visiting]++; |
| |
| ControlFlowBlock& block = control_flow_[visiting]; |
| |
| // Step 1: Collect known values provided from each predecessor |
| block.known_at_block_start = [&]() -> BufferState { |
| if (num_previous_visits >= max_revisits_) { |
| return BufferState(); |
| } |
| |
| // Validate internal constraint. This should be true by |
| // construction, as ControlFlowGraphBuilder only builds graphs |
| // that have two or fewer predecessors. |
| ICHECK_LE(block.predecessors.size(), 2) |
| << "InternalError: Each block should have at most two predecessors. " |
| << "Graph constructed in ControlFlowGraphBuilder did not satisfy this constraint."; |
| |
| std::vector<BufferState> states; |
| for (const auto& pred : block.predecessors) { |
| const auto& pred_block = control_flow_[pred.index]; |
| BufferState state = pred_block.known_at_block_end; |
| state.Substitute(pred.var_remap, &analyzer); |
| states.push_back(state); |
| } |
| |
| if (std::all_of(block.predecessors.begin(), block.predecessors.end(), |
| [&](const auto& pred) { return visit_count_lookup[pred.index] == 0; })) { |
| // Predecessors, if any, are unvisited. |
| return {}; |
| } else if (block.predecessors.size() == 1) { |
| // Block has only a single predecessor |
| return states[0]; |
| } |
| |
| const auto& pred_a = block.predecessors[0]; |
| const auto& pred_b = block.predecessors[1]; |
| |
| auto& priors_a = states[0]; |
| auto& priors_b = states[1]; |
| |
| // During the first visit of a block, predecessor blocks may be |
| // unvisited, even though we preferentially visit earlier blocks |
| // first. (e.g. During the first visit of the start of a For |
| // loop, the end of the For loop has not yet been visited.) If |
| // this is the case, assume the best-case scenario that all |
| // knowns are consistent, and rely on a later visit to |
| // resolve/remove any conflicts. |
| if (visit_count_lookup[pred_a.index] == 0) { |
| return priors_b; |
| } else if (visit_count_lookup[pred_b.index] == 0) { |
| return priors_a; |
| } |
| |
| if (pred_a.post_condition && pred_b.post_condition) { |
| // The predicate can identify which predecessor block applies |
| // (e.g. i==0 for the first loop iteration, i>0 for remaining |
| // loop iterations). Therefore, we can use all buffer |
| // constraints, conditional on having come from the |
| // predecessor that provides it. |
| priors_a.AddCondition(pred_a.post_condition.value()); |
| priors_b.AddCondition(pred_b.post_condition.value()); |
| priors_a.Union(priors_b, &analyzer); |
| return priors_a; |
| } else { |
| // We don't know which predecessor applies. Therefore, the |
| // only buffer constraints that can be used are those that |
| // appear in both predecessors. |
| priors_a.Intersection(priors_b, &analyzer); |
| return priors_a; |
| } |
| }(); |
| |
| // Step 2: Collect knowns provided as a result of executing this block |
| auto post_state = [&]() { |
| if (num_previous_visits >= max_revisits_) { |
| return BufferState(); |
| } |
| auto post_state = block.known_at_block_start; |
| post_state.ApplyTouches(axis_var_lookup_, block.touch_points, &analyzer); |
| post_state.RemoveFreeParameters(free_predicate_parameters_, &analyzer); |
| return post_state; |
| }(); |
| |
| // Step 3: If any changes are made to the post knowns since the |
| // previous time we visited this block, mark the successor block |
| // as needing to be visited. |
| if (num_previous_visits == 0 || |
| !post_state.IsEquivalentTo(block.known_at_block_end, &analyzer)) { |
| block.known_at_block_end = std::move(post_state); |
| for (const auto& successor : block.successors) { |
| to_visit.insert(successor.index); |
| } |
| } |
| } |
| } |
| |
| void ControlFlowGraph::BackwardPropagateUnusedValues(std::optional<size_t> flow_from) { |
| // Values to visit when searching. Using a std::set to |
| // preferentially visit nodes near the end of the control flow. |
| std::set<size_t> to_visit; |
| |
| if (flow_from.has_value()) { |
| to_visit.insert(flow_from.value()); |
| } else { |
| // Initiatize the locations to search from, propagating values |
| // backward from anywhere that performs a write. |
| for (size_t i = 0; i < control_flow_.size(); i++) { |
| const auto& touch_points = control_flow_[i].touch_points; |
| bool performs_write = std::any_of( |
| touch_points.begin(), touch_points.end(), |
| [](const auto& touch) { return touch.touch_type == BufferTouch::AccessType::Write; }); |
| if (performs_write) { |
| to_visit.insert(i); |
| } |
| } |
| } |
| |
| // Map from a block's index |
| std::unordered_map<size_t, size_t> visit_count_lookup; |
| |
| Analyzer analyzer; |
| analyzer.rewrite_simplify.SetMaximumRewriteSteps(max_simplification_steps_); |
| analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension( |
| arith::RewriteSimplifier::kTransitivelyProveInequalities | |
| arith::RewriteSimplifier::kConvertBooleanToAndOfOrs | |
| arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches)); |
| |
| analyzer.Bind(iterator_ranges_); |
| analyzer.Bind(free_predicate_parameters_); |
| |
| while (to_visit.size()) { |
| size_t visiting = *to_visit.rbegin(); |
| to_visit.erase(visiting); |
| |
| size_t num_previous_visits = visit_count_lookup[visiting]++; |
| |
| ControlFlowBlock& block = control_flow_[visiting]; |
| |
| // Step 1: Collect known unused indices provided by each successor |
| block.unused_at_block_end = [&]() -> BufferState { |
| if (num_previous_visits >= max_revisits_) { |
| return BufferState(); |
| } |
| ICHECK_LE(block.successors.size(), 2) |
| << "Each block should have at most two successors, but block " << visiting |
| << " breaks this requirement"; |
| |
| std::vector<BufferState> states; |
| for (const auto& successor : block.successors) { |
| const auto& successor_block = control_flow_[successor.index]; |
| BufferState state = successor_block.unused_at_block_start; |
| state.Substitute(successor.var_remap, &analyzer); |
| states.push_back(state); |
| } |
| |
| if (std::all_of(block.successors.begin(), block.successors.end(), [&](const auto& successor) { |
| return visit_count_lookup[successor.index] == 0; |
| })) { |
| // Successors, if any, are unvisited. |
| return {}; |
| } else if (block.successors.size() == 1) { |
| // Block has only a single successor |
| return states[0]; |
| } |
| |
| const auto& successor_a = block.successors[0]; |
| const auto& successor_b = block.successors[1]; |
| |
| auto& post_a = states[0]; |
| auto& post_b = states[1]; |
| |
| // During the first visit of a block, successor blocks may be |
| // unvisited, even though we preferentially visit later blocks |
| // first. (e.g. During the first visit of the end of a For |
| // loop, the start of the For loop has not yet been visited.) |
| // If this is the case, assume the best-case scenario that all |
| // knowns are consistent, and rely on a later visit to |
| // resolve/remove any conflicts. |
| if (visit_count_lookup[successor_a.index] == 0) { |
| return post_b; |
| } else if (visit_count_lookup[successor_b.index] == 0) { |
| return post_a; |
| } |
| |
| if (successor_a.post_condition && successor_b.post_condition) { |
| // The predicate can identify which successor block applies |
| // (e.g. i==n-1 for the last loop iteration, i<n-1 for earlier |
| // loop iterations). Therefore, we can use all buffer |
| // constraints, conditional on having come from the |
| // successor that provides it. |
| post_a.AddCondition(successor_a.post_condition.value()); |
| post_b.AddCondition(successor_b.post_condition.value()); |
| post_a.Union(post_b, &analyzer); |
| return post_a; |
| } else { |
| // We don't know which successor applies. Therefore, the |
| // only buffer constraints that can be used are those that |
| // appear in both successors. |
| post_a.Intersection(post_b, &analyzer); |
| return post_a; |
| } |
| }(); |
| |
| // Step 2: Collect knowns provided as a result of executing this block |
| auto unused_at_block_start = [&]() { |
| if (num_previous_visits >= max_revisits_) { |
| return BufferState(); |
| } |
| auto prior_state = block.unused_at_block_end; |
| prior_state.BackpropUnusedIndices(axis_var_lookup_, block.touch_points, &analyzer); |
| prior_state.RemoveFreeParameters(free_predicate_parameters_, &analyzer); |
| return prior_state; |
| }(); |
| |
| // Step 3: If any changes are made to the post knowns since the |
| // previous time we visited this block, mark the successor block |
| // as needing to be visited. |
| if (num_previous_visits == 0 || |
| !unused_at_block_start.IsEquivalentTo(block.unused_at_block_start, &analyzer)) { |
| block.unused_at_block_start = std::move(unused_at_block_start); |
| for (const auto& pred : block.predecessors) { |
| to_visit.insert(pred.index); |
| } |
| } |
| } |
| } |
| |
| bool ControlFlowGraph::IsOverwrittenWithoutEffect(const tir::BufferStore& store, |
| const Stmt& context) const { |
| ffi::Optional<ffi::Array<Var>> index_variables = GetIndexVariables(store->buffer); |
| if (!index_variables) { |
| return false; |
| } |
| |
| auto it = control_flow_lookup_.find(context.get()); |
| ICHECK(it != control_flow_lookup_.end()) << "Context did not occur within analyzed statement:\n" |
| << context; |
| const auto& context_block = control_flow_[it->second]; |
| |
| auto [store_touch, free_params] = context_block.MakeBufferTouch( |
| store->buffer, index_variables.value(), store->indices, BufferTouch::AccessType::Write, |
| BufferLoad(store->buffer, store->indices)); |
| |
| Analyzer local_analyzer; |
| local_analyzer.Bind(free_predicate_parameters_); |
| local_analyzer.Bind(iterator_ranges_); |
| local_analyzer.Bind(free_params); |
| local_analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension( |
| arith::RewriteSimplifier::kTransitivelyProveInequalities | |
| arith::RewriteSimplifier::kConvertBooleanToAndOfOrs | |
| arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches)); |
| |
| PrimExpr predicate = store_touch.predicate && store_touch.AtLoopIteration(); |
| |
| predicate = SimplifyAsAndOfOrs(predicate, &local_analyzer); |
| |
| for (const auto& unused : context_block.unused_at_block_end.constraints_) { |
| if (store_touch.buffer.same_as(unused.buffer)) { |
| PrimExpr difference = SimplifyAsAndOfOrs(predicate && !unused.predicate, &local_analyzer); |
| if (is_zero(difference)) { |
| return true; |
| } |
| } |
| } |
| return false; |
| } |
| |
| PrimExpr ControlFlowGraph::SimplifyInContext(PrimExpr expr, const tir::Stmt& context, |
| Analyzer* analyzer) const { |
| size_t context_index = [&]() { |
| auto it = control_flow_lookup_.find(context.get()); |
| ICHECK(it != control_flow_lookup_.end()) |
| << "Context did not occur in the Stmt provided to BufferTouchPattern's constructor"; |
| return it->second; |
| }(); |
| |
| const auto& control_flow_block = control_flow_[context_index]; |
| |
| PrimExpr constraint = Bool(true); |
| for (const auto& known : non_buffer_assumptions_) { |
| constraint = constraint && known; |
| } |
| With<ConstraintContext> constraint_context(analyzer, constraint); |
| With<ConstraintContext> control_flow_scope(analyzer, control_flow_block.scope_predicate); |
| |
| expr = control_flow_block.known_at_block_start.SubstituteKnownBufferValues( |
| std::move(expr), axis_var_lookup_, analyzer); |
| |
| expr = analyzer->Simplify(std::move(expr)); |
| return expr; |
| } |
| |
| } // namespace tir |
| } // namespace tvm |