blob: e60fdf5b9d2b4a2399534e0d973eacd33ce1bc16 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "trace_apply.h"
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <optional>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "../tir/schedule/analysis.h"
#include "utils.h"
namespace tvm {
namespace meta_schedule {
using namespace tir;
// Returns true if b1 is an ancestor of b2
bool IsAncestor(BlockRV b1, BlockRV b2, Schedule sch) {
if (sch->Get(b1)->name_hint == sch->Get(b2)->name_hint) {
return true;
}
for (auto prod : sch->GetProducers(b2)) {
if (IsAncestor(b1, prod, sch)) return true;
}
return false;
}
// Inline or reverse inline spatial blocks after the anchor block
void InlinePostBlocks(Schedule sch, Trace anchor_trace, Target target) {
static auto kind_get_block = InstructionKind::Get("GetBlock");
// We let blocks whose names are referenced in the anchor trace be scheduled by the anchor trace.
// We record such block names to avoid inlining them here.
std::unordered_set<std::string> get_block_names;
for (const auto& inst : anchor_trace->insts) {
if (inst->kind.same_as(kind_get_block)) {
auto block_name = Downcast<String>(inst->attrs[0]);
ICHECK(block_name.defined());
get_block_names.insert(block_name);
}
}
auto anchor_block = FindAnchorBlock(sch->mod());
std::vector<std::string> inline_todos;
std::optional<int> last_block_idx{std::nullopt};
for (auto name : GetBlockNames(sch->mod())) {
auto block = sch->GetBlock(name);
if (anchor_block) {
auto anchor_block_rv = sch->GetBlock(anchor_block->name_hint);
if (IsAncestor(block, anchor_block_rv, sch)) continue;
}
// Spatial blocks which are not referenced in the anchor trace will be inlined here.
auto block_sref = sch->GetSRef(block);
if (IsSpatial(block_sref) && !get_block_names.count(name)) {
StmtSRef scopeRoot =
(name != "root") ? GetScopeRoot(sch->state(), block_sref, false) : block_sref;
if (IsOutputBlock(sch->state(), block_sref, scopeRoot)) {
last_block_idx = inline_todos.size();
}
inline_todos.push_back(name);
}
}
if (last_block_idx) {
// The last block can only be reverse compute inlined. We make sure to inline all
// producer blocks of the last block beforehand so that reverse compute inline can succeed.
std::swap(inline_todos[*last_block_idx], inline_todos.back());
}
auto inline_rule = GetDefaultAutoInline(target->kind->name);
for (auto name : inline_todos) {
inline_rule->Apply(sch, sch->GetBlock(name));
}
}
// Apply instructions from the anchor trace to the target schedule, and returns blocks
// that remain unscheduled.
std::vector<BlockRV> ApplyAnchorTrace(Schedule sch, Trace anchor_trace) {
static auto kind_get_child_blocks = InstructionKind::Get("GetChildBlocks");
static auto kind_get_block = InstructionKind::Get("GetBlock");
static auto kind_compute_inline = InstructionKind::Get("ComputeInline");
static auto kind_reverse_compute_inline = InstructionKind::Get("ReverseComputeInline");
const auto block_names_orig = GetBlockNames(sch->mod());
const auto sch_orig = sch->Copy();
std::unordered_map<const Object*, const Object*> rv_map;
// Blocks and loops that appear in the anchor trace but are not part of the target schedule.
std::unordered_set<BlockRV, ObjectHash, ObjectEqual> foreign_blocks;
std::unordered_set<LoopRV, ObjectHash, ObjectEqual> foreign_loops;
// Instructions in the anchor trace can be applied only if all inputs are part of the target
// schedule.
auto is_inst_applicable = [&foreign_blocks, &foreign_loops](Instruction inst) {
for (auto input : inst->inputs) {
if (!input.defined()) continue;
if ((input->IsInstance<BlockRVNode>() && foreign_blocks.count(Downcast<BlockRV>(input))) ||
(input->IsInstance<LoopRVNode>() && foreign_loops.count(Downcast<LoopRV>(input)))) {
return false;
}
}
return true;
};
for (const auto& inst : anchor_trace->insts) {
if (!is_inst_applicable(inst)) {
// If we find an instruction that is not applicable, its outputs are recorded as "foreign"
// to the target schedule.
for (auto output : inst->outputs) {
if (output->IsInstance<BlockRVNode>()) {
foreign_blocks.insert(Downcast<BlockRV>(output));
} else if (output->IsInstance<LoopRVNode>()) {
foreign_loops.insert(Downcast<LoopRV>(output));
}
}
continue;
}
Array<ObjectRef> inputs = TranslateInputRVs(inst->inputs, rv_map);
if (inst->kind.same_as(kind_get_block) && !HasBlock(sch, Downcast<String>(inst->attrs[0]))) {
// The anchor trace does get_block on a block that is not part of the target schedule.
auto block = Downcast<BlockRV>(inst->outputs[0]);
foreign_blocks.insert(block);
continue;
} else if (inst->kind.same_as(kind_reverse_compute_inline)) {
// The anchor trace does reverse_compute_inline on a block, but the block with the same name
// in the target schedule cannot be reverse compute inline-ed.
// In such cases, it should be possible to apply compute_inline instead.
auto block = Downcast<BlockRV>(inputs[0]);
auto block_sref = sch->GetSRef(block);
if (!CanReverseComputeInline(sch->state(), block_sref)) {
ICHECK(CanComputeInline(sch->state(), block_sref));
sch->ComputeInline(block);
continue;
}
} else if (inst->kind.same_as(kind_compute_inline)) {
// Similar to the reverse_compute_inline case above.
auto block = Downcast<BlockRV>(inputs[0]);
auto block_sref = sch->GetSRef(block);
auto state = sch->state();
if (!CanComputeInline(state, block_sref)) {
ICHECK(IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false)))
<< "If a spatial block cannot be inlined, it should be the output block";
if (CanReverseComputeInline(sch->state(), block_sref)) {
sch->ReverseComputeInline(block);
}
continue;
}
}
Optional<ObjectRef> decision = anchor_trace->GetDecision(inst);
Array<ObjectRef> outputs = inst->kind->f_apply_to_schedule(sch, inputs, inst->attrs, decision);
if (inst->kind.same_as(kind_get_child_blocks)) {
// We want to allow a trace generated for a single conv2d block to be applied to
// conv2d -> elemwise blocks, where two conv2d are the same workload.
// GetChildBlocks returns a different number of blocks for the two cases above, which
// violates the assumption made by TranslateAddOutputRVs: old_outputs.size() ==
// new_outputs.size(). We workaround this problem by assuming that the prefix of the "new"
// outputs matches with the "old" outputs, and truncating the new outputs accordingly.
ICHECK(inst->outputs.size() <= outputs.size());
TranslateAddOutputRVs(
inst->outputs, Array<ObjectRef>(outputs.begin(), outputs.begin() + inst->outputs.size()),
&rv_map);
} else {
TranslateAddOutputRVs(inst->outputs, outputs, &rv_map);
}
}
auto is_scheduled = [=](const std::string& block_name) {
auto loops = sch->GetLoops(sch->GetBlock(block_name));
auto loops_orig = sch_orig->GetLoops(sch_orig->GetBlock(block_name));
if (loops.size() != loops_orig.size()) {
return true;
}
for (size_t i = 0; i < loops.size(); ++i) {
auto loop = sch->Get(loops[i]);
auto loop_orig = sch_orig->Get(loops_orig[i]);
if (loop->kind != loop_orig->kind) {
return true;
}
}
return false;
};
const auto block_names_now = GetBlockNames(sch->mod());
std::vector<BlockRV> unscheduled_blocks;
for (auto name : block_names_orig) {
if (block_names_now.count(name) && name != "root" && !is_scheduled(name)) {
unscheduled_blocks.push_back(sch->GetBlock(name));
}
}
return unscheduled_blocks;
}
void ScheduleUsingAnchorTrace(Schedule sch, const Trace& anchor_trace, const tvm::Target& target) {
InlinePostBlocks(sch, anchor_trace, target);
auto unscheduled_blocks = ApplyAnchorTrace(sch, anchor_trace);
ICHECK(unscheduled_blocks.size() <= 1)
<< "All blocks should have been scheduled or only one (fused) spatial block can remain "
"unscheduled at this point.";
if (unscheduled_blocks.empty()) {
// All blocks have already been scheduled.
return;
}
auto last_block = unscheduled_blocks[0];
auto last_block_producers = sch->GetProducers(last_block);
if (last_block_producers.size() == 1 && IsSpatial(sch->GetSRef(last_block_producers[0]))) {
// Inline into the cache write stage
sch->ReverseComputeInline(last_block);
} else if (target->kind->name == "llvm" || target->kind->name == "hexagon") {
sch->Parallel(sch->Fuse(sch->GetLoops(last_block)));
} else if (IsGPUTarget(target->kind->name)) {
auto max_threads_per_block = target->GetAttr<Integer>("max_threads_per_block");
ICHECK(max_threads_per_block.defined())
<< "ValueError: missing attribute `max_threads_per_block` in the target";
auto auto_bind_rule =
ScheduleRule::AutoBind(/*max_threadblocks=*/256,
/*thread_extents*/ Array<Integer>{32, 64, 128, 256, 512, 1024},
max_threads_per_block.value()->value);
auto_bind_rule->Apply(sch, last_block);
}
}
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleUsingAnchorTrace")
.set_body_typed(ScheduleUsingAnchorTrace);
} // namespace meta_schedule
} // namespace tvm