blob: 2fc1352677cb97c96d2502d5088d63e9c9fa2156 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"
namespace tvm {
namespace meta_schedule {
class AddRFactorNode : public ScheduleRuleNode {
public:
// Inherited from ScheduleRuleNode
void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(context->target.defined());
Target target = context->target.value();
this->max_parallel_basic_ = GetTargetNumCores(target);
if (this->max_jobs_per_core != -1) {
this->max_parallel_extent_ = max_parallel_basic_ * max_jobs_per_core;
}
}
// Inherited from ScheduleRuleNode
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv);
// Inherited from ScheduleRuleNode
ScheduleRule Clone() const final {
ObjectPtr<AddRFactorNode> n = make_object<AddRFactorNode>(*this);
return ScheduleRule(n);
}
public:
/*!
* \brief The maximum number of jobs to be launched per core.
* It sets the uplimit of parallelism, i.e. `num_cores * max_jobs_per_core`.
* Use -1 to disable parallelism.
*/
int max_jobs_per_core;
/*! \brief The maximum size of the innermost factor */
int max_innermost_factor;
/*! \brief The number of uplimit of parallelism. */
int max_parallel_extent_;
/*! \brief The number of cores. */
int max_parallel_basic_;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("max_jobs_per_core", &max_jobs_per_core);
v->Visit("max_innermost_factor", &max_innermost_factor);
// `max_parallel_extent_` is not visited
// `max_parallel_basic_` is not visited
}
static constexpr const char* _type_key = "meta_schedule.AddRFactor";
TVM_DECLARE_FINAL_OBJECT_INFO(AddRFactorNode, ScheduleRuleNode);
};
ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core,
Optional<Integer> max_innermost_factor) {
ObjectPtr<AddRFactorNode> n = make_object<AddRFactorNode>();
n->max_jobs_per_core = max_jobs_per_core;
n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
n->max_parallel_extent_ = -1;
n->max_parallel_basic_ = -1;
return ScheduleRule(n);
}
Array<tir::Schedule> AddRFactorNode::Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) {
tir::StmtSRef block_sref = sch->GetSRef(block_rv);
if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_parallel_extent_,
max_parallel_basic_)) {
return {sch};
}
// Make a copy of the original schedule.
tir::Schedule ori_sch = sch->Copy();
ori_sch->Seed(sch->ForkSeed());
// Reorder the loop axes if reduction loops are not innermost.
// After the reordering, fuse all the reduction loops.
size_t num_spatial_loops;
tir::LoopRV fused_reduce_loop;
ReorderAndFuseReductionLoops(sch, block_rv, &fused_reduce_loop, &num_spatial_loops);
// Split the fused reduction loop.
Array<tir::ExprRV> factors = sch->SamplePerfectTile(fused_reduce_loop, 2, max_innermost_factor);
Array<tir::LoopRV> split_loops = sch->Split(fused_reduce_loop, {factors.begin(), factors.end()});
Array<tir::Schedule> res;
for (const tir::LoopRV& split_loop : split_loops) {
tir::Schedule sch_tmp = sch->Copy();
sch_tmp->Seed(sch->ForkSeed());
try {
const tir::BlockRV& block_rf = sch_tmp->RFactor(split_loop, num_spatial_loops);
Array<tir::LoopRV> axes = sch_tmp->GetLoops(block_rf);
ICHECK_GT(axes.size(), num_spatial_loops);
// Annotate that the rfactor block, which is now the producer of the original block, needs to
// be considered by the rule Random-Compute-Location.
sch_tmp->Annotate(block_rv, tir::attr::meta_schedule_random_compute_producer, Integer(1));
res.push_back(sch_tmp);
} catch (const tvm::runtime::Error& e) {
}
}
res.push_back(ori_sch);
return res;
}
TVM_REGISTER_NODE_TYPE(AddRFactorNode);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAddRFactor")
.set_body_typed(ScheduleRule::AddRFactor);
} // namespace meta_schedule
} // namespace tvm