blob: 05c04834e78cd5cb208c2fbcf5c0ce9864f300c5 [file] [log] [blame]
/*!
* Copyright (c) 2016 by Contributors
* \file bound.cc
* \brief The bound inference logic.
*/
#include <tvm/ir_visitor.h>
#include <tvm/schedule_pass.h>
#include <tvm/operation.h>
#include <tvm/ir_pass.h>
#include <unordered_map>
#include <unordered_set>
#include "graph.h"
#include "message_passing.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace schedule {
using runtime::StorageRank;
using runtime::StorageScope;
using runtime::ThreadScope;
/*! \brief The graph context used during bound inference. */
struct GraphContext {
/*! \brief The feed graph */
FeedGraph feed_graph;
/*! \brief Attachment path */
AttachPath attach_path;
/*! \brief The bind map */
std::unordered_map<IterVar, IterVar> bind_map;
/*! \brief map from op to stage */
std::unordered_map<const Node*, Stage> op2stage_;
};
bool NeedRelax(const IterVar& iv,
bool found_attach,
const std::unordered_map<IterVar, IterVar>& bind_map,
const runtime::StorageScope& scope) {
auto it = bind_map.find(iv);
const std::string& tag = (
it != bind_map.end() ? it->second->thread_tag : iv->thread_tag);
if (tag.length() == 0 || tag == "pipeline") {
return !found_attach;
}
ThreadScope ts = ThreadScope::make(tag);
// When there is warp memory
// threadIdx.x must be set to be warp index.
if (scope.rank == StorageRank::kWarp &&
ts.rank == 1 &&
ts.dim_index == 0) {
return true;
}
return static_cast<int>(scope.rank) <= ts.rank;
}
// infer storage scope, if not given
StorageScope InferStorageScope(
const Stage& stage, const GraphContext& ctx) {
if (stage->scope.length() != 0) {
return StorageScope::make(stage->scope);
}
int max_rank = -1;
for (IterVar iv : ctx.attach_path.at(stage->op)) {
auto it = ctx.bind_map.find(iv);
const std::string& tag = (
it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag);
if (tag != "pipeline" && tag.length() != 0) {
max_rank = std::max(max_rank, ThreadScope::make(tag).rank);
}
}
StorageScope s;
s.rank = runtime::DefaultStorageRank(max_rank);
return s;
}
void InferRootBound(const Stage& stage,
const GraphContext& ctx,
std::unordered_map<IterVar, Range>* rmap) {
CHECK_NE(stage->attach_type, kInline)
<< "call schedule.normalize before scheduleops";
if (stage->attach_type == kInlinedAlready) return;
if (stage->is_output) {
// verify correctness.
CHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot)
<< "Output must be attached at root";
}
if (stage->is_output || stage->op.as<PlaceholderOpNode>()) {
for (auto iv : stage->op->root_iter_vars()) {
CHECK(iv->dom.defined());
CHECK(!rmap->count(iv));
(*rmap)[iv] = iv->dom;
}
return;
}
// The tensor domain.
std::unordered_map<Tensor, TensorDom> tmap;
// The consumers of the op.
std::unordered_set<Operation> consumers;
for (int i = 0; i < stage->op->num_outputs(); ++i) {
Tensor t = stage->op.output(i);
tmap.emplace(t, TensorDom(static_cast<int>(t.ndim())));
auto it = ctx.feed_graph.find(t);
if (it != ctx.feed_graph.end()) {
for (const Operation& op : it->second) {
consumers.insert(op);
}
} else {
LOG(INFO) << "not in feed graph consumer = " << stage->op;
}
}
// storage scope.
runtime::StorageScope scope = InferStorageScope(stage, ctx);
// Bound prop by other consumers.
// - Compute bound by relaxation rules: NeedRelax
// - For normal index, use relative location of loop nest./
// - For thread index, use the thread scope.
//
Array<IterVar> stage_attach = ctx.attach_path.at(stage->op);
// The parent set.
for (const Operation& op : consumers) {
std::unordered_map<const Variable*, IntSet> relax_set;
std::unordered_map<IterVar, IntSet> up_state;
bool found_attach = false;
CHECK(ctx.op2stage_.count(op.get()));
const Stage& op_stage = ctx.op2stage_.at(op.get());
// Consumer nest
for (size_t i = op_stage->leaf_iter_vars.size(); i != 0; --i) {
IterVar iv = op_stage->leaf_iter_vars[i - 1];
if (stage_attach.size() != 0 && iv == stage_attach[0]) {
found_attach = true;
}
auto it = rmap->find(iv);
CHECK(it != rmap->end());
const Range& vrange = it->second;
if (is_one(vrange->extent)) {
up_state[iv] = IntSet::single_point(vrange->min);
} else if (!NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
CHECK(is_zero(vrange->min))
<< "InferBound requires every leaf iter var's min equals 0, "
<< " call schedule.normalize to achieve this. ";
if (ctx.bind_map.count(iv)) {
up_state[iv] = IntSet::single_point(ctx.bind_map.at(iv)->var);
} else {
up_state[iv] = IntSet::single_point(iv->var);
}
} else {
up_state[iv] = IntSet::range(vrange);
}
}
// Consumer's attach nest
for (IterVar iv : ctx.attach_path.at(op)) {
if (stage_attach.size() != 0 && iv == stage_attach[0]) {
found_attach = true;
}
Range vrange = rmap->at(iv);
CHECK(is_zero(vrange->min))
<< "InferBound requires every leaf iter var's min equals 0, "
<< "call schedule.normalize to achieve this.";
if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
relax_set[iv->var.get()] = IntSet::range(vrange);
if (ctx.bind_map.count(iv)) {
relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange);
}
}
}
CHECK(found_attach || stage_attach.size() == 0)
<< "Invalid Schedule, cannot find the producer " << stage->op
<< " along the loop nest specified by compute_at of consumer " << op;
// Get the domain of the consumer
PassUpDomain(op_stage, *rmap, &up_state);
// Relax if needed.
std::unordered_map<const Variable*, IntSet> dom_map;
for (auto iv : op->root_iter_vars()) {
Range r;
if (up_state.count(iv)) {
r = up_state.at(iv).cover_range(iv->dom);
} else {
r = iv->dom;
}
if (relax_set.size() != 0) {
dom_map[iv->var.get()] = EvalSet(r, relax_set);
} else {
dom_map[iv->var.get()] = IntSet::range(r);
}
}
op->PropBoundToInputs(op, dom_map, &tmap);
}
stage->op->GatherBound(stage->op, tmap, rmap);
}
Map<IterVar, Range> InferBound(const Schedule& sch) {
// Prepare context
GraphContext ctx;
Array<Operation> roots;
for (Operation op : sch->outputs) {
roots.push_back(sch->stage_map[op]->op);
}
ctx.feed_graph = CreateFeedGraph(CreateReadGraph(roots));
for (Stage stage : sch->stages) {
for (auto kv : stage->iter_var_attrs) {
if (kv.second->bind_thread.defined()) {
CHECK(!ctx.bind_map.count(kv.first));
ctx.bind_map[kv.first] = kv.second->bind_thread;
}
}
ctx.op2stage_[stage->op.get()] = stage;
}
ctx.attach_path = CreateAttachPath(sch);
// Run inference.
std::unordered_map<IterVar, Range> ret;
for (size_t i = sch->stages.size(); i != 0; --i) {
const Stage& stage = sch->stages[i - 1];
InferRootBound(stage, ctx, &ret);
// pass down to get bound of all iter vars.
PassDownDomain(stage, &ret);
for (IterVar iv : stage->env_threads) {
CHECK(iv->dom.defined());
ret[iv] = iv->dom;
}
}
for (auto& p : ret) {
ret[p.first] = Range::make_by_min_extent(ir::Simplify(p.second->min),
ir::Simplify(p.second->extent));
}
return Map<IterVar, Range>(ret.begin(), ret.end());
}
} // namespace schedule
} // namespace tvm