| /* | 
 |  * Licensed to the Apache Software Foundation (ASF) under one | 
 |  * or more contributor license agreements.  See the NOTICE file | 
 |  * distributed with this work for additional information | 
 |  * regarding copyright ownership.  The ASF licenses this file | 
 |  * to you under the Apache License, Version 2.0 (the | 
 |  * "License"); you may not use this file except in compliance | 
 |  * with the License.  You may obtain a copy of the License at | 
 |  * | 
 |  *   http://www.apache.org/licenses/LICENSE-2.0 | 
 |  * | 
 |  * Unless required by applicable law or agreed to in writing, | 
 |  * software distributed under the License is distributed on an | 
 |  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | 
 |  * KIND, either express or implied.  See the License for the | 
 |  * specific language governing permissions and limitations | 
 |  * under the License. | 
 |  */ | 
 | #include "../utils.h" | 
 |  | 
 | namespace tvm { | 
 | namespace meta_schedule { | 
 |  | 
 | /*! \brief The type of inline to be performed on a specific block */ | 
 | enum class InlineType : int32_t { | 
 |   /*! \brief No inline opportunity */ | 
 |   kNoInline = 0, | 
 |   /*! \brief Inline the block into its consumer */ | 
 |   kInlineIntoConsumer = 1, | 
 |   /*! \brief Inline the block into its producer */ | 
 |   kInlineIntoProducer = 2, | 
 | }; | 
 |  | 
 | /*! \brief The rule that inlines spatial blocks if it satisfies some conditions. */ | 
 | class AutoInlineNode : public ScheduleRuleNode { | 
 |  public: | 
 |   /*! \brief Checks if the specific block should be inlined */ | 
 |   inline InlineType CheckInline(const tir::Schedule& sch, const tir::BlockRV& block_rv); | 
 |  | 
 |   // Inherited from ScheduleRuleNode | 
 |   void InitializeWithTuneContext(const TuneContext& context) final {} | 
 |  | 
 |   // Inherited from ScheduleRuleNode | 
 |   Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { | 
 |     InlineType inline_type = CheckInline(sch, block_rv); | 
 |     if (inline_type == InlineType::kInlineIntoConsumer) { | 
 |       sch->ComputeInline(block_rv); | 
 |     } else if (inline_type == InlineType::kInlineIntoProducer) { | 
 |       sch->ReverseComputeInline(block_rv); | 
 |     } | 
 |     return {sch}; | 
 |   } | 
 |  | 
 |  public: | 
 |   /*! \brief If allows to inline a block into its producer */ | 
 |   bool into_producer; | 
 |   /*! \brief If allows to inline a block into its consumer */ | 
 |   bool into_consumer; | 
 |   /*! \brief Always inline constant tensors */ | 
 |   bool inline_const_tensor; | 
 |   /*! \brief Always disallow if-then-else-like constructs */ | 
 |   bool disallow_if_then_else; | 
 |   /*! \brief Always require the read-to-write mapping to be injective to do auto inline */ | 
 |   bool require_injective; | 
 |   /*! \brief Always require the read-to-write mapping to be ordered to do auto inline */ | 
 |   bool require_ordered; | 
 |   /*! \brief The operators that are disallowed in auto inline */ | 
 |   Array<Op> disallow_op; | 
 |  | 
 |   void VisitAttrs(tvm::AttrVisitor* v) { | 
 |     v->Visit("into_producer", &into_producer); | 
 |     v->Visit("into_consumer", &into_consumer); | 
 |     v->Visit("inline_const_tensor", &inline_const_tensor); | 
 |     v->Visit("disallow_if_then_else", &disallow_if_then_else); | 
 |     v->Visit("require_injective", &require_injective); | 
 |     v->Visit("require_ordered", &require_ordered); | 
 |     v->Visit("disallow_op", &disallow_op); | 
 |   } | 
 |  | 
 |   static constexpr const char* _type_key = "meta_schedule.AutoInline"; | 
 |   TVM_DECLARE_FINAL_OBJECT_INFO(AutoInlineNode, ScheduleRuleNode); | 
 | }; | 
 |  | 
 | inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, | 
 |                                               const tir::BlockRV& block_rv) { | 
 |   using namespace tvm::tir; | 
 |   StmtSRef block_sref = sch->GetSRef(block_rv); | 
 |   ScheduleState state = sch->state(); | 
 |   const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); | 
 |   BlockRealize realize = GetBlockRealize(state, block_sref); | 
 |   // Cond 1. The block has only one write buffer | 
 |   if (block->writes.size() != 1) { | 
 |     return InlineType::kNoInline; | 
 |   } | 
 |   // Cond 2. For a block that generates a constant tensor, ignore all other conditions | 
 |   if (inline_const_tensor && block->reads.empty()) { | 
 |     return InlineType::kInlineIntoConsumer; | 
 |   } | 
 |   // Cond 3. The block doesn't contain any disallowed operators | 
 |   if (!disallow_op.empty() && HasOp(realize, disallow_op)) { | 
 |     return InlineType::kNoInline; | 
 |   } | 
 |   // Cond 4. The block doesn't have any if-then-else-like constructs | 
 |   if (disallow_if_then_else && HasIfThenElse(realize)) { | 
 |     return InlineType::kNoInline; | 
 |   } | 
 |   // Cond 5. The mapping from read indices to write indices are injective and ordered | 
 |   if (require_injective || require_ordered) { | 
 |     const BufferRegion& write_region = block->writes[0]; | 
 |     for (const BufferRegion& read_region : block->reads) { | 
 |       bool injective, ordered; | 
 |       auto _ = std::ignore; | 
 |       std::tie(/*exists=*/_, /*surjective=*/_, injective, ordered, /*no_const_read=*/_, | 
 |                /*no_shift_read=*/_) = AnalyzeReadWritePattern(read_region, write_region); | 
 |       if (require_injective && injective == false) { | 
 |         return InlineType::kNoInline; | 
 |       } | 
 |       if (require_ordered && ordered == false) { | 
 |         return InlineType::kNoInline; | 
 |       } | 
 |     } | 
 |   } | 
 |   // Last cond: Check inline into the consumers or the spatial producer | 
 |   tir::StmtSRef scope_block = tir::GetScopeRoot(sch->state(), block_sref, | 
 |                                                 /*require_stage_pipeline=*/false); | 
 |   if (into_consumer) { | 
 |     Array<tir::StmtSRef> consumer_srefs = GetConsumers(state, block_sref); | 
 |     if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) { | 
 |       return InlineType::kInlineIntoConsumer; | 
 |     } | 
 |   } | 
 |   if (into_producer) { | 
 |     Array<tir::StmtSRef> producer_srefs = GetProducers(state, block_sref); | 
 |     if (producer_srefs.size() == 1 && | 
 |         tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) && | 
 |         CanReverseComputeInline(state, block_sref)) { | 
 |       return InlineType::kInlineIntoProducer; | 
 |     } | 
 |   } | 
 |   return InlineType::kNoInline; | 
 | } | 
 |  | 
 | ScheduleRule ScheduleRule::AutoInline(bool into_producer,          // | 
 |                                       bool into_consumer,          // | 
 |                                       bool inline_const_tensor,    // | 
 |                                       bool disallow_if_then_else,  // | 
 |                                       bool require_injective,      // | 
 |                                       bool require_ordered,        // | 
 |                                       Optional<Array<String>> disallow_op) { | 
 |   ObjectPtr<AutoInlineNode> n = make_object<AutoInlineNode>(); | 
 |   n->into_producer = into_producer; | 
 |   n->into_consumer = into_consumer; | 
 |   n->inline_const_tensor = inline_const_tensor; | 
 |   n->disallow_if_then_else = disallow_if_then_else; | 
 |   n->require_injective = require_injective; | 
 |   n->require_ordered = require_ordered; | 
 |   n->disallow_op.clear(); | 
 |   if (disallow_op.defined()) { | 
 |     Array<String> op_names = disallow_op.value(); | 
 |     n->disallow_op.reserve(op_names.size()); | 
 |     for (const String& op_name : op_names) { | 
 |       n->disallow_op.push_back(Op::Get(op_name)); | 
 |     } | 
 |   } | 
 |   return ScheduleRule(n); | 
 | } | 
 |  | 
 | TVM_REGISTER_NODE_TYPE(AutoInlineNode); | 
 | TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline") | 
 |     .set_body_typed(ScheduleRule::AutoInline); | 
 |  | 
 | }  // namespace meta_schedule | 
 | }  // namespace tvm |