blob: c3c764fc8e2becd254a153e77990a5d090772cbe [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file auto_scheduler/loop_state.cc
* \brief An lightweight IR (intermediate representation) for loop structures.
* see auto_scheduler/loop_state.h for more explanation.
*/
#include <tvm/auto_scheduler/compute_dag.h>
#include <tvm/auto_scheduler/loop_state.h>
#include <tvm/auto_scheduler/transform_step.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <utility>
#include "utils.h"
namespace tvm {
namespace auto_scheduler {
TVM_REGISTER_OBJECT_TYPE(StepNode);
TVM_REGISTER_NODE_TYPE(StageNode);
TVM_REGISTER_NODE_TYPE(StateNode);
TVM_REGISTER_NODE_TYPE(IteratorNode);
/********** Iterator **********/
Iterator::Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation,
const std::vector<Iterator>* orig_iters) {
auto node = make_object<IteratorNode>();
node->name = std::move(name);
node->range = std::move(range);
node->iter_kind = iter_kind;
node->annotation = annotation;
if (orig_iters != nullptr) {
node->orig_iters = *orig_iters;
}
data_ = std::move(node);
}
/********** Stage **********/
Stage::Stage(te::Operation op) {
auto node = make_object<StageNode>();
if (op->IsInstance<te::ComputeOpNode>()) {
node->op_type = StageKind::kCompute;
auto* pop = op.as<te::ComputeOpNode>();
for (const auto& axis : pop->axis) {
node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom,
IteratorKind::kSpatial, IteratorAnnotation::kNone));
}
for (const auto& axis : pop->reduce_axis) {
node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom,
IteratorKind::kReduction, IteratorAnnotation::kNone));
}
} else if (op->IsInstance<te::PlaceholderOpNode>()) {
node->op_type = StageKind::kPlaceholder;
} else {
LOG(FATAL) << "Unsupported operator type" << op->_type_key;
}
node->compute_at = ComputeAtKind::kRoot;
node->op = std::move(op);
node->attrs.auto_unroll_max_step = 0;
node->attrs.storage_offset = 0;
data_ = std::move(node);
}
Stage::Stage(te::Operation op, StageKind op_type, const Array<Iterator>& iters,
ComputeAtKind compute_at, StageAttributes attrs) {
auto node = make_object<StageNode>();
node->op = std::move(op);
node->op_type = op_type;
node->iters = iters;
node->compute_at = compute_at;
node->attrs = attrs;
data_ = std::move(node);
}
/********** AttachMap **********/
void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) {
AttachMapNode* pnode = CopyOnWrite();
// Delete the current entry of this stage
DeleteStageEntry(pnode, stage_id);
// Store the new stage/iterator relations to map
IterKey iter_key(target_stage_id, target_iter_id);
pnode->stage_to_attach_iter[stage_id] = iter_key;
pnode->iter_to_attached_stages[iter_key].push_back(stage_id);
}
void AttachMap::DeleteStage(int stage_id) {
AttachMapNode* pnode = CopyOnWrite();
// Delete the original stage entry
DeleteStageEntry(pnode, stage_id);
}
void AttachMap::UpdateIters(const std::vector<IterKey>& original_iters,
const std::vector<IterKey>& new_iters) {
CHECK_EQ(original_iters.size(), new_iters.size());
AttachMapNode* pnode = CopyOnWrite();
std::unordered_map<IterKey, std::vector<StageKey>> new_iter_to_attached_stages;
for (size_t i = 0; i < original_iters.size(); ++i) {
auto entry = pnode->iter_to_attached_stages.find(original_iters[i]);
// We get <IterKey, std::vector<StageKey>> from this map
if (entry == pnode->iter_to_attached_stages.end()) {
// Skip if this iterator does not have any attach relations
continue;
}
// Update the attaching target of an stage to the new iter in `stage_to_attach_iter`
for (const auto& s : entry->second) {
pnode->stage_to_attach_iter[s] = new_iters[i];
}
// Remove the original iterator relation from `iter_to_attached_stages` and add the new
// iterator to it
std::vector<int> attached_stages = std::move(entry->second);
pnode->iter_to_attached_stages.erase(entry);
new_iter_to_attached_stages[new_iters[i]] = std::move(attached_stages);
}
// Update new entries
for (auto& it : new_iter_to_attached_stages) {
pnode->iter_to_attached_stages[it.first] = std::move(it.second);
}
}
void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) {
auto old_entry = pnode->stage_to_attach_iter.find(stage_id);
// We get <StageKey, IterKey> from this map
if (old_entry != pnode->stage_to_attach_iter.end()) {
// Delete the stage in `iter_to_attached_stages`, if the corresponding iterator does not have
// any attached stage, delete this iterm too
auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second);
// We get <IterKey, std::vector<StageKey>> from this map
FindAndDeleteItem(&entry2->second, stage_id);
if (entry2->second.size() == 0) {
pnode->iter_to_attached_stages.erase(entry2);
}
// Delete the stage in `stage_to_attach_iter`
pnode->stage_to_attach_iter.erase(old_entry);
}
}
AttachMap AttachMap::ApplyStageIdOffset(int start_id, int offset) const {
AttachMap map = AttachMap(make_object<AttachMapNode>());
auto pmap = map.CopyOnWrite();
for (const auto& x : operator->()->stage_to_attach_iter) {
auto key = x.first;
if (key >= start_id) {
key += offset;
}
auto value = x.second;
if (value.first >= start_id) {
value.first += offset;
}
pmap->stage_to_attach_iter.insert(std::make_pair(key, value));
}
for (const auto& x : operator->()->iter_to_attached_stages) {
auto key = x.first;
if (key.first >= start_id) {
key.first += offset;
}
auto value = x.second;
for (auto& i : value) {
if (i >= start_id) {
i += offset;
}
}
pmap->iter_to_attached_stages.insert(std::make_pair(key, value));
}
return map;
}
/********** State **********/
State::State(const Array<te::Operation>& ops) {
auto node = make_object<StateNode>();
for (const auto& op : ops) {
node->stages.push_back(Stage(op));
}
node->attach_map = AttachMap(make_object<AttachMapNode>());
node->concrete = true;
data_ = std::move(node);
}
/********** Schedule primitives apis for state **********/
Iterator State::bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type) {
const Stage& stage = operator->()->stages[stage_id];
if (thread_type < IteratorAnnotation::kVThread || thread_type > IteratorAnnotation::kThreadZ) {
LOG(FATAL) << "thread_type error, valid: kVThread, kBlockX, kBlockY, "
<< "kThreadX, kThreadY, kBlockZ, kThreadZ";
}
AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), thread_type);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this);
}
Iterator State::parallel(int stage_id, const Iterator& it) {
const Stage& stage = operator->()->stages[stage_id];
AnnotationStep step =
AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kParallel);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this);
}
Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) {
const Stage& stage = operator->()->stages[stage_id];
// Don't unroll if the extent is larger than max_unroll
if (max_unroll != -1 && it->range.defined()) {
if (auto imm = it->range->extent.as<IntImmNode>()) {
if (imm->value > max_unroll) {
return it;
}
}
}
AnnotationStep step =
AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kUnroll);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this);
}
Iterator State::vectorize(int stage_id, const Iterator& it) {
const Stage& stage = operator->()->stages[stage_id];
AnnotationStep step =
AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kVectorize);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this);
}
Iterator State::fuse(int stage_id, const Array<Iterator>& iters) {
const Stage& stage = operator->()->stages[stage_id];
Array<Integer> indices;
GetIndices(stage->iters, iters, &indices);
FuseStep step = FuseStep(stage_id, indices);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this);
}
void State::pragma(int stage_id, const Iterator& it, const String& pragma_type) {
const Stage& stage = operator->()->stages[stage_id];
PragmaStep step = PragmaStep(stage_id, GetIndex(stage->iters, it), pragma_type);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this);
}
void State::reorder(int stage_id, const Array<Iterator>& order) {
const Stage& stage = operator->()->stages[stage_id];
CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators "
<< "should be specified";
Array<Integer> after_ids;
GetIndices(stage->iters, order, &after_ids);
ReorderStep step = ReorderStep(stage_id, after_ids);
CopyOnWrite()->transform_steps.push_back(step);
step->ApplyToState(this);
}
Array<Iterator> State::split(int stage_id, const Iterator& it,
const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
const Stage& stage = operator->()->stages[stage_id];
SplitStep step =
SplitStep(stage_id, GetIndex(stage->iters, it),
it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this);
}
Array<Iterator> State::follow_split(int stage_id, const Iterator& it, int src_step_id,
int n_split) {
const Stage& stage = operator->()->stages[stage_id];
FollowSplitStep step =
FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this);
}
Array<Iterator> State::follow_fused_split(int stage_id, const Iterator& it,
const Array<Integer>& src_step_ids, int level,
bool factor_or_nparts) {
const Stage& stage = operator->()->stages[stage_id];
FollowFusedSplitStep step = FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it),
src_step_ids, level, factor_or_nparts);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this);
}
void State::storage_align(int stage_id, const Iterator& it, int factor, int offset) {
const Stage& stage = operator->()->stages[stage_id];
StorageAlignStep step = StorageAlignStep(stage_id, GetIndex(stage->iters, it), factor, offset);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this);
}
void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) {
const Stage& target_stage = operator->()->stages[target_stage_id];
ComputeAtStep step =
ComputeAtStep(stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter));
CopyOnWrite()->transform_steps.push_back(step);
step->ApplyToState(this);
}
void State::compute_inline(int stage_id) {
ComputeInlineStep step = ComputeInlineStep(stage_id);
CopyOnWrite()->transform_steps.push_back(step);
step->ApplyToState(this);
}
void State::compute_root(int stage_id) {
ComputeRootStep step = ComputeRootStep(stage_id);
CopyOnWrite()->transform_steps.push_back(step);
step->ApplyToState(this);
}
int State::cache_read(int stage_id, const String& scope_name,
const Array<Integer>& reader_stage_ids, const ComputeDAG& dag) {
CacheReadStep step = CacheReadStep(stage_id, scope_name, reader_stage_ids);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this, dag);
}
int State::cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag) {
CacheWriteStep step = CacheWriteStep(stage_id, scope_name);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this, dag);
}
int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag) {
const Stage& stage = operator->()->stages[stage_id];
RfactorStep step = RfactorStep(stage_id, GetIndex(stage->iters, it), factor_iter_id);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this, dag);
}
// Print stage to ostream
void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_indent,
bool delete_trivial_loop) {
const Stage& stage = state->stages[stage_id];
if (stage->attrs.auto_unroll_max_step != 0) {
for (size_t j = 0; j < base_indent; ++j) {
*os << " ";
}
*os << stage->op->name << " auto_unroll: " << stage->attrs.auto_unroll_max_step << "\n";
}
if (stage->attrs.storage_offset != 0) {
for (size_t j = 0; j < base_indent; ++j) {
*os << " ";
}
*os << stage->op->name << " storage_offset: " << stage->attrs.storage_offset << "\n";
}
size_t indent = 0;
for (size_t i = 0; i < stage->iters.size(); ++i) {
const Iterator& iter = stage->iters[i];
if (!(delete_trivial_loop && iter->range.defined() && is_one(iter->range->extent))) {
for (size_t j = 0; j < base_indent + indent; ++j) {
*os << " ";
}
*os << IteratorAnnotationString[static_cast<int>(iter->annotation)] << " ";
if (iter->range.defined()) {
*os << iter->name << " (" << iter->range->min << "," << iter->range->extent << ")";
} else {
*os << iter->name << " (None)";
}
*os << "\n";
indent += 2;
}
if (state.defined()) {
IterKey iter_key(stage_id, i);
auto pair = state->attach_map->iter_to_attached_stages.find(iter_key);
if (pair != state->attach_map->iter_to_attached_stages.end()) {
// Print the attached stage
for (const auto& attach_stage_id : pair->second) {
PrintStage(os, attach_stage_id, state, base_indent + indent, delete_trivial_loop);
}
}
}
}
for (size_t j = 0; j < base_indent + indent; ++j) {
*os << " ";
}
*os << stage->op->name << " = ...\n";
}
// Print state to ostream
void PrintState(std::ostream* os, const State& state, bool delete_trivial_loop) {
// Gather placeholders
Array<String> placeholders;
for (const auto& stage : state->stages) {
if (stage->op_type == StageKind::kPlaceholder) {
placeholders.push_back(stage->op->name);
}
}
*os << "Placeholder: ";
for (size_t i = 0; i < placeholders.size(); ++i) {
*os << placeholders[i];
if (i != placeholders.size() - 1) {
*os << ", ";
}
}
*os << "\n";
// Print all stages
for (size_t i = 0; i < state->stages.size(); ++i) {
const Stage& stage = state->stages[i];
if (stage->op_type == StageKind::kPlaceholder) {
continue;
} else if (stage->op_type == StageKind::kCompute) {
if (stage->compute_at == ComputeAtKind::kRoot) {
PrintStage(os, i, state, 0, delete_trivial_loop);
}
} else {
LOG(FATAL) << "Invalid op type";
}
}
}
String State::ToStr(bool delete_trivial_loop) const {
std::ostringstream os;
PrintState(&os, (*this), delete_trivial_loop);
return os.str();
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<StateNode>([](const ObjectRef& ref, ReprPrinter* p) {
PrintState(&p->stream, tvm::Downcast<State>(ref), true);
});
/********** State interface API for ffi **********/
TVM_REGISTER_GLOBAL("auto_scheduler.StateBind")
.set_body_typed([](State state, int stage_id, const Iterator& it, int thread_type) {
const auto& res = state.bind(stage_id, it, IteratorAnnotation(thread_type));
return Array<ObjectRef>{state, res};
});
TVM_REGISTER_GLOBAL("auto_scheduler.StateParallel")
.set_body_typed([](State state, int stage_id, const Iterator& it) {
const auto& res = state.parallel(stage_id, it);
return Array<ObjectRef>{state, res};
});
TVM_REGISTER_GLOBAL("auto_scheduler.StateUnroll")
.set_body_typed([](State state, int stage_id, const Iterator& it, int max_unroll) {
const auto& res = state.unroll(stage_id, it, max_unroll);
return Array<ObjectRef>{state, res};
});
TVM_REGISTER_GLOBAL("auto_scheduler.StateVectorize")
.set_body_typed([](State state, int stage_id, const Iterator& it) {
const auto& res = state.vectorize(stage_id, it);
return Array<ObjectRef>{state, res};
});
TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse")
.set_body_typed([](State state, int stage_id, const Array<Iterator>& iters) {
const auto& res = state.fuse(stage_id, iters);
return Array<ObjectRef>{state, res};
});
TVM_REGISTER_GLOBAL("auto_scheduler.StatePragma")
.set_body_typed([](State state, int stage_id, const Iterator& it, const String& pragma_type) {
state.pragma(stage_id, it, pragma_type);
return state;
});
TVM_REGISTER_GLOBAL("auto_scheduler.StateReorder")
.set_body_typed([](State state, int stage_id, const Array<Iterator>& order) {
state.reorder(stage_id, order);
return state;
});
TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit")
.set_body_typed([](State state, int stage_id, const Iterator& it,
const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
const auto& res = state.split(stage_id, it, lengths, inner_to_outer);
return Array<ObjectRef>{state, res};
});
TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit")
.set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id,
int n_split) {
const auto& res = state.follow_split(stage_id, it, src_step_id, n_split);
return Array<ObjectRef>{state, Array<Iterator>(res)};
});
TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit")
.set_body_typed([](State state, int stage_id, const Iterator& it,
const Array<Integer>& src_step_ids, int level, bool factor_or_nparts) {
const auto& res =
state.follow_fused_split(stage_id, it, src_step_ids, level, factor_or_nparts);
return Array<ObjectRef>{state, Array<Iterator>(res)};
});
TVM_REGISTER_GLOBAL("auto_scheduler.StateStorageAlign")
.set_body_typed([](State state, int stage_id, const Iterator& it, int factor, int offset) {
state.storage_align(stage_id, it, factor, offset);
return state;
});
TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt")
.set_body_typed([](State state, int stage_id, int target_stage_id,
const Iterator& target_iter) {
state.compute_at(stage_id, target_stage_id, target_iter);
return state;
});
TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeInline")
.set_body_typed([](State state, int stage_id) {
state.compute_inline(stage_id);
return state;
});
TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeRoot")
.set_body_typed([](State state, int stage_id) {
state.compute_root(stage_id);
return state;
});
TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheRead")
.set_body_typed([](State state, int stage_id, const String& scope_name,
const Array<Integer>& reader_stage_ids, const ComputeDAG& dag) {
int res = state.cache_read(stage_id, scope_name, reader_stage_ids, dag);
return Array<ObjectRef>{state, Integer(res)};
});
TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheWrite")
.set_body_typed([](State state, int stage_id, const String& scope_name,
const ComputeDAG& task_dag) {
int res = state.cache_write(stage_id, scope_name, task_dag);
return Array<ObjectRef>{state, Integer(res)};
});
TVM_REGISTER_GLOBAL("auto_scheduler.StateRfactor")
.set_body_typed([](State state, int stage_id, const Iterator& it, int factor_iter_id,
const ComputeDAG& dag) {
int res = state.rfactor(stage_id, it, factor_iter_id, dag);
return Array<ObjectRef>{state, Integer(res)};
});
TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) {
return std::equal_to<State>()(state1, state2);
});
} // namespace auto_scheduler
} // namespace tvm