| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file auto_scheduler/search_policy/sketch_policy_rules.h |
| * \brief Rules defined to generate the sketches and initial sampled states in SketchPolicy. |
| */ |
| |
| #ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_RULES_H_ |
| #define TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_RULES_H_ |
| |
| #include <tvm/auto_scheduler/loop_state.h> |
| #include <tvm/auto_scheduler/search_task.h> |
| |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "utils.h" |
| |
| namespace tvm { |
| namespace auto_scheduler { |
| |
| class SketchPolicyNode; |
| |
| /********** Sketch Generation Rule **********/ |
| |
| /*! \brief The base class for derivation rules used in the sketch generation. */ |
| class SketchGenerationRule { |
| public: |
| /*! \brief Result enumeration of the condition function. */ |
| enum class ConditionKind : int { |
| /*! \brief Skip this rule and continue to try the next rules. */ |
| kSkip = 0, |
| /*! \brief Apply this rule and continue to try the next rules. */ |
| kApply = 1, |
| /*! \brief Apply this rule and skip the rest rules. */ |
| kApplyAndSkipRest = 2 |
| }; |
| |
| /*! |
| * \brief Condition check function of this rule. |
| * \param policy The SketchPolicyNode of this rule, some information may be used during |
| * the condition checking. |
| * \param state The original state to be checked. |
| * \param stage_id The index of the stage to process this condition check. |
| * \return The condition check result of this rule. |
| */ |
| virtual ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state, |
| int stage_id) const = 0; |
| |
| /*! |
| * \brief Apply function of this rule. |
| * \param policy The SketchPolicyNode of this rule, some information may be used during |
| * the rule applying. |
| * \param state The original state to apply this rule. |
| * \param stage_id The index of the next stage to apply this rule. |
| * \return The state after applying this rule, and index of the next stage. |
| */ |
| virtual std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, |
| const State& state, int stage_id) const = 0; |
| |
| /*! |
| * \brief Get the name of this rule. |
| * \return A string of the rule name. |
| */ |
| virtual std::string GetRuleName() const = 0; |
| }; |
| |
| #define DEFINE_SKETCH_GENERATION_RULE(rule_name) \ |
| class rule_name : public SketchGenerationRule { \ |
| public: \ |
| ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state, \ |
| int stage_id) const final; \ |
| std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state, \ |
| int stage_id) const final; \ |
| std::string GetRuleName() const final { return #rule_name; } \ |
| }; |
| |
| /*! \brief The rule that simply skips the current stage. It returns an unchanged state and move to |
| * the next stage. */ |
| DEFINE_SKETCH_GENERATION_RULE(RuleSkipStage); |
| |
| /*! \brief The rule that inlines simple elementwise ops. |
| * \note This rule only inlines the strictly inlineable stages. Stages marked as not strictly |
| * inlineable will have a chance to try different compute at location in InitPopulation later. |
| */ |
| DEFINE_SKETCH_GENERATION_RULE(RuleAlwaysInline); |
| |
| /*! \brief The rule that performs multi-level tiling. */ |
| DEFINE_SKETCH_GENERATION_RULE(RuleMultiLevelTiling); |
| |
| /*! \brief The rule that performs multi-level tiling and fuses later consumers. */ |
| DEFINE_SKETCH_GENERATION_RULE(RuleMultiLevelTilingWithFusion); |
| |
| /*! \brief The rule that adds a cache read stage. Mainly used for GPU cooperative fetching, |
| * Currently only support 1 to 1 match cache read. */ |
| DEFINE_SKETCH_GENERATION_RULE(RuleAddCacheRead); |
| |
| /*! \brief The rule that adds a cache write stage. */ |
| DEFINE_SKETCH_GENERATION_RULE(RuleAddCacheWrite); |
| |
| /*! \brief The rule that adds rfactor stage. */ |
| DEFINE_SKETCH_GENERATION_RULE(RuleAddRfactor); |
| |
| /*! \brief The rule that deals with compute ops that perform "fake reduction" with const tensors. |
| * This kind of op comes from winograd transformation. */ |
| DEFINE_SKETCH_GENERATION_RULE(RuleSimplifyComputeWithConstTensor); |
| |
| /*! \brief The rule that use cross thread reduction for GPU. */ |
| DEFINE_SKETCH_GENERATION_RULE(RuleCrossThreadReduction); |
| |
| /*! \brief Handle special cases in Winograd transformation for GPU. We need to change the compute |
| * location of the producers of compute ops that perform "fake reduction" with const tensors. */ |
| DEFINE_SKETCH_GENERATION_RULE(RuleSpecialComputeLocationGPU); |
| |
| /********** Init Population **********/ |
| |
| /*! \brief The base class for rules used to annotate the sketches to get the initial population. */ |
| class PopulationGenerationRule { |
| public: |
| /*! \brief Result enumeration of the apply function. */ |
| enum class ResultKind : int { kValid = 0, kInvalid = 1 }; |
| |
| /*! |
| * \brief Apply function of this rule. |
| * \param policy The SketchPolicyNode of this rule, some member may get changed during the |
| * rule applying. (e.g. random number generator) |
| * \param state The state to apply this rule, update inplace. |
| * \return The result of this rule, indicate if there's any valid state generated. |
| */ |
| virtual ResultKind Apply(SketchPolicyNode* policy, State* state, |
| std::mt19937* rand_gen) const = 0; |
| |
| /*! \brief The deconstructor */ |
| virtual ~PopulationGenerationRule() = default; |
| }; |
| |
| // A helper to define population initialization rules |
| #define DEFINE_INIT_POPULATION_RULE(rule_name) \ |
| class rule_name : public PopulationGenerationRule { \ |
| public: \ |
| ResultKind Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const final; \ |
| }; |
| |
| /*! \brief The rule that fills the incomplete SplitSteps. */ |
| DEFINE_INIT_POPULATION_RULE(InitFillTileSize); |
| |
| /*! \brief The rule that randomly changes the computation location for some stages that do not |
| * need tiling and are not strictly inlineable(e.g. data padding). */ |
| DEFINE_INIT_POPULATION_RULE(InitChangeComputeLocation); |
| |
| /*! \brief The rule that annotates parallel for CPU. */ |
| DEFINE_INIT_POPULATION_RULE(InitParallel); |
| |
| /*! \brief The rule that annotates unroll. */ |
| DEFINE_INIT_POPULATION_RULE(InitUnroll); |
| |
| /*! \brief The rule that annotates vectorization. */ |
| DEFINE_INIT_POPULATION_RULE(InitVectorization); |
| |
| /*! \brief The rule that annotates thread binding for GPU. */ |
| DEFINE_INIT_POPULATION_RULE(InitThreadBind); |
| |
| /********** Mutation **********/ |
| |
| /*! \brief The base class for mutation rules used in the evolutionary search. */ |
| class PopulationMutationRule : public PopulationGenerationRule { |
| public: |
| /* \brief The constructor |
| * \param selection_weight the probabiliy of applying this rule is |
| * proportional to this weight |
| */ |
| explicit PopulationMutationRule(double selection_weight) : weight(selection_weight) {} |
| |
| /* \brief The weight of this rule */ |
| double weight; |
| }; |
| |
| // A helper to define mutation rules used in the evolutionary search |
| #define DEFINE_MUTATE_POPULATION_RULE(rule_name) \ |
| class rule_name : public PopulationMutationRule { \ |
| public: \ |
| explicit rule_name(double weight) : PopulationMutationRule(weight) {} \ |
| ResultKind Apply(SketchPolicyNode* policy, State* state, std::mt19937* rand_gen) const final; \ |
| }; |
| |
| /*! \brief The rule that mutates tile size by randomly dividing a tile size by a factor |
| and multipling it to another tile size. */ |
| DEFINE_MUTATE_POPULATION_RULE(MutateTileSize); |
| |
| /*! \brief The rule that mutates the number of fused outer iterators annotated by parallel. */ |
| DEFINE_MUTATE_POPULATION_RULE(MutateParallel); |
| |
| /*! \brief The rule that randomly changes the computation location for some stages that do not |
| * need tiling and are not strictly inlineable(e.g. data padding). */ |
| DEFINE_MUTATE_POPULATION_RULE(MutateComputeLocation); |
| |
| /*! \brief The rule that mutates the value of a randomly selected auto unroll pragma step. */ |
| DEFINE_MUTATE_POPULATION_RULE(MutateAutoUnroll); |
| |
| } // namespace auto_scheduler |
| } // namespace tvm |
| |
| #endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_RULES_H_ |