blob: 78f0608a0d14893a699d32504d7d4a39e8f62883 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file schedule_dataflow_rewrite.cc
*/
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <unordered_set>
#include "../../tir/transforms/ir_util.h"
#include "message_passing.h"
#include "operation_inline.h"
namespace tvm {
namespace te {
// find first occurance location in leaf
template <typename T>
size_t FindNodeRef(ArrayNode* array_node, const T& v) {
const Object* n = v.get();
for (size_t i = 0; i < array_node->size(); ++i) {
if (array_node->at(i).get() == n) return i;
}
return array_node->size();
}
// The replacer of cache.
class VarReplacer : public tir::StmtExprMutator {
public:
explicit VarReplacer(const std::unordered_map<const VarNode*, PrimExpr>& vsub) : vsub_(vsub) {}
PrimExpr VisitExpr_(const VarNode* op) final {
auto it = vsub_.find(op);
if (it != vsub_.end()) return it->second;
return GetRef<PrimExpr>(op);
}
tir::CommReducer MutateCommReducer(tir::CommReducer combiner) {
// Replace free variables in combiner
auto new_identity = tir::UpdateArray(combiner->identity_element,
[this](const PrimExpr& e) { return this->VisitExpr(e); });
auto new_result = tir::UpdateArray(combiner->result,
[this](const PrimExpr& e) { return this->VisitExpr(e); });
if (combiner->identity_element.same_as(new_identity) &&
combiner->identity_element.same_as(new_result)) {
return combiner;
} else {
return tir::CommReducer(combiner->lhs, combiner->rhs, new_result, new_identity);
}
}
PrimExpr VisitExpr_(const tir::ReduceNode* op) final {
PrimExpr new_e = StmtExprMutator::VisitExpr_(op);
const tir::ReduceNode* new_reduce = new_e.as<tir::ReduceNode>();
tir::CommReducer new_combiner = MutateCommReducer(op->combiner);
if (op->combiner.same_as(new_combiner)) {
return new_e;
} else {
return tir::Reduce(new_combiner, new_reduce->source, new_reduce->axis, new_reduce->condition,
new_reduce->value_index, new_reduce->init);
}
}
private:
const std::unordered_map<const VarNode*, PrimExpr>& vsub_;
};
PrimExpr InjectPredicate(const Array<PrimExpr>& predicates, PrimExpr body) {
using tir::ReduceNode;
using tir::SelectNode;
if (predicates.size() == 0) return body;
const ReduceNode* reduce = body.as<ReduceNode>();
auto fand = [](PrimExpr a, PrimExpr b) { return a && b; };
if (reduce) {
auto n = make_object<ReduceNode>(*reduce);
n->condition = foldl(fand, n->condition, predicates);
return PrimExpr(n);
}
return Select(foldl(fand, const_true(1), predicates), body, make_zero(body.dtype()));
}
// Replace data flow appears in all stages given the tensor change.
// Also update vmap if subsequent dataflow need to be replaced.
// Need to keep an update to the date transitive closure property on the vmap by a reverse map.
void ReplaceDataFlow(const Array<Stage>& stages, std::unordered_map<Tensor, Tensor>* vmap,
std::unordered_map<Tensor, Tensor>* rvmap) {
for (Stage s : stages) {
Operation op = s->op->ReplaceInputs(s->op, *vmap);
if (!op.same_as(s->op)) {
for (int i = 0; i < op->num_outputs(); ++i) {
auto it = rvmap->find(s->op.output(i));
if (it != rvmap->end()) {
(*vmap)[it->second] = op.output(i);
} else {
(*vmap)[s->op.output(i)] = op.output(i);
(*rvmap)[op.output(i)] = s->op.output(i);
}
}
s->op = op;
}
}
}
inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) {
return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) &&
(a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)) &&
((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init)));
}
Tensor Schedule::cache_read(const Tensor& tensor, const std::string& scope,
const Array<Operation>& readers) {
(*this)->InvalidateCache();
// create identity mapping.
std::ostringstream os;
os << tensor->op->name;
if (tensor->op->num_outputs() != 1) {
os << ".v" << tensor->value_index;
}
os << "." << scope;
std::unordered_map<Tensor, Tensor> vsub;
Stage s = operator[](tensor->op);
Tensor sugar_tensor = s->op.output(tensor->value_index);
Tensor cache = compute(
sugar_tensor->shape,
[&sugar_tensor](const Array<Var>& i) {
return sugar_tensor(Array<PrimExpr>(i.begin(), i.end()));
},
os.str());
vsub[sugar_tensor] = cache;
std::unordered_map<Tensor, Tensor> vmap;
std::unordered_map<Tensor, Tensor> rvmap;
for (Operation op : readers) {
Stage s = operator[](op);
Operation repl_op = s->op->ReplaceInputs(s->op, vsub);
CHECK(!repl_op.same_as(s->op)) << "Cannot find " << tensor << " in the inputs of " << s->op;
vmap[s->op.output(0)] = repl_op.output(0);
rvmap[repl_op.output(0)] = s->op.output(0);
s->op = repl_op;
}
ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
Array<Stage>& stages = (*this)->stages;
Stage op_stage = operator[](tensor->op);
size_t pos = FindNodeRef(stages.GetArrayNode(), op_stage);
Stage cache_stage = Stage(cache->op);
cache_stage.set_scope(scope);
CHECK_LT(pos, stages.size());
stages.insert(stages.begin() + pos + 1, cache_stage);
(*this)->stage_map.Set(cache->op, cache_stage);
// Update group
cache_stage->group = op_stage->group;
if (cache_stage->group.defined()) {
++cache_stage->group->num_child_stages;
}
return cache;
}
template <typename OpType>
void PrepareAxisMapping(Stage orig_stage, OpType* op, std::unordered_set<IterVar>* p_red_axis,
Array<IterVar>* p_new_axis, std::unordered_map<IterVar, Range>* p_dom_map,
std::unordered_map<const VarNode*, PrimExpr>* p_vsub,
std::unordered_map<const VarNode*, PrimExpr>* p_vsub2newvar,
std::vector<PrimExpr>* p_predicates) {
auto& red_axis = *p_red_axis;
auto& new_axis = *p_new_axis;
auto& dom_map = *p_dom_map;
auto& vsub = *p_vsub;
auto& vsub2newvar = *p_vsub2newvar;
auto& predicates = *p_predicates;
arith::Analyzer analyzer;
for (IterVar iv : op->reduce_axis) {
red_axis.insert(iv);
}
for (IterVar iv : op->axis) {
dom_map[iv] = iv->dom;
analyzer.Bind(iv->var, iv->dom);
}
te::PassDownDomain(orig_stage, &dom_map, &analyzer, true);
{
// The source->cache
std::unordered_map<IterVar, PrimExpr> value_map;
for (IterVar iv : orig_stage->leaf_iter_vars) {
if (red_axis.count(iv)) continue;
CHECK_EQ(iv->iter_type, kDataPar) << "Can only relayout with in data parallel dimensions";
Range dom = dom_map.at(iv);
IterVar new_iv = IterVar(dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
new_axis.push_back(new_iv);
if (is_one(dom->min)) {
value_map[iv] = dom->min;
} else {
value_map[iv] = iv->var;
vsub2newvar[iv->var.get()] = new_iv->var;
}
}
// skip reduction iteration.
std::unordered_set<IterVar> skip_bound_check;
for (IterVar iv : op->reduce_axis) {
skip_bound_check.insert(iv);
}
PassUpIndex(orig_stage, dom_map, &value_map, true);
predicates = MakeBoundCheck(orig_stage, dom_map, value_map, true, skip_bound_check);
// The root axis
for (IterVar iv : op->axis) {
if (value_map.count(iv)) {
vsub[iv->var.get()] = value_map.at(iv);
} // to handle tensor axis
}
}
}
Array<Tensor> ReplaceOriginalOp(Schedule sch, Stage orig_stage, const std::string& scope,
Operation cache_op, Operation orig_new_op, size_t tensor_size) {
Array<Tensor> cache_tensor_list;
for (size_t i = 0; i < tensor_size; i++) {
Tensor cache_tensor = cache_op.output(i);
cache_tensor_list.push_back(cache_tensor);
}
// The replace of the dataflow
std::unordered_map<Tensor, Tensor> vmap;
std::unordered_map<Tensor, Tensor> rvmap;
vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
for (size_t i = 0; i < tensor_size; i++) {
vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
}
ReplaceDataFlow(sch->stages, &vmap, &rvmap);
// mutate orig stage
orig_stage->op = orig_new_op;
orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
orig_stage->relations = Array<IterVarRelation>();
// create schedule for new cached stage.
Array<Stage>& stages = sch->stages;
size_t pos = FindNodeRef(stages.GetArrayNode(), orig_stage);
Stage cache_stage = Stage(cache_op);
cache_stage.set_scope(scope);
CHECK_LT(pos, stages.size());
stages.insert(stages.begin() + pos, cache_stage);
sch->stage_map.Set(cache_op, cache_stage);
// Update group
cache_stage->group = orig_stage->group;
if (cache_stage->group.defined()) {
++cache_stage->group->num_child_stages;
}
return cache_tensor_list;
}
// Cache write and relayout the data according to loop pattern
Array<Tensor> CacheWriteWithReLayout(Schedule sch, const Array<Tensor>& tensor_array,
const std::string& scope) {
size_t tensor_size = tensor_array.size();
sch->InvalidateCache();
Tensor tensor = tensor_array[0];
Stage orig_stage = sch[tensor->op];
const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
std::unordered_set<IterVar> red_axis;
Array<IterVar> new_axis;
std::unordered_map<IterVar, Range> dom_map;
std::unordered_map<const VarNode*, PrimExpr> vsub;
std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
std::vector<PrimExpr> predicates;
PrepareAxisMapping(orig_stage, compute, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar,
&predicates);
PrimExpr body;
Array<PrimExpr> body_list;
const tir::ReduceNode* first_reduce = nullptr;
for (auto cbody : compute->body) {
body = VarReplacer(vsub)(cbody);
body = InjectPredicate(predicates, body);
body = VarReplacer(vsub2newvar)(body);
// Reduce nodes in ONE computeOp must be the same except value_index
// This is right only if the original body ensures Reduce nodes are the same
if (body->IsInstance<tir::ReduceNode>()) {
const tir::ReduceNode* reduce_body = body.as<tir::ReduceNode>();
if (first_reduce != nullptr) {
CHECK(ReduceEqual(reduce_body, first_reduce));
body = tir::Reduce(first_reduce->combiner, first_reduce->source, first_reduce->axis,
first_reduce->condition, reduce_body->value_index, reduce_body->init);
} else {
first_reduce = reduce_body;
}
} else {
CHECK(first_reduce == nullptr) << "cannot mix reduce and other node in ONE compute bodys";
}
body_list.push_back(body);
}
// The reader args
Array<PrimExpr> args;
{
// cache->compute
std::unordered_map<IterVar, PrimExpr> value_map;
for (IterVar iv : compute->axis) {
value_map[iv] = iv->var;
}
te::PassDownIndex(orig_stage, dom_map, &value_map, true);
for (IterVar iv : orig_stage->leaf_iter_vars) {
if (red_axis.count(iv)) continue;
args.push_back(value_map.at(iv));
}
}
Operation cache_op =
ComputeOp(compute->name + "." + scope, compute->tag, compute->attrs, new_axis, body_list);
Array<PrimExpr> cache_expr_list;
for (size_t i = 0; i < tensor_size; i++) {
Tensor cache_tensor = cache_op.output(i);
cache_expr_list.push_back(cache_tensor(args));
}
Operation orig_new_op =
ComputeOp(compute->name, compute->tag, compute->attrs, compute->axis, cache_expr_list);
return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size);
}
// for tensor compute op
Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch, const Array<Tensor>& tensor_array,
const std::string& scope) {
size_t tensor_size = tensor_array.size();
sch->InvalidateCache();
Tensor tensor = tensor_array[0];
Stage orig_stage = sch[tensor->op];
const TensorComputeOpNode* tensor_op = orig_stage->op.as<TensorComputeOpNode>();
CHECK_EQ(tensor_op->num_outputs(), 1)
<< "cache write only support single output tensor_compute_op";
std::unordered_set<IterVar> red_axis;
Array<IterVar> new_axis;
std::unordered_map<IterVar, Range> dom_map;
std::unordered_map<const VarNode*, PrimExpr> vsub;
std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
std::vector<PrimExpr> predicates;
PrepareAxisMapping(orig_stage, tensor_op, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar,
&predicates);
for (int i = tensor_op->schedulable_ndim; i < static_cast<int>(tensor_op->axis.size()); ++i) {
IterVar iv = tensor_op->axis[i];
IterVar new_iv = IterVar(iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
new_axis.push_back(new_iv);
}
Array<Region> new_regions;
for (Region old_region : tensor_op->input_regions) {
Region region;
for (Range r : old_region) {
PrimExpr min = VarReplacer(vsub2newvar)(r->min);
PrimExpr extent = VarReplacer(vsub2newvar)(r->extent);
region.push_back(Range::FromMinExtent(min, extent));
}
new_regions.push_back(region);
}
Array<PrimExpr> new_scalar_inputs;
for (PrimExpr old_input : tensor_op->scalar_inputs) {
new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input));
}
Operation cache_op =
TensorComputeOp(tensor_op->name + "." + scope, tensor_op->tag, new_axis,
tensor_op->reduce_axis, tensor_op->schedulable_ndim, tensor_op->intrin,
tensor_op->inputs, new_regions, new_scalar_inputs);
// axis will be used in generating compute op
Array<IterVar> compute_axis = tensor_op->axis;
for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
IterVar iv = tensor_op->axis[i];
IterVar aiv = IterVar(iv->dom, iv->var, kDataPar);
compute_axis.Set(i, aiv);
}
// The reader args
Array<PrimExpr> args;
{
// cache->compute
std::unordered_map<IterVar, PrimExpr> value_map;
for (IterVar iv : compute_axis) {
value_map[iv] = iv->var;
}
PassDownIndex(orig_stage, dom_map, &value_map, true);
for (IterVar iv : orig_stage->leaf_iter_vars) {
if (red_axis.count(iv)) continue;
args.push_back(value_map.at(iv));
}
// tensorized region axis
for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
IterVar iv = compute_axis[i];
args.push_back(value_map.at(iv));
}
}
Array<PrimExpr> cache_expr_list;
for (size_t i = 0; i < tensor_size; i++) {
Tensor cache_tensor = cache_op.output(i);
cache_expr_list.push_back(cache_tensor(args));
}
Operation orig_new_op =
ComputeOp(tensor_op->name, tensor_op->tag, {}, compute_axis, cache_expr_list);
return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size);
}
Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array, const std::string& scope) {
(*this)->InvalidateCache();
CHECK(tensor_array.size() > 0) << "size of tensor_array must be greater than 0";
Tensor tensor = tensor_array[0];
Stage orig_stage = operator[](tensor->op);
const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
CHECK(static_cast<size_t>(compute->num_outputs()) == tensor_array.size())
<< "size of input tensor list must be same as number of stage outputs";
for (size_t i = 1; i < tensor_array.size(); i++) {
Stage tmp_stage = operator[](tensor_array[i]->op);
CHECK(orig_stage.same_as(tmp_stage)) << "Input tensor list must be generated by ONE computeOp";
}
return CacheWriteWithReLayout(*this, tensor_array, scope);
}
Tensor Schedule::cache_write(const Tensor& tensor, const std::string& scope) {
// support original compute and tensor compute both
(*this)->InvalidateCache();
if (tensor->op.as<ComputeOpNode>()) {
return (CacheWriteWithReLayout(*this, {tensor}, scope))[0];
} else if (tensor->op.as<TensorComputeOpNode>()) {
return (CacheWriteWithReLayoutTensor(*this, {tensor}, scope))[0];
} else {
LOG(FATAL) << "cache write only take ComputeOp or TensorComputeOp as writers";
return Tensor();
}
}
void RebaseNonZeroMinLoop(ScheduleNode* sch) {
std::unordered_map<IterVar, IterVar> rebase_map;
for (Stage s : sch->stages) {
if (s->attach_type == kInlinedAlready) continue;
auto root_iter_vars = s->op->root_iter_vars();
ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
for (IterVar iv : root_iter_vars) {
size_t idx = FindNodeRef(leaf_vars, iv);
auto it = s->iter_var_attrs.find(iv);
// don;t need to rebase path that are binded.
if (it != s->iter_var_attrs.end() && (*it).second->bind_thread.defined()) {
continue;
}
if (idx < leaf_vars->size()) {
// insert rebase
IterVar rebased = IterVar(Range(), iv->var.copy_with_suffix(""), iv->iter_type);
s->relations.push_back(te::Rebase(iv, rebased));
if (s->iter_var_attrs.count(iv)) {
s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv));
}
leaf_vars->SetItem(idx, rebased);
rebase_map[iv] = rebased;
}
}
}
// remap the parent relation
for (Stage s : sch->stages) {
if (s->attach_type != kScope) continue;
if (rebase_map.count(s->attach_ivar)) {
s->attach_ivar = rebase_map.at(s->attach_ivar);
}
}
for (Stage s : sch->groups) {
if (s->attach_type != kScope) continue;
if (rebase_map.count(s->attach_ivar)) {
s->attach_ivar = rebase_map.at(s->attach_ivar);
}
}
}
void InjectInline(ScheduleNode* sch) {
sch->InvalidateCache();
std::vector<Array<PrimExpr> > new_body(sch->stages.size());
std::vector<bool> changed(sch->stages.size(), false);
std::vector<Stmt> new_hybrid_body(sch->stages.size());
std::vector<bool> hybrid_changed(sch->stages.size(), false);
// inline all the ops
for (size_t i = sch->stages.size(); i != 0; --i) {
Stage stage = sch->stages[i - 1];
if (stage->attach_type == kInline) {
stage->attach_type = kInlinedAlready;
Array<Var> args;
PrimExpr body;
{
// setup args
const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
CHECK(compute) << "can only inline compute op";
for (auto iv : compute->axis) {
args.push_back(iv->var);
}
CHECK_EQ(compute->body.size(), 1U) << "can only inline compute op with 1 output";
body = compute->body[0];
}
for (size_t j = i; j < sch->stages.size(); ++j) {
Stage s = sch->stages[j];
const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
const HybridOpNode* hybrid = s->op.as<HybridOpNode>();
if (compute) {
if (!new_body[j].size()) {
new_body[j] = compute->body;
}
if (new_body[j][0]->IsInstance<tir::ReduceNode>()) {
// specially handle reduction inline for multiplre reductions.
const tir::ReduceNode* reduce = new_body[j][0].as<tir::ReduceNode>();
for (size_t k = 1; k < new_body[j].size(); ++k) {
const tir::ReduceNode* reduce_ = new_body[j][k].as<tir::ReduceNode>();
CHECK(reduce_);
CHECK(ReduceEqual(reduce_, reduce)) << "The Reduce inputs of ComputeOp should "
<< "have the same attribute except value_index";
}
PrimExpr new_value = Inline(tir::Evaluate(new_body[j][0]), stage->op, args, body)
.as<tir::EvaluateNode>()
->value;
if (!new_value.same_as(new_body[j][0])) {
changed[j] = true;
const tir::ReduceNode* r = new_value.as<tir::ReduceNode>();
CHECK(r != nullptr);
CHECK_EQ(new_body[j].size(), r->source.size());
for (size_t k = 0; k < new_body[j].size(); ++k) {
auto n = make_object<tir::ReduceNode>(*r);
n->value_index = static_cast<int>(k);
n->dtype = r->source[k].dtype();
new_body[j].Set(k, PrimExpr(n));
}
}
} else {
for (size_t k = 0; k < new_body[j].size(); ++k) {
PrimExpr new_value = Inline(tir::Evaluate(new_body[j][k]), stage->op, args, body)
.as<tir::EvaluateNode>()
->value;
if (!new_value.same_as(new_body[j][k])) {
new_body[j].Set(k, new_value);
changed[j] = true;
}
}
}
} else if (hybrid) {
if (!new_hybrid_body[j].defined()) {
new_hybrid_body[j] = hybrid->body;
}
Stmt new_stmt = Inline(new_hybrid_body[j], stage->op, args, body);
if (!new_stmt.same_as(new_hybrid_body[j])) {
new_hybrid_body[j] = new_stmt;
hybrid_changed[j] = true;
}
}
}
}
}
std::unordered_map<Tensor, Tensor> repl;
// rewrite dataflow
for (size_t i = 0; i < sch->stages.size(); ++i) {
Stage s = sch->stages[i];
if (s->attach_type == kInlinedAlready) continue;
if (new_body[i].size()) {
// Logics from ReplaceDataFlow
const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
CHECK(compute);
Operation op = s->op;
if (changed[i]) {
op = ComputeOp(compute->name, compute->tag, compute->attrs, compute->axis, new_body[i]);
}
op = op->ReplaceInputs(op, repl);
if (!op.same_as(s->op)) {
for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
repl[s->op.output(idx)] = op.output(idx);
}
s->op = op;
}
} else if (hybrid_changed[i]) {
const HybridOpNode* hybrid = sch->stages[i]->op.as<HybridOpNode>();
CHECK(hybrid);
Operation op = HybridOp(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
hybrid->outputs, new_hybrid_body[i]);
op = op->ReplaceInputs(op, repl);
for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
repl[s->op.output(idx)] = op.output(idx);
}
s->op = op;
} else {
Operation op = s->op->ReplaceInputs(s->op, repl);
if (!op.same_as(s->op)) {
for (int j = 0; j < op->num_outputs(); ++j) {
repl[s->op.output(j)] = op.output(j);
}
s->op = op;
}
}
}
}
void LegalizeInvalidAttach(ScheduleNode* sch) {
// Legalize the compute_at location if the target iterator of compute_at is split or fused.
// Case 1: If the target of compute_at is split,
// we will move the compute_at location to the inner iterator.
// Case 2: If the target of compute_at is fused,
// we will move the compute_at location to the newly fused iterator.
// Note that case 2 can only happen if the target of compute_at
// is the innermost operand of fuse operation.
// Map an old invalid attach point to its new valid attach point
std::unordered_map<IterVar, IterVar> replace_map;
for (Stage stage : sch->stages) {
for (Stage s = stage; s.defined();) {
// The following logic is simiar to the `CreateAttachPath` in `src/te/schedule/graph.h`,
// because we follow the validation check in that function to legalize the attach.
Stage spec = s.GetAttachSpec();
if (spec->attach_type != kScope) {
break;
}
bool start_attach = false;
IterVar attach_ivar = spec->attach_ivar;
s = spec->attach_stage;
CHECK(attach_ivar.defined());
CHECK(s.defined());
for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
IterVar iv = s->leaf_iter_vars[i - 1];
if (!start_attach && iv.same_as(attach_ivar)) {
start_attach = true;
break;
}
}
if (!start_attach) {
IterVar new_attach_ivar = attach_ivar;
bool updated = true;
// recursively update the relations
while (updated) {
updated = false;
for (const auto& rel : s->relations) {
if (const FuseNode* r = rel.as<FuseNode>()) {
if (new_attach_ivar.same_as(r->inner)) {
new_attach_ivar = r->fused;
updated = true;
}
} else if (const SplitNode* r = rel.as<SplitNode>()) {
if (new_attach_ivar.same_as(r->parent)) {
new_attach_ivar = r->inner;
updated = true;
}
}
}
replace_map[attach_ivar] = new_attach_ivar;
}
}
}
}
// remap the parent relation
for (Stage s : sch->stages) {
if (s->attach_type != kScope) continue;
if (replace_map.count(s->attach_ivar)) {
s->attach_ivar = replace_map.at(s->attach_ivar);
}
}
for (Stage s : sch->groups) {
if (s->attach_type != kScope) continue;
if (replace_map.count(s->attach_ivar)) {
s->attach_ivar = replace_map.at(s->attach_ivar);
}
}
}
Schedule Schedule::normalize() {
Schedule sn = copy();
InjectInline(sn.operator->());
RebaseNonZeroMinLoop(sn.operator->());
LegalizeInvalidAttach(sn.operator->());
return sn;
}
// Handle reduction factor.
Array<Tensor> Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis) {
(*this)->InvalidateCache();
using tir::ReduceNode;
CHECK_EQ(axis->iter_type, kCommReduce) << "Can only factor reduction axis";
Stage reduce_stage = operator[](tensor->op);
const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>();
CHECK(compute_op) << "Can only factor ComputeOp";
ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite();
{
size_t axis_pos = FindNodeRef(leaf_vars, axis);
CHECK_NE(axis_pos, leaf_vars->size()) << "Cannot find IterVar " << axis << " in leaf iter vars";
}
// Find touched reduction axis.
std::unordered_map<IterVar, int> touch_map;
touch_map[axis] = 1;
te::PassUpBitMaskOr(reduce_stage, &touch_map, true);
te::PassDownBitMaskOr(reduce_stage, &touch_map, true);
// skip reduction iteration.
std::unordered_set<IterVar> skip_bound_check;
// Verify normal axis are not touched.
for (IterVar iv : compute_op->axis) {
CHECK(!touch_map.count(iv)) << "Factor axis touches normal axis.";
skip_bound_check.insert(iv);
}
// get analyzer.
arith::Analyzer analyzer;
// Get the replace index
std::unordered_map<IterVar, Range> dom_map;
std::unordered_map<IterVar, PrimExpr> value_map;
for (IterVar iv : compute_op->reduce_axis) {
if (touch_map.count(iv)) {
dom_map[iv] = iv->dom;
} else {
skip_bound_check.insert(iv);
}
analyzer.Bind(iv->var, iv->dom);
}
te::PassDownDomain(reduce_stage, &dom_map, &analyzer, true);
for (IterVar iv : reduce_stage->leaf_iter_vars) {
if (touch_map.count(iv)) {
Range dom = dom_map.at(iv);
if (is_one(dom->extent)) {
value_map[iv] = dom->min;
} else {
value_map[iv] = iv->var;
}
}
}
te::PassUpIndex(reduce_stage, dom_map, &value_map, true);
std::vector<PrimExpr> predicates =
MakeBoundCheck(reduce_stage, dom_map, value_map, true, skip_bound_check);
// Get the factored op node.
const int factor_axis_pos =
factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis;
CHECK_LE(factor_axis_pos, compute_op->axis.size());
auto n = make_object<ComputeOpNode>();
n->name = compute_op->name + ".rf";
{
// axis relacement.
auto iv_node = make_object<IterVarNode>();
iv_node->dom = dom_map.at(axis);
CHECK(is_zero(iv_node->dom->min)) << "Can only factor reduction domain starting from 0";
iv_node->var = axis->var;
iv_node->iter_type = kDataPar;
const int size = compute_op->axis.size();
for (int idx = 0; idx < size; ++idx) {
if (factor_axis_pos == idx) {
n->axis.push_back(IterVar(iv_node));
}
n->axis.push_back(compute_op->axis[idx]);
}
if (factor_axis_pos == size) {
n->axis.push_back(IterVar(iv_node));
}
}
// predicate generation, copy not touched axis.
int idx = tensor->value_index;
const ReduceNode* reduce = compute_op->body[idx].as<ReduceNode>();
CHECK(reduce) << "Can only rfactor non-inline reductions";
predicates.push_back(reduce->condition);
auto fand = [](PrimExpr a, PrimExpr b) { return a && b; };
PrimExpr predicate = likely(foldl(fand, const_true(1), predicates));
std::unordered_map<const VarNode*, PrimExpr> vsub;
for (IterVar iv : compute_op->reduce_axis) {
if (!touch_map.count(iv)) {
n->reduce_axis.push_back(iv);
} else {
CHECK(value_map.count(iv));
PrimExpr index = value_map.at(iv);
vsub[iv->var.get()] = index;
}
}
// Copy touched axis.
for (IterVar iv : reduce_stage->leaf_iter_vars) {
if (touch_map.count(iv) && !iv.same_as(axis)) {
CHECK_EQ(iv->iter_type, kCommReduce);
auto ncpy = make_object<IterVarNode>(*iv.operator->());
ncpy->dom = dom_map.at(iv);
n->reduce_axis.push_back(IterVar(ncpy));
}
}
VarReplacer replacer(vsub);
Array<PrimExpr> new_source =
tir::UpdateArray(reduce->source, [&replacer](const PrimExpr& e) { return replacer(e); });
PrimExpr new_pred = replacer(predicate);
std::vector<PrimExpr> body;
for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
body.emplace_back(Reduce(reduce->combiner, new_source, n->reduce_axis, new_pred, idx, {}));
}
n->body = Array<PrimExpr>(body);
// refresh relations, keep the un-touched relations.
Array<IterVarRelation> rels;
for (IterVarRelation rel : reduce_stage->relations) {
bool touched = false;
if (const SplitNode* r = rel.as<SplitNode>()) {
if (touch_map.count(r->parent)) touched = true;
} else if (const FuseNode* r = rel.as<FuseNode>()) {
if (touch_map.count(r->fused)) touched = true;
} else if (const RebaseNode* r = rel.as<RebaseNode>()) {
if (touch_map.count(r->parent)) touched = true;
} else {
LOG(FATAL) << "unknown relation type";
}
if (!touched) {
rels.push_back(rel);
}
}
// initialize the factored stage.
Operation factor_op(n);
Array<Stage>& stages = (*this)->stages;
size_t stage_pos = FindNodeRef(stages.GetArrayNode(), reduce_stage);
Stage factor_stage = Stage(factor_op);
factor_stage->relations = rels;
CHECK_LT(stage_pos, stages.size());
stages.insert(stages.begin() + stage_pos, factor_stage);
(*this)->stage_map.Set(factor_op, factor_stage);
factor_stage->group = reduce_stage->group;
if (factor_stage->group.defined()) {
++factor_stage->group->num_child_stages;
}
// Replace the old reduction.
IterVar repl_red_axis = reduce_axis(dom_map.at(axis), axis->var->name_hint + ".v");
Array<Tensor> factor_tensors;
Array<Tensor> old_tensors;
int size = factor_op->num_outputs();
for (int idx = 0; idx < size; ++idx) {
factor_tensors.push_back(factor_op.output(idx));
old_tensors.push_back(reduce_stage->op.output(idx));
}
Array<Tensor> repl_tensors = compute(
old_tensors[0]->shape,
[&](const Array<Var>& i) {
Array<PrimExpr> indices;
const int idx_size = static_cast<int>(i.size());
for (int idx = 0; idx < idx_size; ++idx) {
if (factor_axis_pos == idx) {
indices.push_back(repl_red_axis->var);
}
indices.push_back(i[idx]);
}
Array<PrimExpr> new_init = reduce->init;
if (!reduce->init.empty()) {
std::unordered_map<const VarNode*, PrimExpr> init_vsub;
for (const auto& init : reduce->init) {
if (init->IsInstance<ProducerLoadNode>()) {
CHECK_EQ(compute_op->axis.size(), idx_size)
<< "'init' should have the number of dimensions as output when using with "
"rfactor";
for (int idx = 0; idx < idx_size; idx++) {
init_vsub[compute_op->axis[idx]->var.get()] = i[idx];
}
}
}
VarReplacer init_replacer(init_vsub);
new_init = tir::UpdateArray(
reduce->init, [&init_replacer](const PrimExpr& e) { return init_replacer(e); });
}
if (factor_axis_pos == idx_size) {
indices.push_back(repl_red_axis->var);
}
Array<PrimExpr> factor_exprs;
for (int idx = 0; idx < size; ++idx) {
factor_exprs.push_back(factor_tensors[idx](indices));
}
Array<PrimExpr> reductions;
Array<IterVar> axis = {repl_red_axis};
PrimExpr cond = const_true();
for (int idx = 0; idx < size; ++idx) {
reductions.push_back(Reduce(reduce->combiner, factor_exprs, axis, cond, idx, new_init));
}
return reductions;
},
reduce_stage->op->name + ".repl");
std::unordered_map<Tensor, Tensor> vmap;
std::unordered_map<Tensor, Tensor> rvmap;
for (int idx = 0; idx < size; ++idx) {
vmap[old_tensors[idx]] = repl_tensors[idx];
rvmap[repl_tensors[idx]] = old_tensors[idx];
}
ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
// revamp the reduction stage.
reduce_stage->op = repl_tensors[0]->op;
reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars();
reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars;
reduce_stage->relations = Array<IterVarRelation>();
return factor_tensors;
}
} // namespace te
} // namespace tvm