| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file schedule_lang.cc |
| */ |
| #include <dmlc/thread_local.h> |
| #include <tvm/runtime/registry.h> |
| #include <tvm/te/operation.h> |
| #include <tvm/te/schedule.h> |
| |
| #include <stack> |
| #include <unordered_set> |
| |
| #include "graph.h" |
| |
| namespace tvm { |
| namespace te { |
| |
| // find first occurance location in leaf |
| template <typename T> |
| size_t FindNodeRef(ArrayNode* array_node, const T& v) { |
| const Object* n = v.get(); |
| for (size_t i = 0; i < array_node->size(); ++i) { |
| if (array_node->at(i).get() == n) return i; |
| } |
| return array_node->size(); |
| } |
| |
| size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) { |
| size_t pos = FindNodeRef(leaf_vars, v); |
| if (pos < leaf_vars->size()) return pos; |
| |
| if (FindNodeRef(all_vars, v) < all_vars->size()) { |
| LOG(FATAL) << "Operate on iter var " << v << "that has already been split"; |
| } else { |
| LOG(FATAL) << "Operate on iter var " << v << "that is not part of the schedule"; |
| } |
| return 0; |
| } |
| |
| DataType MatchDataType(std::vector<DataType> dtypes) { |
| int max_bits = -1; |
| for (const auto& dtype : dtypes) { |
| CHECK(dtype.is_int()); |
| CHECK(dtype.is_scalar()); |
| max_bits = std::max(max_bits, dtype.bits()); |
| } |
| return DataType::Int(max_bits); |
| } |
| |
| void SplitHelper(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts, |
| IterVar* p_outer, IterVar* p_inner) { |
| // Check if split is valid. |
| CHECK(parent->iter_type == kDataPar || parent->iter_type == kCommReduce || |
| parent->iter_type == kOrdered) |
| << "Cannot split on " << IterVarType2String(parent->iter_type); |
| IterVar outer = IterVar(Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type); |
| IterVar inner = IterVar(Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type); |
| *p_outer = outer; |
| *p_inner = inner; |
| // The splits |
| Array<IterVar>& all_vars = self->all_iter_vars; |
| Array<IterVar>& leaf_vars = self->leaf_iter_vars; |
| size_t pos = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), parent); |
| self->relations.push_back(Split(parent, outer, inner, factor, nparts)); |
| // add vars to all vars |
| all_vars.push_back(outer); |
| all_vars.push_back(inner); |
| // replace the position. |
| leaf_vars.erase(leaf_vars.begin() + pos); |
| leaf_vars.insert(leaf_vars.begin() + pos, inner); |
| leaf_vars.insert(leaf_vars.begin() + pos, outer); |
| } |
| |
| Stage::Stage(Operation op) { |
| auto n = make_object<StageNode>(); |
| n->op = op; |
| n->origin_op = op; |
| n->all_iter_vars = op->root_iter_vars(); |
| // remove opaque var from leaf. |
| Array<IterVar> clean; |
| for (IterVar iv : n->all_iter_vars) { |
| if (iv->iter_type != kOpaque) clean.push_back(iv); |
| } |
| if (clean.size() == n->all_iter_vars.size()) { |
| n->leaf_iter_vars = n->all_iter_vars; |
| } else { |
| n->leaf_iter_vars = clean; |
| } |
| data_ = std::move(n); |
| } |
| |
| bool Stage::is_scheduled() const { |
| const StageNode* n = operator->(); |
| return !(n->relations.empty() && n->attach_type == kGroupRoot && |
| n->all_iter_vars.same_as(n->leaf_iter_vars)); |
| } |
| |
| Stage Stage::GetAttachSpec() const { |
| Stage attach_spec = *this; |
| while (attach_spec->attach_type == kGroupRoot && attach_spec->group.defined()) { |
| attach_spec = attach_spec->group; |
| } |
| return attach_spec; |
| } |
| |
| Stage& Stage::set_scope(std::string scope) { // NOLINT(*) |
| (*this)->scope = scope; |
| return *this; |
| } |
| |
| Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) |
| CHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates"; |
| // Group constraint checking. |
| Stage group = (*this)->group; |
| if (group.defined()) { |
| Stage pg = parent->group; |
| while (pg.defined() && !pg.same_as(group)) { |
| pg = pg->group; |
| } |
| CHECK(pg.same_as(group)) << "Can only assign compute_at to stages within the same group"; |
| } |
| |
| (*this)->attach_type = kScope; |
| (*this)->attach_ivar = scope; |
| (*this)->attach_stage = parent; |
| bool found = false; |
| for (size_t i = 0; i < parent->leaf_iter_vars.size(); ++i) { |
| if (scope == parent->leaf_iter_vars[i]) { |
| found = true; |
| break; |
| } |
| } |
| CHECK(found) << "Cannot find the axis " << scope << " in parent's leaf_iter_vars" |
| << " parent=" << parent; |
| return *this; |
| } |
| |
| Stage& Stage::compute_inline() { // NOLINT(*) |
| CHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates"; |
| (*this)->attach_type = kInline; |
| return *this; |
| } |
| |
| Stage& Stage::compute_root() { // NOLINT(*) |
| CHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates"; |
| (*this)->attach_type = kGroupRoot; |
| return *this; |
| } |
| |
| Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*) |
| StageNode* self = operator->(); |
| CHECK(ivar->iter_type == kDataPar || ivar->iter_type == kCommReduce) |
| << "Cannot bind " << IterVarType2String(ivar->iter_type) << " to thread"; |
| CHECK(thread_ivar->iter_type == kThreadIndex) |
| << "Cannot rebase by " << IterVarType2String(ivar->iter_type) |
| << ", only thread axis is allowed so far"; |
| ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); |
| ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); |
| FindLeafVar(all_vars, leaf_vars, ivar); |
| |
| auto it = self->iter_var_attrs.find(ivar); |
| ObjectPtr<IterVarAttrNode> n; |
| if (it != self->iter_var_attrs.end()) { |
| n = make_object<IterVarAttrNode>(*(*it).second.operator->()); |
| if (n->bind_thread.defined() && !n->bind_thread.same_as(thread_ivar)) { |
| LOG(WARNING) << "Axis " << ivar << " is already bind to another thread " << n->bind_thread; |
| } |
| } else { |
| n = make_object<IterVarAttrNode>(); |
| } |
| n->bind_thread = thread_ivar; |
| self->iter_var_attrs.Set(ivar, IterVarAttr(n)); |
| return *this; |
| } |
| |
| Stage& Stage::env_threads(Array<IterVar> threads) { |
| StageNode* self = operator->(); |
| CHECK(self->op.defined() && self->op.as<ScanOpNode>()) |
| << "env_threads is only valid for composite ops such as ScanOp"; |
| CHECK_EQ(self->env_threads.size(), 0U) << "Already set env_threads"; |
| Array<IterVar>& leaf_vars = self->leaf_iter_vars; |
| Array<IterVar>& all_vars = self->all_iter_vars; |
| std::vector<ObjectRef> temp; |
| for (IterVar iv : threads) { |
| temp.push_back(iv); |
| } |
| leaf_vars.insert(leaf_vars.begin(), temp.begin(), temp.end()); |
| all_vars.insert(all_vars.end(), temp.begin(), temp.end()); |
| self->env_threads = threads; |
| return *this; |
| } |
| |
| Stage& Stage::set_store_predicate(PrimExpr predicate) { |
| StageNode* self = operator->(); |
| self->store_predicate = predicate; |
| return *this; |
| } |
| |
| Stage& Stage::split(IterVar parent, PrimExpr factor, IterVar* p_outer, |
| IterVar* p_inner) { // NOLINT(*) |
| SplitHelper(operator->(), parent, factor, PrimExpr(), p_outer, p_inner); |
| return *this; |
| } |
| |
| Stage& Stage::split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, |
| IterVar* p_inner) { // NOLINT(*) |
| SplitHelper(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner); |
| return *this; |
| } |
| |
| Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT(*) |
| StageNode* self = operator->(); |
| CHECK(outer->iter_type == kDataPar || outer->iter_type == kCommReduce || |
| outer->iter_type == kOrdered) |
| << "Cannot fuse " << IterVarType2String(outer->iter_type); |
| CHECK(inner->iter_type == kDataPar || inner->iter_type == kCommReduce || |
| inner->iter_type == kOrdered) |
| << "Cannot fuse " << IterVarType2String(inner->iter_type); |
| |
| IterVarType iter_type = outer->iter_type; |
| if (inner->iter_type > iter_type) iter_type = inner->iter_type; |
| std::string fused_name = outer->var->name_hint + "." + inner->var->name_hint + ".fused"; |
| DataType iter_dtype = MatchDataType({inner->var.dtype(), outer->var.dtype()}); |
| |
| IterVar fused = IterVar(Range(), Var(fused_name, iter_dtype), iter_type); |
| |
| Array<IterVar>& all_vars = self->all_iter_vars; |
| Array<IterVar>& leaf_vars = self->leaf_iter_vars; |
| |
| size_t pos_inner = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), inner); |
| size_t pos_outer = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), outer); |
| if (pos_inner + 1 == pos_outer) { |
| std::swap(outer, inner); |
| std::swap(pos_inner, pos_outer); |
| } |
| CHECK_EQ(pos_inner, pos_outer + 1) |
| << "Can only fuse iterations that are consecutive between each other"; |
| self->relations.push_back(Fuse(outer, inner, fused)); |
| all_vars.push_back(fused); |
| leaf_vars.erase(leaf_vars.begin() + pos_outer, leaf_vars.begin() + pos_inner + 1); |
| leaf_vars.insert(leaf_vars.begin() + pos_outer, fused); |
| *p_target = fused; |
| return *this; |
| } |
| |
| Stage& Stage::fuse(const Array<IterVar>& axes, IterVar* p_target) { // NOLINT(*) |
| if (axes.size() != 0) { |
| IterVar fused = axes[0]; |
| for (size_t i = 1; i < axes.size(); ++i) { |
| this->fuse(fused, axes[i], &fused); |
| } |
| *p_target = std::move(fused); |
| } else { |
| StageNode* self = operator->(); |
| // special handle fuse empty array. |
| // insert at the outer most loop |
| IterVar singleton = |
| IterVar(Range::FromMinExtent(0, 1), Var("singleton", DataType::Int(32)), kDataPar); |
| self->relations.push_back(Singleton(singleton)); |
| Array<IterVar>& all_vars = self->all_iter_vars; |
| Array<IterVar>& leaf_vars = self->leaf_iter_vars; |
| all_vars.push_back(singleton); |
| leaf_vars.insert(leaf_vars.begin(), singleton); |
| *p_target = singleton; |
| } |
| return *this; |
| } |
| |
| Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*) |
| std::unordered_set<IterVar> seen_var; |
| StageNode* self = operator->(); |
| for (IterVar iv : order) { |
| CHECK(iv->iter_type == kDataPar || iv->iter_type == kCommReduce || |
| iv->iter_type == kThreadIndex) |
| << "Cannot reorder IterVar(" << IterVarType2String(iv->iter_type) << ")"; |
| |
| CHECK_EQ(seen_var.count(iv), 0) << "Same axis can not appear more than once " << iv; |
| seen_var.insert(iv); |
| } |
| ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); |
| ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); |
| std::vector<size_t> pos; |
| |
| for (size_t i = 0; i < order.size(); ++i) { |
| pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i])); |
| } |
| std::vector<ObjectRef> temp; |
| for (size_t i = 0; i < pos.size(); ++i) { |
| temp.emplace_back(leaf_vars->at(pos[i])); |
| } |
| std::sort(pos.begin(), pos.end()); |
| for (size_t i = 0; i < pos.size(); ++i) { |
| leaf_vars->SetItem(pos[i], temp[i]); |
| } |
| return *this; |
| } |
| |
| Stage& Stage::tile(IterVar x_parent, IterVar y_parent, PrimExpr x_factor, PrimExpr y_factor, |
| IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner) { |
| split(x_parent, x_factor, p_x_outer, p_x_inner); |
| split(y_parent, y_factor, p_y_outer, p_y_inner); |
| reorder(Array<IterVar>({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner})); |
| return *this; |
| } |
| |
| template <typename FUpdate> |
| inline void UpdateIterVarAttr(StageNode* self, IterVar var, FUpdate fupdate, |
| bool need_leaf = true) { |
| if (need_leaf) { |
| ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); |
| ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); |
| FindLeafVar(all_vars, leaf_vars, var); |
| } |
| auto it = self->iter_var_attrs.find(var); |
| ObjectPtr<IterVarAttrNode> n; |
| if (it != self->iter_var_attrs.end()) { |
| n = make_object<IterVarAttrNode>(*(*it).second.operator->()); |
| } else { |
| n = make_object<IterVarAttrNode>(); |
| } |
| fupdate(n.get()); |
| self->iter_var_attrs.Set(var, IterVarAttr(n)); |
| } |
| |
| inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) { |
| UpdateIterVarAttr(self, var, [iter_type](IterVarAttrNode* n) { n->iter_type = iter_type; }); |
| } |
| |
| Stage& Stage::vectorize(IterVar var) { // NOLINT(*) |
| CHECK(var->iter_type == kDataPar || var->iter_type == kOpaque || var->iter_type == kUnrolled || |
| var->iter_type == kVectorized || var->iter_type == kTensorized || |
| var->iter_type == kParallelized) |
| << "Cannot vectorize on " << IterVarType2String(var->iter_type); |
| SetAttrIterType(operator->(), var, kVectorized); |
| return *this; |
| } |
| |
| Stage& Stage::tensorize(IterVar var, TensorIntrin f) { // NOLINT(*) |
| UpdateIterVarAttr(operator->(), var, [f](IterVarAttrNode* n) { |
| n->iter_type = kTensorized; |
| n->tensor_intrin = f; |
| }); |
| return *this; |
| } |
| |
| Stage& Stage::unroll(IterVar var) { // NOLINT(*) |
| SetAttrIterType(operator->(), var, kUnrolled); |
| return *this; |
| } |
| |
| Stage& Stage::parallel(IterVar var) { // NOLINT(*) |
| SetAttrIterType(operator->(), var, kParallelized); |
| return *this; |
| } |
| |
| Stage& Stage::pragma(IterVar var, const std::string& pragma_type, |
| const PrimExpr& pragma_value) { // NOLINT(*) |
| if (pragma_type == "unroll") { |
| this->unroll(var); |
| } else if (pragma_type == "vectorize") { |
| this->vectorize(var); |
| } else { |
| UpdateIterVarAttr(operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) { |
| n->pragma_keys.push_back(tir::StringImm(pragma_type)); |
| n->pragma_values.push_back(pragma_value); |
| }); |
| } |
| return *this; |
| } |
| |
| Stage& Stage::prefetch(const Tensor& tensor, IterVar var, PrimExpr offset) { |
| StageNode* self = operator->(); |
| ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); |
| ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); |
| FindLeafVar(all_vars, leaf_vars, var); |
| auto it = self->iter_var_attrs.find(var); |
| ObjectPtr<IterVarAttrNode> n; |
| if (it != self->iter_var_attrs.end()) { |
| n = make_object<IterVarAttrNode>(*(*it).second.operator->()); |
| } else { |
| n = make_object<IterVarAttrNode>(); |
| } |
| n->prefetch_data.push_back(tensor); |
| n->prefetch_offset.push_back(offset); |
| self->iter_var_attrs.Set(var, IterVarAttr(n)); |
| return *this; |
| } |
| |
| Stage& Stage::storage_align(IterVar axis, int factor, int offset) { |
| StageNode* self = operator->(); |
| UpdateIterVarAttr( |
| self, axis, |
| [factor, offset](IterVarAttrNode* n) { |
| n->dim_align_factor = factor; |
| n->dim_align_offset = offset; |
| }, |
| false); |
| return *this; |
| } |
| |
| Stage& Stage::double_buffer() { |
| StageNode* self = operator->(); |
| CHECK(!self->is_output) << "Cannot apply double buffer on output"; |
| self->double_buffer = true; |
| return *this; |
| } |
| |
| Stage CopyStage(const Stage& s) { |
| ObjectPtr<StageNode> n = make_object<StageNode>(*s.operator->()); |
| return Stage(n); |
| } |
| |
| Schedule Schedule::copy() const { |
| // map of stages. |
| const ScheduleNode* self = operator->(); |
| std::unordered_map<Stage, Stage, ObjectPtrHash, ObjectPtrEqual> smap; |
| ObjectPtr<ScheduleNode> n = make_object<ScheduleNode>(); |
| n->outputs = self->outputs; |
| // Copy the stages. |
| for (Stage s : self->stages) { |
| Stage scopy = CopyStage(s); |
| smap[s] = scopy; |
| n->stages.push_back(scopy); |
| } |
| for (Stage g : self->groups) { |
| Stage gcopy = CopyStage(g); |
| smap[g] = gcopy; |
| n->groups.push_back(gcopy); |
| } |
| // Remaps the reference relations. |
| for (auto kv : self->stage_map) { |
| n->stage_map.Set(kv.first, smap.at(kv.second)); |
| } |
| for (Stage s : n->stages) { |
| if (s->attach_stage.defined()) { |
| CHECK(smap.find(s->attach_stage) != smap.end()) |
| << s->attach_stage << " not found in " << (*this); |
| s->attach_stage = smap.at(s->attach_stage); |
| } |
| if (s->group.defined()) { |
| CHECK(smap.find(s->group) != smap.end()) << s->group << " not found in " << (*this); |
| s->group = smap.at(s->group); |
| } |
| } |
| for (Stage s : n->groups) { |
| if (s->attach_stage.defined()) { |
| CHECK(smap.find(s->attach_stage) != smap.end()) |
| << s->attach_stage << " not found in " << (*this); |
| s->attach_stage = smap.at(s->attach_stage); |
| } |
| if (s->group.defined()) { |
| CHECK(smap.find(s->group) != smap.end()) << s->group << " not found in " << (*this); |
| s->group = smap.at(s->group); |
| } |
| } |
| return Schedule(n); |
| } |
| |
| Stage Schedule::operator[](const Operation& op) { |
| auto it = (*this)->stage_map.find(op); |
| CHECK(it != (*this)->stage_map.end()) |
| << "Cannot find Stage for operator " << op << " in the schedule"; |
| return (*it).second; |
| } |
| |
| Stage LeastCommonAncestor(Stage g1, Stage g2) { |
| if (!g1.defined()) return g1; |
| if (!g2.defined()) return g2; |
| if (g1.same_as(g2)) return g1; |
| Stage g = g1; |
| while (g.defined()) { |
| if (g.same_as(g2)) return g2; |
| g = g->group; |
| } |
| g = g2; |
| while (g.defined()) { |
| if (g.same_as(g1)) return g1; |
| g = g->group; |
| } |
| return g; |
| } |
| |
| Array<Tensor> RemapTensor(ScheduleNode* self, const Array<Tensor>& arr) { |
| self->InitCache(); |
| const auto& op2stage_cache = self->op2stage_cache_; |
| Array<Tensor> ret; |
| for (Tensor t : arr) { |
| if (!op2stage_cache.count(t->op.get())) { |
| CHECK(self->stage_map.count(t->op)) << "Given tensor is not in the schedule plan"; |
| t = self->stage_map[t->op]->op.output(t->value_index); |
| } |
| ret.push_back(t); |
| } |
| return ret; |
| } |
| |
| // Group the schedule stages. |
| Stage Schedule::create_group(const Array<Tensor>& outputs, const Array<Tensor>& inputs, |
| bool include_inputs) { |
| ScheduleNode* self = operator->(); |
| self->InitCache(); |
| const auto& op2stage_cache = self->op2stage_cache_; |
| // Get the ops. |
| Array<Operation> ops = |
| te::GetSubGraph(RemapTensor(self, outputs), RemapTensor(self, inputs), include_inputs); |
| // local counter entry |
| // Automatically initialize to 0 during creation. |
| struct Entry { |
| int count{0}; |
| }; |
| // Map of group->touched counter |
| std::unordered_map<Stage, Entry, ObjectPtrHash, ObjectPtrEqual> counter; |
| // The parent group; |
| Stage parent_group; |
| // Detect common parent and child. |
| for (size_t i = 0; i < ops.size(); ++i) { |
| Operation op = ops[i]; |
| auto it = op2stage_cache.find(op.get()); |
| CHECK(it != op2stage_cache.end()); |
| Stage op_group = it->second->group; |
| if (i == 0) { |
| parent_group = op_group; |
| } else { |
| parent_group = LeastCommonAncestor(parent_group, op_group); |
| } |
| if (op_group.defined()) { |
| ++counter[op_group].count; |
| } |
| } |
| // Create the new group stage. |
| Stage gstage(make_object<StageNode>()); |
| gstage->group = parent_group; |
| if (parent_group.defined()) { |
| ++parent_group->num_child_stages; |
| } |
| // Propagate the counter statistics from by checking if subgroup |
| // Is full and propagate. |
| std::vector<Stage> stack; |
| for (auto& kv : counter) { |
| if (!kv.first.same_as(parent_group)) { |
| if (kv.first->num_child_stages == kv.second.count) { |
| stack.push_back(kv.first); |
| } |
| } |
| } |
| while (!stack.empty()) { |
| Stage g = stack.back(); |
| stack.pop_back(); |
| if (g->group.defined() && !g->group.same_as(parent_group)) { |
| Entry& e = counter[g->group]; |
| ++e.count; |
| if (e.count == g->group->num_child_stages) { |
| stack.push_back(g->group); |
| } |
| } |
| } |
| // Verification and remappig the subgroups. |
| for (auto& kv : counter) { |
| if (kv.first.same_as(parent_group)) continue; |
| CHECK_EQ(kv.first->num_child_stages, kv.second.count) |
| << "Trying to group region that intersect with an already existed group"; |
| if (kv.first->group.same_as(parent_group)) { |
| Stage s = kv.first; |
| s->group = gstage; |
| ++gstage->num_child_stages; |
| if (parent_group.defined()) { |
| --parent_group->num_child_stages; |
| } |
| } |
| } |
| // Remap the group of op stages. |
| for (Operation op : ops) { |
| auto it = op2stage_cache.find(op.get()); |
| CHECK(it != op2stage_cache.end()); |
| Stage s = it->second; |
| if (s->group.same_as(parent_group)) { |
| s->group = gstage; |
| ++gstage->num_child_stages; |
| if (parent_group.defined()) { |
| --parent_group->num_child_stages; |
| } |
| } |
| } |
| // Correct the attach to keep everything in group. |
| for (Operation op : ops) { |
| auto it = op2stage_cache.find(op.get()); |
| CHECK(it != op2stage_cache.end()); |
| Stage s = it->second; |
| if (s->attach_type == kScope) { |
| Stage cg = LeastCommonAncestor(s->attach_stage->group, gstage); |
| if (!cg.same_as(gstage)) { |
| LOG(WARNING) << "group invalidates some previous compute_at relation " |
| << " and keeps things to be computed inside the group"; |
| s.compute_root(); |
| } |
| } |
| } |
| |
| self->groups.push_back(gstage); |
| return gstage; |
| } |
| |
| void ScheduleNode::InvalidateCache() { op2stage_cache_.clear(); } |
| |
| void ScheduleNode::InitCache() { |
| if (op2stage_cache_.size() == stages.size()) return; |
| InvalidateCache(); |
| for (Stage s : stages) { |
| if (s->op.defined()) { |
| op2stage_cache_[s->op.get()] = s; |
| } |
| } |
| CHECK_EQ(op2stage_cache_.size(), stages.size()); |
| } |
| |
| bool ScheduleNode::Contain(const Operation& op) const { |
| return stage_map.find(op) != stage_map.end(); |
| } |
| |
| Schedule::Schedule(Array<Operation> ops) { |
| auto n = make_object<ScheduleNode>(); |
| data_ = n; |
| n->outputs = ops; |
| auto g = te::CreateReadGraph(n->outputs); |
| Array<Operation> post_order = te::PostDFSOrder(n->outputs, g); |
| // output set. |
| std::unordered_set<Operation> output_set; |
| for (Operation x : ops) { |
| output_set.insert(x); |
| } |
| for (Operation op : post_order) { |
| Stage stage(op); |
| stage->is_output = output_set.count(op) != 0; |
| n->stages.push_back(stage); |
| n->stage_map.Set(op, stage); |
| // mark scan updates. |
| if (const ScanOpNode* scan = op.as<ScanOpNode>()) { |
| Array<Tensor> inputs; |
| for (Tensor t : scan->state_placeholder) { |
| inputs.push_back(t); |
| } |
| for (Tensor t : scan->inputs) { |
| inputs.push_back(t); |
| } |
| // Create the scan group. |
| Stage scan_group = this->create_group(scan->update, inputs, false); |
| scan_group->attach_type = kScanUpdate; |
| scan_group->attach_stage = stage; |
| |
| for (size_t i = 0; i < scan->update.size(); ++i) { |
| Stage s = n->stage_map[scan->update[i]->op]; |
| CHECK(scan_group.same_as(s->group)); |
| } |
| } |
| } |
| } |
| |
| Split::Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts) { |
| auto n = make_object<SplitNode>(); |
| n->parent = parent; |
| n->outer = outer; |
| n->inner = inner; |
| n->factor = factor; |
| n->nparts = nparts; |
| data_ = std::move(n); |
| } |
| |
| Fuse::Fuse(IterVar outer, IterVar inner, IterVar fused) { |
| auto n = make_object<FuseNode>(); |
| n->outer = outer; |
| n->inner = inner; |
| n->fused = fused; |
| data_ = std::move(n); |
| } |
| |
| Rebase::Rebase(IterVar parent, IterVar rebased) { |
| auto n = make_object<RebaseNode>(); |
| n->parent = parent; |
| n->rebased = rebased; |
| data_ = std::move(n); |
| } |
| |
| Singleton::Singleton(IterVar iter) { |
| auto n = make_object<SingletonNode>(); |
| n->iter = iter; |
| data_ = std::move(n); |
| } |
| |
| SpecializedCondition::SpecializedCondition(Array<PrimExpr> conditions) { |
| ObjectPtr<SpecializedConditionNode> n = make_object<SpecializedConditionNode>(); |
| n->clauses = std::move(conditions); |
| data_ = std::move(n); |
| } |
| |
| /*! \brief Entry to hold the SpecializedCondition context stack. */ |
| struct TVMSpecializationThreadLocalEntry { |
| /*! \brief The current specialized condition */ |
| std::stack<SpecializedCondition> condition_stack; |
| }; |
| |
| /*! \brief Thread local store to hold the Target context stack. */ |
| typedef dmlc::ThreadLocalStore<TVMSpecializationThreadLocalEntry> TVMSpecializationThreadLocalStore; |
| |
| void SpecializedCondition::EnterWithScope() { |
| TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get(); |
| entry->condition_stack.push(*this); |
| } |
| |
| void SpecializedCondition::ExitWithScope() { |
| TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get(); |
| CHECK(!entry->condition_stack.empty()); |
| CHECK(entry->condition_stack.top().same_as(*this)); |
| entry->condition_stack.pop(); |
| } |
| |
| SpecializedCondition SpecializedCondition::Current() { |
| TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get(); |
| SpecializedCondition cond; |
| if (entry->condition_stack.size() > 0) { |
| cond = entry->condition_stack.top(); |
| } |
| return cond; |
| } |
| |
| class SpecializedCondition::Internal { |
| public: |
| static void EnterScope(SpecializedCondition cond) { cond.EnterWithScope(); } |
| |
| static void ExitScope(SpecializedCondition cond) { cond.ExitWithScope(); } |
| }; |
| |
| TVM_REGISTER_NODE_TYPE(StageNode); |
| TVM_REGISTER_NODE_TYPE(IterVarAttrNode); |
| TVM_REGISTER_NODE_TYPE(SplitNode); |
| TVM_REGISTER_NODE_TYPE(FuseNode); |
| TVM_REGISTER_NODE_TYPE(RebaseNode); |
| TVM_REGISTER_NODE_TYPE(SingletonNode); |
| TVM_REGISTER_NODE_TYPE(ScheduleNode); |
| TVM_REGISTER_NODE_TYPE(SpecializedConditionNode); |
| |
| // Printer |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<StageNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const StageNode*>(node.get()); |
| if (op->op.defined()) { |
| p->stream << "stage(" << op->origin_op->name << ", " << op << ")"; |
| } else { |
| p->stream << "group-stage(" << op << ")"; |
| } |
| }) |
| .set_dispatch<IterVarAttrNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const IterVarAttrNode*>(node.get()); |
| p->stream << IterVarType2String(op->iter_type); |
| }) |
| .set_dispatch<SplitNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const SplitNode*>(node.get()); |
| p->stream << "split(parent="; |
| p->Print(op->parent); |
| p->stream << ", outer="; |
| p->Print(op->outer); |
| p->stream << ", inner="; |
| p->Print(op->inner); |
| p->stream << ')'; |
| }) |
| .set_dispatch<FuseNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const FuseNode*>(node.get()); |
| p->stream << "split("; |
| p->stream << "outer="; |
| p->Print(op->outer); |
| p->stream << ", inner="; |
| p->Print(op->inner); |
| p->stream << ", fused="; |
| p->Print(op->fused); |
| p->stream << ')'; |
| }) |
| .set_dispatch<RebaseNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const RebaseNode*>(node.get()); |
| p->stream << "rebase("; |
| p->stream << "parent="; |
| p->Print(op->parent); |
| p->stream << ", rebased="; |
| p->Print(op->rebased); |
| p->stream << ')'; |
| }) |
| .set_dispatch<SingletonNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const SingletonNode*>(node.get()); |
| p->stream << "singleton("; |
| p->Print(op->iter); |
| p->stream << ')'; |
| }) |
| .set_dispatch<ScheduleNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const ScheduleNode*>(node.get()); |
| p->stream << "schedule(" << op << ")"; |
| }) |
| .set_dispatch<SpecializedConditionNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const SpecializedConditionNode*>(node.get()); |
| p->stream << "specialized_condition("; |
| p->Print(op->clauses); |
| p->stream << ')'; |
| }); |
| |
| TVM_REGISTER_GLOBAL("te.CreateSchedule").set_body_typed(create_schedule); |
| |
| TVM_REGISTER_GLOBAL("te.StageSetScope").set_body_method(&Stage::set_scope); |
| |
| TVM_REGISTER_GLOBAL("te.StageBind").set_body_method(&Stage::bind); |
| |
| TVM_REGISTER_GLOBAL("te.StageSplitByFactor") |
| .set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) { |
| IterVar outer, inner; |
| stage.split(parent, factor, &outer, &inner); |
| return Array<IterVar>({outer, inner}); |
| }); |
| |
| TVM_REGISTER_GLOBAL("te.StageSplitByNParts") |
| .set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) { |
| IterVar outer, inner; |
| stage.split_by_nparts(parent, nparts, &outer, &inner); |
| return Array<IterVar>({outer, inner}); |
| }); |
| |
| TVM_REGISTER_GLOBAL("te.StageFuse").set_body_typed([](Stage stage, Array<IterVar> axes) { |
| IterVar fused; |
| stage.fuse(axes, &fused); |
| return fused; |
| }); |
| |
| TVM_REGISTER_GLOBAL("te.StageComputeAt").set_body_method(&Stage::compute_at); |
| |
| TVM_REGISTER_GLOBAL("te.StageComputeInline").set_body_method(&Stage::compute_inline); |
| |
| TVM_REGISTER_GLOBAL("te.StageComputeRoot").set_body_method(&Stage::compute_root); |
| |
| TVM_REGISTER_GLOBAL("te.StageReorder").set_body_method(&Stage::reorder); |
| |
| TVM_REGISTER_GLOBAL("te.StageTile") |
| .set_body_typed([](Stage stage, IterVar x_parent, IterVar y_parent, PrimExpr x_factor, |
| PrimExpr y_factor) { |
| IterVar x_outer, y_outer, x_inner, y_inner; |
| stage.tile(x_parent, y_parent, x_factor, y_factor, &x_outer, &y_outer, &x_inner, &y_inner); |
| return Array<IterVar>({x_outer, y_outer, x_inner, y_inner}); |
| }); |
| |
| TVM_REGISTER_GLOBAL("te.StageEnvThreads").set_body_method(&Stage::env_threads); |
| |
| TVM_REGISTER_GLOBAL("te.StageSetStorePredicate").set_body_method(&Stage::set_store_predicate); |
| |
| TVM_REGISTER_GLOBAL("te.StageUnroll").set_body_method(&Stage::unroll); |
| |
| TVM_REGISTER_GLOBAL("te.StageVectorize").set_body_method(&Stage::vectorize); |
| |
| TVM_REGISTER_GLOBAL("te.StageTensorize").set_body_method(&Stage::tensorize); |
| |
| TVM_REGISTER_GLOBAL("te.StageParallel").set_body_method(&Stage::parallel); |
| |
| TVM_REGISTER_GLOBAL("te.StagePragma").set_body_method(&Stage::pragma); |
| |
| TVM_REGISTER_GLOBAL("te.StagePrefetch").set_body_method(&Stage::prefetch); |
| |
| TVM_REGISTER_GLOBAL("te.StageStorageAlign").set_body_method(&Stage::storage_align); |
| |
| TVM_REGISTER_GLOBAL("te.StageDoubleBuffer").set_body_method(&Stage::double_buffer); |
| |
| TVM_REGISTER_GLOBAL("te.ScheduleNormalize").set_body_method(&Schedule::normalize); |
| |
| TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup").set_body_method(&Schedule::create_group); |
| |
| TVM_REGISTER_GLOBAL("te.ScheduleCacheRead").set_body_method(&Schedule::cache_read); |
| |
| TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite").set_body([](TVMArgs args, TVMRetValue* ret) { |
| if (args[1].IsObjectRef<Tensor>()) { |
| *ret = args[0].operator Schedule().cache_write(args[1].operator Tensor(), args[2]); |
| } else { |
| *ret = args[0].operator Schedule().cache_write(args[1].operator Array<Tensor>(), args[2]); |
| } |
| }); |
| |
| TVM_REGISTER_GLOBAL("te.ScheduleRFactor").set_body_method(&Schedule::rfactor); |
| |
| TVM_REGISTER_GLOBAL("te.CreateSpecializedCondition").set_body_typed([](Array<PrimExpr> condition) { |
| return SpecializedCondition(condition); |
| }); |
| |
| TVM_REGISTER_GLOBAL("te.GetCurrentSpecialization").set_body([](TVMArgs args, TVMRetValue* ret) { |
| *ret = SpecializedCondition::Current(); |
| }); |
| |
| TVM_REGISTER_GLOBAL("te.EnterSpecializationScope") |
| .set_body_typed(SpecializedCondition::Internal::EnterScope); |
| |
| TVM_REGISTER_GLOBAL("te.ExitSpecializationScope") |
| .set_body_typed(SpecializedCondition::Internal::ExitScope); |
| |
| } // namespace te |
| } // namespace tvm |