| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file auto_scheduler/transform_step.cc |
| * \brief Transformation steps. These steps are used to manipulate the LoopState. |
| * They are similar to the schedule primitives in te::Stage. |
| */ |
| |
| #include <tvm/auto_scheduler/compute_dag.h> |
| #include <tvm/auto_scheduler/loop_state.h> |
| #include <tvm/auto_scheduler/transform_step.h> |
| #include <tvm/runtime/registry.h> |
| #include <tvm/te/operation.h> |
| |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "utils.h" |
| |
| namespace dmlc { |
| namespace json { |
| |
| template <> |
| struct Handler<::tvm::Array<::tvm::Integer>> { |
| inline static void Write(dmlc::JSONWriter* writer, const ::tvm::Array<::tvm::Integer>& array) { |
| writer->BeginArray(false); |
| for (const auto& i : array) { |
| CHECK(i.defined()); |
| writer->WriteArrayItem(i->value); |
| } |
| writer->EndArray(); |
| } |
| inline static void Read(dmlc::JSONReader* reader, ::tvm::Array<::tvm::Integer>* array) { |
| array->clear(); |
| reader->BeginArray(); |
| while (reader->NextArrayItem()) { |
| int value; |
| Handler<int>::Read(reader, &value); |
| array->push_back(value); |
| } |
| } |
| }; |
| |
| template <> |
| struct Handler<::tvm::Array<::tvm::Optional<::tvm::Integer>>> { |
| inline static void Write(dmlc::JSONWriter* writer, |
| const ::tvm::Array<::tvm::Optional<::tvm::Integer>>& array) { |
| writer->BeginArray(false); |
| for (const auto& i : array) { |
| CHECK(i); |
| writer->WriteArrayItem(i.value()->value); |
| } |
| writer->EndArray(); |
| } |
| inline static void Read(dmlc::JSONReader* reader, |
| ::tvm::Array<::tvm::Optional<::tvm::Integer>>* array) { |
| array->clear(); |
| reader->BeginArray(); |
| while (reader->NextArrayItem()) { |
| int value; |
| Handler<int>::Read(reader, &value); |
| array->push_back(::tvm::Integer(value)); |
| } |
| } |
| }; |
| |
| } // namespace json |
| } // namespace dmlc |
| |
| namespace tvm { |
| namespace auto_scheduler { |
| |
| // Update the te::stage to tir::IterVar axis mapping |
| void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes) { |
| if (auto pop = stage->op.as<te::ComputeOpNode>()) { |
| Array<IterVar> axes; |
| for (const auto& axis : pop->axis) { |
| axes.push_back(axis); |
| } |
| for (const auto& axis : pop->reduce_axis) { |
| axes.push_back(axis); |
| } |
| stage_to_axes->Set(stage, std::move(axes)); |
| } else if (stage->op->IsInstance<te::PlaceholderOpNode>()) { |
| {} // do nothing on Placeholder |
| } else { |
| LOG(FATAL) << "Invalid op " << stage->op; |
| } |
| } |
| |
| const char* IteratorAnnotationString[] = { |
| "for", // kNone = 0 |
| "unroll", // kUnroll = 1 |
| "vectorize", // kVectorize = 2 |
| "parallel", // kParallel = 3 |
| "vthread", // kVThread = 4 |
| "blockIdx.x", // kBlockX = 5 |
| "threadIdx.x", // kThreadX = 6 |
| "blockIdx.y", // kBlockY = 7 |
| "threadIdx.y", // kThreadY = 8 |
| "blockIdx.z", // kBlockZ = 9 |
| "threadIdx.z", // kThreadZ = 10 |
| "tensorize" // kTensorized = 11 |
| }; |
| |
| Step StepReadFromRecord(dmlc::JSONReader* reader) { |
| std::string name; |
| bool s; |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&name); |
| if (name == AnnotationStepNode::record_prefix_str) { |
| return AnnotationStep(reader); |
| } else if (name == FuseStepNode::record_prefix_str) { |
| return FuseStep(reader); |
| } else if (name == PragmaStepNode::record_prefix_str) { |
| return PragmaStep(reader); |
| } else if (name == ReorderStepNode::record_prefix_str) { |
| return ReorderStep(reader); |
| } else if (name == SplitStepNode::record_prefix_str) { |
| return SplitStep(reader); |
| } else if (name == FollowSplitStepNode::record_prefix_str) { |
| return FollowSplitStep(reader); |
| } else if (name == FollowFusedSplitStepNode::record_prefix_str) { |
| return FollowFusedSplitStep(reader); |
| } else if (name == StorageAlignStepNode::record_prefix_str) { |
| return StorageAlignStep(reader); |
| } else if (name == ComputeAtStepNode::record_prefix_str) { |
| return ComputeAtStep(reader); |
| } else if (name == ComputeInlineStepNode::record_prefix_str) { |
| return ComputeInlineStep(reader); |
| } else if (name == ComputeRootStepNode::record_prefix_str) { |
| return ComputeRootStep(reader); |
| } else if (name == CacheReadStepNode::record_prefix_str) { |
| return CacheReadStep(reader); |
| } else if (name == CacheWriteStepNode::record_prefix_str) { |
| return CacheWriteStep(reader); |
| } else if (name == RfactorStepNode::record_prefix_str) { |
| return RfactorStep(reader); |
| } else { |
| LOG(FATAL) << "Invalid step format: " << name; |
| } |
| return Step(); |
| } |
| |
| void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { |
| // We need this runtime dispatcher because different steps have different function signatures |
| if (auto ps = step.as<AnnotationStepNode>()) { |
| ps->ApplyToState(state); |
| } else if (auto ps = step.as<FuseStepNode>()) { |
| ps->ApplyToState(state); |
| } else if (auto ps = step.as<PragmaStepNode>()) { |
| ps->ApplyToState(state); |
| } else if (auto ps = step.as<ReorderStepNode>()) { |
| ps->ApplyToState(state); |
| } else if (auto ps = step.as<SplitStepNode>()) { |
| ps->ApplyToState(state); |
| } else if (auto ps = step.as<FollowSplitStepNode>()) { |
| ps->ApplyToState(state); |
| } else if (auto ps = step.as<FollowFusedSplitStepNode>()) { |
| ps->ApplyToState(state); |
| } else if (auto ps = step.as<StorageAlignStepNode>()) { |
| ps->ApplyToState(state); |
| } else if (auto ps = step.as<ComputeAtStepNode>()) { |
| ps->ApplyToState(state); |
| } else if (auto ps = step.as<ComputeInlineStepNode>()) { |
| ps->ApplyToState(state); |
| } else if (auto ps = step.as<ComputeRootStepNode>()) { |
| ps->ApplyToState(state); |
| } else if (auto ps = step.as<CacheReadStepNode>()) { |
| ps->ApplyToState(state, dag); |
| } else if (auto ps = step.as<CacheWriteStepNode>()) { |
| ps->ApplyToState(state, dag); |
| } else if (auto ps = step.as<RfactorStepNode>()) { |
| ps->ApplyToState(state, dag); |
| } else { |
| LOG(FATAL) << "Invalid step: " << step; |
| } |
| } |
| |
| void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, |
| te::Schedule* schedule, const Array<Step>& transform_steps) { |
| if (auto ps = step.as<AnnotationStepNode>()) { |
| ps->ApplyToSchedule(stages, stage_to_axes); |
| } else if (auto ps = step.as<FuseStepNode>()) { |
| ps->ApplyToSchedule(stages, stage_to_axes); |
| } else if (auto ps = step.as<PragmaStepNode>()) { |
| ps->ApplyToSchedule(stages, stage_to_axes); |
| } else if (auto ps = step.as<ReorderStepNode>()) { |
| ps->ApplyToSchedule(stages, stage_to_axes); |
| } else if (auto ps = step.as<SplitStepNode>()) { |
| ps->ApplyToSchedule(stages, stage_to_axes); |
| } else if (auto ps = step.as<FollowSplitStepNode>()) { |
| ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); |
| } else if (auto ps = step.as<FollowFusedSplitStepNode>()) { |
| ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); |
| } else if (auto ps = step.as<StorageAlignStepNode>()) { |
| ps->ApplyToSchedule(stages, stage_to_axes); |
| } else if (auto ps = step.as<ComputeAtStepNode>()) { |
| ps->ApplyToSchedule(stages, stage_to_axes); |
| } else if (auto ps = step.as<ComputeInlineStepNode>()) { |
| ps->ApplyToSchedule(stages, stage_to_axes); |
| } else if (auto ps = step.as<ComputeRootStepNode>()) { |
| ps->ApplyToSchedule(stages, stage_to_axes); |
| } else if (auto ps = step.as<CacheReadStepNode>()) { |
| ps->ApplyToSchedule(stages, stage_to_axes, schedule); |
| } else if (auto ps = step.as<CacheWriteStepNode>()) { |
| ps->ApplyToSchedule(stages, stage_to_axes, schedule); |
| } else if (auto ps = step.as<RfactorStepNode>()) { |
| ps->ApplyToSchedule(stages, stage_to_axes, schedule); |
| } else { |
| LOG(FATAL) << "Invalid Step: " << step; |
| } |
| } |
| |
| String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes, te::Schedule* schedule, |
| const Array<Step>& transform_steps) { |
| if (auto ps = step.as<AnnotationStepNode>()) { |
| return ps->PrintAsPythonAPI(stages, stage_to_axes); |
| } else if (auto ps = step.as<FuseStepNode>()) { |
| return ps->PrintAsPythonAPI(stages, stage_to_axes); |
| } else if (auto ps = step.as<PragmaStepNode>()) { |
| return ps->PrintAsPythonAPI(stages, stage_to_axes); |
| } else if (auto ps = step.as<ReorderStepNode>()) { |
| return ps->PrintAsPythonAPI(stages, stage_to_axes); |
| } else if (auto ps = step.as<SplitStepNode>()) { |
| return ps->PrintAsPythonAPI(stages, stage_to_axes); |
| } else if (auto ps = step.as<FollowSplitStepNode>()) { |
| return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); |
| } else if (auto ps = step.as<FollowFusedSplitStepNode>()) { |
| return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); |
| } else if (auto ps = step.as<StorageAlignStepNode>()) { |
| return ps->PrintAsPythonAPI(stages, stage_to_axes); |
| } else if (auto ps = step.as<ComputeAtStepNode>()) { |
| return ps->PrintAsPythonAPI(stages, stage_to_axes); |
| } else if (auto ps = step.as<ComputeInlineStepNode>()) { |
| return ps->PrintAsPythonAPI(stages, stage_to_axes); |
| } else if (auto ps = step.as<ComputeRootStepNode>()) { |
| return ps->PrintAsPythonAPI(stages, stage_to_axes); |
| } else if (auto ps = step.as<CacheReadStepNode>()) { |
| return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); |
| } else if (auto ps = step.as<CacheWriteStepNode>()) { |
| return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); |
| } else if (auto ps = step.as<RfactorStepNode>()) { |
| return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); |
| } else { |
| LOG(FATAL) << "Invalid Step: " << step; |
| } |
| return ""; |
| } |
| |
| /********** Steps working on single stage **********/ |
| |
| /********** Annotation **********/ |
| AnnotationStep::AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann) { |
| auto node = make_object<AnnotationStepNode>(); |
| node->stage_id = stage_id; |
| node->iter_id = iter_id; |
| node->annotation = ann; |
| data_ = std::move(node); |
| } |
| |
| AnnotationStep::AnnotationStep(dmlc::JSONReader* reader) { |
| auto node = make_object<AnnotationStepNode>(); |
| bool s; |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->stage_id); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->iter_id); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| int int_val; |
| reader->Read(&int_val); |
| node->annotation = IteratorAnnotation(int_val); |
| data_ = std::move(node); |
| } |
| |
| void AnnotationStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
| writer->WriteArraySeperator(); |
| writer->WriteString(record_prefix_str); |
| writer->WriteArrayItem(stage_id); |
| writer->WriteArrayItem(iter_id); |
| writer->WriteArrayItem(static_cast<int>(annotation)); |
| } |
| |
| Iterator AnnotationStepNode::ApplyToState(State* state) const { |
| const Stage& stage = (*state)->stages[stage_id]; |
| Iterator it = stage->iters[iter_id]; |
| |
| CHECK(it->annotation == IteratorAnnotation::kNone); |
| Iterator new_it = Iterator(it->name, it->range, it->iter_kind, annotation, &it->orig_iters); |
| Stage new_stage = stage; |
| new_stage.CopyOnWrite()->iters.Set(iter_id, new_it); |
| state->CopyOnWrite()->stages.Set(stage_id, std::move(new_stage)); |
| return new_it; |
| } |
| |
| void AnnotationStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes) const { |
| te::Stage stage = (*stages)[stage_id]; |
| const Array<IterVar>& axes = (*stage_to_axes)[stage]; |
| |
| switch (annotation) { |
| case IteratorAnnotation::kUnroll: |
| stage.unroll(axes[iter_id]); |
| break; |
| case IteratorAnnotation::kVectorize: |
| stage.vectorize(axes[iter_id]); |
| break; |
| case IteratorAnnotation::kParallel: |
| stage.parallel(axes[iter_id]); |
| break; |
| case IteratorAnnotation::kVThread: |
| case IteratorAnnotation::kBlockX: |
| case IteratorAnnotation::kBlockY: |
| case IteratorAnnotation::kBlockZ: |
| case IteratorAnnotation::kThreadX: |
| case IteratorAnnotation::kThreadY: |
| case IteratorAnnotation::kThreadZ: |
| stage.bind(axes[iter_id], |
| te::thread_axis(Range(), IteratorAnnotationString[static_cast<int>(annotation)])); |
| break; |
| case IteratorAnnotation::kNone: |
| break; |
| default: |
| LOG(FATAL) << "Invalid Annotation " << static_cast<int>(annotation); |
| break; |
| } |
| |
| stages->Set(stage_id, std::move(stage)); |
| } |
| |
| String AnnotationStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes) const { |
| std::stringstream ss; |
| const auto& stage = (*stages)[stage_id]; |
| const auto& iter = (*stage_to_axes)[stage][iter_id]; |
| const auto& op_name = CleanName(stage->op->name); |
| |
| ss << "s[" << op_name << "]."; |
| switch (annotation) { |
| case IteratorAnnotation::kUnroll: |
| ss << "unroll("; |
| break; |
| case IteratorAnnotation::kVectorize: |
| ss << "vectorize("; |
| break; |
| case IteratorAnnotation::kParallel: |
| ss << "parallel("; |
| break; |
| case IteratorAnnotation::kVThread: |
| case IteratorAnnotation::kBlockX: |
| case IteratorAnnotation::kBlockY: |
| case IteratorAnnotation::kBlockZ: |
| case IteratorAnnotation::kThreadX: |
| case IteratorAnnotation::kThreadY: |
| case IteratorAnnotation::kThreadZ: |
| ss << "bind("; |
| break; |
| case IteratorAnnotation::kNone: |
| break; |
| default: |
| LOG(FATAL) << "Invalid annotation " << static_cast<int>(annotation); |
| break; |
| } |
| ss << CleanName(iter->var->name_hint, op_name); |
| switch (annotation) { |
| case IteratorAnnotation::kVThread: |
| case IteratorAnnotation::kBlockX: |
| case IteratorAnnotation::kBlockY: |
| case IteratorAnnotation::kBlockZ: |
| case IteratorAnnotation::kThreadX: |
| case IteratorAnnotation::kThreadY: |
| case IteratorAnnotation::kThreadZ: |
| ss << ", te.thread_axis(\"" << IteratorAnnotationString[static_cast<int>(annotation)] |
| << "\")"; |
| break; |
| default: |
| break; |
| } |
| ss << ")\n"; |
| |
| ApplyToSchedule(stages, stage_to_axes); |
| return ss.str(); |
| } |
| |
| /********** Fuse **********/ |
| FuseStep::FuseStep(int stage_id, const Array<Integer>& fused_ids) { |
| auto node = make_object<FuseStepNode>(); |
| node->stage_id = stage_id; |
| for (const auto& x : fused_ids) { |
| CHECK(x->IsInstance<IntImmNode>()); |
| } |
| node->fused_ids = fused_ids; |
| data_ = std::move(node); |
| } |
| |
| FuseStep::FuseStep(dmlc::JSONReader* reader) { |
| auto node = make_object<FuseStepNode>(); |
| bool s; |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->stage_id); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->fused_ids); |
| data_ = std::move(node); |
| } |
| |
| void FuseStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
| writer->WriteArraySeperator(); |
| writer->WriteString(record_prefix_str); |
| writer->WriteArrayItem(stage_id); |
| writer->WriteArrayItem(fused_ids); |
| } |
| |
| Iterator FuseStepNode::ApplyToState(State* state) const { |
| const Stage& stage = (*state)->stages[stage_id]; |
| size_t old_iter_size = static_cast<int>(stage->iters.size()); |
| |
| String new_name; |
| PrimExpr new_extent = 1; |
| IteratorKind new_iter_kind = IteratorKind::kSpecial; |
| std::vector<Iterator> orig_iters; |
| |
| for (size_t i = 0; i < fused_ids.size(); ++i) { |
| if (i > 0) { |
| CHECK_EQ(fused_ids[i]->value, fused_ids[i - 1]->value + 1); |
| } |
| |
| if (i != fused_ids.size() - 1) { |
| const auto& iter_to_attached_stage = (*state)->attach_map->iter_to_attached_stages; |
| if (iter_to_attached_stage.find(std::make_pair(stage_id, fused_ids[i])) != |
| iter_to_attached_stage.end()) { |
| LOG(FATAL) << "Invalid Fuse. Trying to fuse iterators that have been attached by some " |
| << "stages. State before fusion:\n" |
| << (*state); |
| } |
| } |
| |
| const Iterator& it = stage->iters[fused_ids[i]]; |
| orig_iters.push_back(it); |
| new_name = new_name + it->name + "@"; |
| |
| if (it->range.defined() && new_extent.defined()) { |
| new_extent = new_extent * it->range->extent; |
| } else { |
| new_extent = PrimExpr(); |
| } |
| |
| if (i == 0) { |
| new_iter_kind = it->iter_kind; |
| } else { |
| if (new_iter_kind != it->iter_kind) { |
| new_iter_kind = IteratorKind::kMixed; |
| } |
| } |
| } |
| |
| Range range; |
| if (new_extent.defined()) { |
| range = Range::FromMinExtent(0, new_extent); |
| } |
| Iterator new_it = |
| Iterator(new_name, range, new_iter_kind, IteratorAnnotation::kNone, &orig_iters); |
| Array<Iterator> new_iters; |
| new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + fused_ids.front()); |
| new_iters.push_back(new_it); |
| new_iters.insert(new_iters.end(), stage->iters.begin() + fused_ids.back() + 1, |
| stage->iters.end()); |
| |
| StateNode* pstate = state->CopyOnWrite(); |
| pstate->stages.Set(stage_id, |
| Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); |
| |
| // Two vectors are used to represent the iterator relation before and after fuse |
| // The original iterators in AttachMap will be updated with the new iterators |
| std::vector<IterKey> from_iters; |
| std::vector<IterKey> to_iters; |
| const size_t begin_id = fused_ids.front(), end_id = fused_ids.back(); |
| for (size_t i = 0; i < old_iter_size; ++i) { |
| if (i <= begin_id) { |
| continue; |
| } else if (i > end_id) { |
| // move forward |
| from_iters.emplace_back(stage_id, i); |
| to_iters.emplace_back(stage_id, i - end_id + begin_id); |
| } else { |
| // move to the fused id |
| from_iters.emplace_back(stage_id, i); |
| to_iters.emplace_back(stage_id, begin_id); |
| } |
| } |
| pstate->attach_map.UpdateIters(from_iters, to_iters); |
| |
| return new_it; |
| } |
| |
| IterVar FuseStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes) const { |
| auto stage = (*stages)[stage_id]; |
| const Array<IterVar>& axes = stage_to_axes->at(stage); |
| |
| Array<IterVar> to_fuse; |
| for (const auto& i : fused_ids) { |
| to_fuse.push_back(axes[i]); |
| } |
| IterVar fused_axis; |
| stage.fuse(to_fuse, &fused_axis); |
| |
| Array<IterVar> new_axes; |
| new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids.front()); |
| new_axes.push_back(fused_axis); |
| new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, axes.end()); |
| |
| stage_to_axes->Set(stage, std::move(new_axes)); |
| stages->Set(stage_id, std::move(stage)); |
| return fused_axis; |
| } |
| |
| String FuseStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes) const { |
| const auto& stage = (*stages)[stage_id]; |
| const auto& op_name = CleanName(stage->op->name); |
| std::stringstream to_fuse; |
| |
| for (size_t i = 0; i < fused_ids.size(); ++i) { |
| to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint, op_name); |
| if (i != fused_ids.size() - 1) { |
| to_fuse << ", "; |
| } |
| } |
| |
| std::stringstream ss; |
| const auto& fused = ApplyToSchedule(stages, stage_to_axes); |
| |
| ss << CleanName(fused->var->name_hint, op_name) << " = s[" << op_name << "].fuse(" |
| << to_fuse.str() << ")\n"; |
| |
| return ss.str(); |
| } |
| |
| /********** Pragma **********/ |
| PragmaStep::PragmaStep(int stage_id, int iter_id, String pragma_type) { |
| auto node = make_object<PragmaStepNode>(); |
| node->stage_id = stage_id; |
| node->iter_id = iter_id; |
| node->pragma_type = std::move(pragma_type); |
| data_ = std::move(node); |
| } |
| |
| PragmaStep::PragmaStep(dmlc::JSONReader* reader) { |
| auto node = make_object<PragmaStepNode>(); |
| bool s; |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->stage_id); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->iter_id); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| std::string string_value; |
| reader->Read(&string_value); |
| node->pragma_type = std::move(string_value); |
| data_ = std::move(node); |
| } |
| |
| void PragmaStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
| writer->WriteArraySeperator(); |
| writer->WriteString(record_prefix_str); |
| writer->WriteArrayItem(stage_id); |
| writer->WriteArrayItem(iter_id); |
| writer->WriteArraySeperator(); |
| writer->WriteString(pragma_type); |
| } |
| |
| void PragmaStepNode::ApplyToState(State* state) const { |
| if (pragma_type == "debug_skip_region") { |
| StateNode* pstate = state->CopyOnWrite(); |
| pstate->attach_map.DeleteStage(stage_id); |
| } else if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { |
| StateNode* pstate = state->CopyOnWrite(); |
| Stage stage = pstate->stages[stage_id]; |
| size_t pos = 0; |
| for (; pos < pragma_type.size(); ++pos) { |
| if ((*(pragma_type.c_str() + pos)) == '$') { |
| break; |
| } |
| } |
| CHECK_LT(pos, pragma_type.size()) << "max step value not found."; |
| stage.CopyOnWrite()->attrs.auto_unroll_max_step = atoi(pragma_type.c_str() + pos + 1); |
| pstate->stages.Set(stage_id, std::move(stage)); |
| } else { |
| LOG(FATAL) << "Unsupported pragma: " << pragma_type; |
| } |
| } |
| |
| void PragmaStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes) const { |
| te::Stage stage = (*stages)[stage_id]; |
| const Array<IterVar>& axes = (*stage_to_axes)[stage]; |
| if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { |
| size_t pos = 0; |
| for (; pos < pragma_type.size(); ++pos) { |
| if ((*(pragma_type.c_str() + pos)) == '$') { |
| break; |
| } |
| } |
| CHECK_LT(pos, pragma_type.size()) << "max step value not found."; |
| int value = atoi(pragma_type.c_str() + pos + 1); |
| stage.pragma(axes[iter_id], "auto_unroll_max_step", value); |
| stage.pragma(axes[iter_id], "unroll_explicit", true); |
| } else { |
| stage.pragma(axes[iter_id], pragma_type); |
| } |
| stages->Set(stage_id, std::move(stage)); |
| } |
| |
| String PragmaStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes) const { |
| std::stringstream ss; |
| const auto& stage = (*stages)[stage_id]; |
| const auto& op_name = CleanName(stage->op->name); |
| |
| if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { |
| size_t pos = 0; |
| for (; pos < pragma_type.size(); ++pos) { |
| if ((*(pragma_type.c_str() + pos)) == '$') { |
| break; |
| } |
| } |
| CHECK_LT(pos, pragma_type.size()) << "max step value not found."; |
| int value = atoi(pragma_type.c_str() + pos + 1); |
| ss << "s[" << op_name << "].pragma(" |
| << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name) |
| << ", \"auto_unroll_max_step\", " << value << ")\n"; |
| ss << "s[" << op_name << "].pragma(" |
| << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name) |
| << ", \"unroll_explicit\", True)\n"; |
| } else { |
| ss << "s[" << op_name << "].pragma(" |
| << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name) << ", \"" |
| << pragma_type << "\")\n"; |
| } |
| |
| ApplyToSchedule(stages, stage_to_axes); |
| return ss.str(); |
| } |
| |
| /********** Reorder **********/ |
| ReorderStep::ReorderStep(int stage_id, const Array<Integer>& after_ids) { |
| auto node = make_object<ReorderStepNode>(); |
| node->stage_id = stage_id; |
| for (const auto& x : after_ids) { |
| CHECK(x->IsInstance<IntImmNode>()); |
| } |
| node->after_ids = after_ids; |
| data_ = std::move(node); |
| } |
| |
| ReorderStep::ReorderStep(dmlc::JSONReader* reader) { |
| auto node = make_object<ReorderStepNode>(); |
| bool s; |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->stage_id); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->after_ids); |
| data_ = std::move(node); |
| } |
| |
| void ReorderStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
| writer->WriteArraySeperator(); |
| writer->WriteString(record_prefix_str); |
| writer->WriteArrayItem(stage_id); |
| writer->WriteArrayItem(after_ids); |
| } |
| |
| void ReorderStepNode::ApplyToState(State* state) const { |
| const Stage& stage = (*state)->stages[stage_id]; |
| Array<Iterator> iters; |
| for (auto x : after_ids) { |
| iters.push_back(stage->iters[x]); |
| } |
| state->CopyOnWrite()->stages.Set( |
| stage_id, Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs)); |
| } |
| |
| void ReorderStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes) const { |
| auto stage = (*stages)[stage_id]; |
| const Array<IterVar>& axes = stage_to_axes->at(stage); |
| CHECK_EQ(after_ids.size(), axes.size()); |
| |
| Array<IterVar> new_axes; |
| new_axes.reserve(axes.size()); |
| for (auto i : after_ids) { |
| new_axes.push_back(axes[i]); |
| } |
| stage.reorder(new_axes); |
| |
| stage_to_axes->Set(stage, std::move(new_axes)); |
| stages->Set(stage_id, std::move(stage)); |
| } |
| |
| String ReorderStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes) const { |
| const auto& stage = (*stages)[stage_id]; |
| const auto& op_name = CleanName(stage->op->name); |
| std::stringstream ss; |
| |
| ss << "s[" << op_name << "].reorder("; |
| for (size_t i = 0; i < after_ids.size(); ++i) { |
| ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint, op_name); |
| if (i != after_ids.size() - 1) { |
| ss << ", "; |
| } |
| } |
| ss << ")\n"; |
| |
| ApplyToSchedule(stages, stage_to_axes); |
| return ss.str(); |
| } |
| |
| /********** Split **********/ |
| // common part for SplitStep, FollowSplitStep, and FollowFusedSplitStep |
| Array<Iterator> ApplySplitToState(State* state, int stage_id, int iter_id, |
| const Array<Optional<Integer>>& lengths, bool inner_to_outer) { |
| const Stage& stage = (*state)->stages[stage_id]; |
| const Iterator& it = stage->iters[iter_id]; |
| size_t old_iter_size = stage->iters.size(); |
| bool concrete = true; |
| |
| Optional<PrimExpr> tosplit_min, tosplit_extent; |
| if (it->range.defined()) { |
| tosplit_min = it->range->min; |
| tosplit_extent = it->range->extent; |
| } else { |
| tosplit_min = NullOpt; |
| tosplit_extent = NullOpt; |
| } |
| |
| Array<Iterator> outs; |
| for (size_t i = 0; i < lengths.size(); ++i) { |
| Optional<Integer> l; |
| String name; |
| if (inner_to_outer) { |
| l = lengths[lengths.size() - i - 1]; |
| name = it->name + "." + std::to_string(lengths.size() - i); |
| } else { |
| l = lengths[i]; |
| name = it->name + "." + std::to_string(i); |
| } |
| Iterator res; |
| if (l && tosplit_min && tosplit_extent) { |
| res = Iterator(name, Range::FromMinExtent(tosplit_min.value(), l.value()), it->iter_kind, |
| IteratorAnnotation::kNone); |
| tosplit_min = Integer(0); |
| tosplit_extent = indexdiv(tosplit_extent.value() + l.value() - 1, l.value()); |
| } else { |
| res = Iterator(name, Range(), it->iter_kind, IteratorAnnotation::kNone); |
| tosplit_min = NullOpt; |
| tosplit_extent = NullOpt; |
| if (!l.defined()) { |
| concrete = false; |
| } |
| } |
| outs.push_back(std::move(res)); |
| } |
| |
| Range range; |
| if (tosplit_min && tosplit_extent) { |
| range = Range::FromMinExtent(tosplit_min.value(), tosplit_extent.value()); |
| } |
| if (inner_to_outer) { |
| outs.push_back(Iterator(it->name + ".0", range, it->iter_kind, IteratorAnnotation::kNone)); |
| // Reverse the Iterator array |
| Array<Iterator> temp(outs.rbegin(), outs.rend()); |
| outs = std::move(temp); |
| } else { |
| outs.push_back(Iterator(it->name + "." + std::to_string(lengths.size()), range, it->iter_kind, |
| IteratorAnnotation::kNone)); |
| } |
| |
| Array<Iterator> new_iters; |
| new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id); |
| new_iters.insert(new_iters.end(), outs.begin(), outs.end()); |
| new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end()); |
| |
| StateNode* pstate = state->CopyOnWrite(); |
| pstate->stages.Set(stage_id, |
| Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); |
| pstate->concrete &= concrete; |
| |
| // Two vectors are used to represent the iterator relation before and after split |
| // The original iterators in AttachMap will be updated with the new iterators |
| std::vector<IterKey> from_iters; |
| std::vector<IterKey> to_iters; |
| for (size_t i = iter_id; i < old_iter_size; ++i) { |
| from_iters.emplace_back(stage_id, i); |
| to_iters.emplace_back(stage_id, i + lengths.size()); |
| } |
| pstate->attach_map.UpdateIters(from_iters, to_iters); |
| |
| return outs; |
| } |
| |
| Array<IterVar> ApplySplitToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, |
| int stage_id, int iter_id, |
| const Array<Optional<Integer>>& lengths, bool inner_to_outer) { |
| auto stage = (*stages)[stage_id]; |
| const Array<IterVar>& axes = stage_to_axes->at(stage); |
| |
| Array<IterVar> outs; |
| if (inner_to_outer) { |
| IterVar outer = axes[iter_id], inner; |
| for (int i = static_cast<int>(lengths.size()) - 1; i >= 0; i--) { |
| IterVar to_split = outer; |
| stage.split(to_split, lengths[i].value(), &outer, &inner); |
| outs.push_back(inner); |
| } |
| outs.push_back(outer); |
| } else { |
| IterVar outer, inner = axes[iter_id]; |
| for (size_t i = 0; i < lengths.size(); i++) { |
| IterVar to_split = inner; |
| stage.split_by_nparts(to_split, lengths[i].value(), &outer, &inner); |
| outs.push_back(outer); |
| } |
| outs.push_back(inner); |
| } |
| |
| Array<IterVar> new_axes; |
| new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + iter_id); |
| if (inner_to_outer) { |
| for (auto x = outs.rbegin(); x != outs.rend(); ++x) { |
| new_axes.push_back((*x)); |
| } |
| } else { |
| for (const auto& x : outs) { |
| new_axes.push_back(x); |
| } |
| } |
| new_axes.insert(new_axes.end(), axes.begin() + iter_id + 1, axes.end()); |
| |
| stage_to_axes->Set(stage, std::move(new_axes)); |
| stages->Set(stage_id, std::move(stage)); |
| return outs; |
| } |
| |
| String PrintSplitAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, int stage_id, |
| int iter_id, const Array<Optional<Integer>>& lengths, |
| bool inner_to_outer) { |
| const auto& stage = (*stages)[stage_id]; |
| auto to_split = stage_to_axes->at(stage)[iter_id]; |
| const auto& func_name = CleanName(stage->op->name); |
| const auto& outs = |
| ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); |
| CHECK_EQ(outs.size(), lengths.size() + 1); |
| |
| std::stringstream ss; |
| int size = static_cast<int>(lengths.size()); |
| if (inner_to_outer) { |
| for (int i = size - 1; i >= 0; i--) { |
| ss << CleanName(outs[size - i]->var->name_hint, func_name) << ", " |
| << CleanName(outs[size - i - 1]->var->name_hint, func_name) << " = s[" << func_name |
| << "].split(" << CleanName(to_split->var->name_hint, func_name) |
| << ", factor=" << lengths[i] << ")\n"; |
| to_split = outs[size - i]; |
| } |
| } else { |
| for (int i = 0; i < size; i++) { |
| ss << CleanName(outs[i]->var->name_hint, func_name) << ", " |
| << CleanName(outs[i + 1]->var->name_hint, func_name) << " = s[" << func_name << "].split(" |
| << CleanName(to_split->var->name_hint, func_name) << ", nparts=" << lengths[i] << ")\n"; |
| to_split = outs[i + 1]; |
| } |
| } |
| |
| return ss.str(); |
| } |
| |
| SplitStep::SplitStep(int stage_id, int iter_id, Optional<PrimExpr> extent, |
| const Array<Optional<Integer>>& lengths, bool inner_to_outer) { |
| auto node = make_object<SplitStepNode>(); |
| node->stage_id = stage_id; |
| // Extent can be a irreducible expression in some special cases |
| if (extent && extent.value()->IsInstance<IntImmNode>()) { |
| node->extent = tvm::Downcast<Integer>(extent.value()); |
| } |
| node->iter_id = iter_id; |
| node->lengths = lengths; |
| node->inner_to_outer = inner_to_outer; |
| data_ = std::move(node); |
| } |
| |
| SplitStep::SplitStep(dmlc::JSONReader* reader) { |
| auto node = make_object<SplitStepNode>(); |
| bool s; |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->stage_id); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->iter_id); |
| int int_val; |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&int_val); |
| if (int_val) { |
| node->extent = Integer(int_val); |
| } |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->lengths); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->inner_to_outer); |
| data_ = std::move(node); |
| } |
| |
| void SplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
| writer->WriteArraySeperator(); |
| writer->WriteString(record_prefix_str); |
| writer->WriteArrayItem(stage_id); |
| writer->WriteArrayItem(iter_id); |
| writer->WriteArrayItem(extent ? GetIntImm(extent.value()) : 0); |
| writer->WriteArrayItem(lengths); |
| writer->WriteArrayItem(static_cast<int>(inner_to_outer)); |
| } |
| |
| Array<Iterator> SplitStepNode::ApplyToState(State* state) const { |
| return ApplySplitToState(state, stage_id, iter_id, lengths, inner_to_outer); |
| } |
| |
| Array<IterVar> SplitStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes) const { |
| return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); |
| } |
| |
| String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes) const { |
| return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); |
| } |
| |
| /********** Follow Split **********/ |
| FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) { |
| auto node = make_object<FollowSplitStepNode>(); |
| node->stage_id = stage_id; |
| node->iter_id = iter_id; |
| node->src_step_id = src_step_id; |
| node->n_split = n_split; |
| data_ = std::move(node); |
| } |
| |
| void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
| writer->WriteArraySeperator(); |
| writer->WriteString(record_prefix_str); |
| writer->WriteArrayItem(stage_id); |
| writer->WriteArrayItem(iter_id); |
| writer->WriteArrayItem(src_step_id); |
| writer->WriteArrayItem(n_split); |
| } |
| |
| Array<Optional<Integer>> FollowSplitStepNode::ExtractSplitLengths( |
| const Array<Step>& transform_steps) const { |
| // Make sure src_step_id is within the range of transform_steps. |
| CHECK_LT(src_step_id, transform_steps.size()); |
| auto ps = transform_steps[src_step_id].as<SplitStepNode>(); |
| CHECK(ps != nullptr); |
| |
| // Make sure the size of ps->lengths is not smaller than n_split-1. |
| // Note that the number of actual splitting factors of src_step is ps->lengths.size()+1. |
| CHECK_LE(n_split, ps->lengths.size() + 1); |
| CHECK(ps != nullptr); |
| |
| Array<Optional<Integer>> lengths; |
| lengths.reserve(n_split); |
| int j = 0; |
| // Get the first (n_split-1) split factors of followed src_step. |
| for (; j < n_split - 1; ++j) { |
| lengths.push_back(ps->lengths[j]); |
| } |
| |
| // Get the last split factor of src_step for splitting level if n_split is smaller than |
| // ps->lengths.size()+1. |
| PrimExpr last_factor = 1; |
| for (; j < static_cast<int>(ps->lengths.size()); ++j) { |
| if (ps->lengths[j]) { |
| last_factor *= ps->lengths[j].value(); |
| } else { |
| last_factor = PrimExpr(); |
| break; |
| } |
| } |
| if (last_factor.defined()) { |
| lengths.push_back(Downcast<Integer>(last_factor)); |
| } else { |
| lengths.push_back(NullOpt); |
| } |
| |
| return lengths; |
| } |
| |
| FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) { |
| auto node = make_object<FollowSplitStepNode>(); |
| bool s; |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->stage_id); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->iter_id); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->src_step_id); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->n_split); |
| data_ = std::move(node); |
| } |
| |
| Array<Iterator> FollowSplitStepNode::ApplyToState(State* state) const { |
| return ApplySplitToState(state, stage_id, iter_id, ExtractSplitLengths((*state)->transform_steps), |
| true); |
| } |
| |
| Array<IterVar> FollowSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes, |
| const Array<Step>& transform_steps) const { |
| return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, |
| ExtractSplitLengths(transform_steps), true); |
| } |
| |
| String FollowSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes, |
| const Array<Step>& transform_steps) const { |
| return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, |
| ExtractSplitLengths(transform_steps), true); |
| } |
| |
| /********** Follow Fused Split **********/ |
| FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id, |
| const Array<Integer>& src_step_ids, int level, |
| bool factor_or_nparts) { |
| auto node = make_object<FollowFusedSplitStepNode>(); |
| node->stage_id = stage_id; |
| node->iter_id = iter_id; |
| node->src_step_ids = src_step_ids; |
| node->level = level; |
| node->factor_or_nparts = factor_or_nparts; |
| data_ = std::move(node); |
| } |
| |
| FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) { |
| auto node = make_object<FollowFusedSplitStepNode>(); |
| bool s; |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->stage_id); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->iter_id); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->src_step_ids); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->level); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->factor_or_nparts); |
| data_ = std::move(node); |
| } |
| |
| void FollowFusedSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
| writer->WriteArraySeperator(); |
| writer->WriteString(record_prefix_str); |
| writer->WriteArrayItem(stage_id); |
| writer->WriteArrayItem(iter_id); |
| writer->WriteArrayItem(src_step_ids); |
| writer->WriteArrayItem(level); |
| writer->WriteArrayItem(static_cast<int>(factor_or_nparts)); |
| } |
| |
| Optional<Integer> FollowFusedSplitStepNode::ExtractSplitLength( |
| const Array<Step>& transform_steps) const { |
| PrimExpr ret(1); |
| |
| for (int src_step_id : src_step_ids) { |
| // Make sure the src_step_id is within the range of transform_steps. |
| CHECK_LT(src_step_id, transform_steps.size()); |
| auto ps = transform_steps[src_step_id].as<SplitStepNode>(); |
| CHECK(ps != nullptr); |
| // Multiple the splitting factor on corresponding splitting level of src_steps. |
| if (ps->lengths[level] && ret.defined()) { |
| ret *= ps->lengths[level].value(); |
| } else { |
| return NullOpt; |
| } |
| } |
| return Downcast<Integer>(ret); |
| } |
| |
| Array<Iterator> FollowFusedSplitStepNode::ApplyToState(State* state) const { |
| return ApplySplitToState(state, stage_id, iter_id, |
| {ExtractSplitLength((*state)->transform_steps)}, factor_or_nparts); |
| } |
| |
| Array<IterVar> FollowFusedSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes, |
| const Array<Step>& transform_steps) const { |
| return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, |
| {ExtractSplitLength(transform_steps)}, factor_or_nparts); |
| } |
| |
| String FollowFusedSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes, |
| const Array<Step>& transform_steps) const { |
| return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, |
| {ExtractSplitLength(transform_steps)}, factor_or_nparts); |
| } |
| |
| /********** Storage Align **********/ |
| StorageAlignStep::StorageAlignStep(int stage_id, int iter_id, int factor, int offset) { |
| auto node = make_object<StorageAlignStepNode>(); |
| node->stage_id = stage_id; |
| node->iter_id = iter_id; |
| node->factor = factor; |
| node->offset = offset; |
| data_ = std::move(node); |
| } |
| |
| StorageAlignStep::StorageAlignStep(dmlc::JSONReader* reader) { |
| auto node = make_object<StorageAlignStepNode>(); |
| bool s; |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->stage_id); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->iter_id); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->factor); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->offset); |
| data_ = std::move(node); |
| } |
| |
| void StorageAlignStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
| writer->WriteArraySeperator(); |
| writer->WriteString(record_prefix_str); |
| writer->WriteArrayItem(stage_id); |
| writer->WriteArrayItem(iter_id); |
| writer->WriteArrayItem(factor); |
| writer->WriteArrayItem(offset); |
| } |
| |
| void StorageAlignStepNode::ApplyToState(State* state) const { |
| StateNode* pstate = state->CopyOnWrite(); |
| Stage stage = pstate->stages[stage_id]; |
| stage.CopyOnWrite()->attrs.storage_offset = offset; |
| pstate->stages.Set(stage_id, std::move(stage)); |
| } |
| |
| void StorageAlignStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes) const { |
| te::Stage stage = (*stages)[stage_id]; |
| const Array<IterVar>& axes = (*stage_to_axes)[stage]; |
| stage.storage_align(axes[iter_id], factor, offset); |
| stages->Set(stage_id, std::move(stage)); |
| } |
| |
| String StorageAlignStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes) const { |
| std::stringstream ss; |
| const auto& stage = (*stages)[stage_id]; |
| const auto& op_name = CleanName(stage->op->name); |
| ss << "s[" << op_name << "].storage_align(" |
| << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name) << ", " << factor |
| << ", " << offset << ")\n"; |
| |
| ApplyToSchedule(stages, stage_to_axes); |
| return ss.str(); |
| } |
| |
| /********** Steps working on multiple stages **********/ |
| |
| /********** Compute At **********/ |
| ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id) { |
| auto node = make_object<ComputeAtStepNode>(); |
| node->stage_id = stage_id; |
| node->target_stage_id = target_stage_id; |
| node->target_iter_id = target_iter_id; |
| data_ = std::move(node); |
| } |
| |
| ComputeAtStep::ComputeAtStep(dmlc::JSONReader* reader) { |
| auto node = make_object<ComputeAtStepNode>(); |
| bool s; |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->stage_id); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->target_stage_id); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->target_iter_id); |
| data_ = std::move(node); |
| } |
| |
| void ComputeAtStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
| writer->WriteArraySeperator(); |
| writer->WriteString(record_prefix_str); |
| writer->WriteArrayItem(stage_id); |
| writer->WriteArrayItem(target_stage_id); |
| writer->WriteArrayItem(target_iter_id); |
| } |
| void ComputeAtStepNode::ApplyToState(State* state) const { |
| const Stage& stage = (*state)->stages[stage_id]; |
| |
| // Remove the bound information of each iterator since they may not be accurate after |
| // compute at |
| Array<Iterator> new_iters; |
| for (const Iterator& it : stage->iters) { |
| new_iters.push_back( |
| Iterator(it->name, Range(), it->iter_kind, it->annotation, &it->orig_iters)); |
| } |
| |
| StateNode* pstate = state->CopyOnWrite(); |
| pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), |
| ComputeAtKind::kIter, stage->attrs)); |
| // Update attach map |
| pstate->attach_map.SetComputeAtIter(stage_id, target_stage_id, target_iter_id); |
| } |
| |
| void ComputeAtStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes) const { |
| te::Stage stage = (*stages)[stage_id]; |
| const auto& target_stage = (*stages)[target_stage_id]; |
| const auto& target_axis = (*stage_to_axes)[target_stage][target_iter_id]; |
| stage.compute_at(target_stage, target_axis); |
| |
| stages->Set(stage_id, std::move(stage)); |
| } |
| |
| String ComputeAtStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes) const { |
| std::stringstream ss; |
| const auto& stage = (*stages)[stage_id]; |
| const auto& target_stage = (*stages)[target_stage_id]; |
| const auto& op_name = CleanName(stage->op->name); |
| const auto& target_op_name = CleanName(target_stage->op->name); |
| ss << "s[" << op_name << "].compute_at(s[" << target_op_name << "], " |
| << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint, target_op_name) |
| << ")\n"; |
| ApplyToSchedule(stages, stage_to_axes); |
| return ss.str(); |
| } |
| |
| /********** Compute Inline **********/ |
| ComputeInlineStep::ComputeInlineStep(int stage_id) { |
| auto node = make_object<ComputeInlineStepNode>(); |
| node->stage_id = stage_id; |
| data_ = std::move(node); |
| } |
| |
| ComputeInlineStep::ComputeInlineStep(dmlc::JSONReader* reader) { |
| auto node = make_object<ComputeInlineStepNode>(); |
| bool s; |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->stage_id); |
| data_ = std::move(node); |
| } |
| |
| void ComputeInlineStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
| writer->WriteArraySeperator(); |
| writer->WriteString(record_prefix_str); |
| writer->WriteArrayItem(stage_id); |
| } |
| |
| void ComputeInlineStepNode::ApplyToState(State* state) const { |
| const Stage& stage = (*state)->stages[stage_id]; |
| |
| // Check the validity of compute_inline |
| for (size_t i = 0; i < stage->iters.size(); ++i) { |
| CHECK_EQ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, i)), 0) |
| << "Invalid compute_inline: There are some other stages that are attached to the " |
| << "target stage"; |
| } |
| |
| StateNode* pstate = state->CopyOnWrite(); |
| auto new_stage = pstate->stages[stage_id]; |
| new_stage.CopyOnWrite()->compute_at = ComputeAtKind::kInlined; |
| pstate->stages.Set(stage_id, std::move(new_stage)); |
| // Update attach map |
| pstate->attach_map.DeleteStage(stage_id); |
| } |
| |
| void ComputeInlineStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes) const { |
| auto stage = (*stages)[stage_id]; |
| stage.compute_inline(); |
| stages->Set(stage_id, std::move(stage)); |
| } |
| |
| String ComputeInlineStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes) const { |
| std::stringstream ss; |
| const auto& stage = (*stages)[stage_id]; |
| ss << "s[" << CleanName(stage->op->name) << "].compute_inline()\n"; |
| ApplyToSchedule(stages, stage_to_axes); |
| return ss.str(); |
| } |
| |
| /********** Compute Root **********/ |
| ComputeRootStep::ComputeRootStep(int stage_id) { |
| auto node = make_object<ComputeRootStepNode>(); |
| node->stage_id = stage_id; |
| data_ = std::move(node); |
| } |
| |
| ComputeRootStep::ComputeRootStep(dmlc::JSONReader* reader) { |
| auto node = make_object<ComputeRootStepNode>(); |
| bool s; |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->stage_id); |
| data_ = std::move(node); |
| } |
| |
| void ComputeRootStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
| writer->WriteArraySeperator(); |
| writer->WriteString(record_prefix_str); |
| writer->WriteArrayItem(stage_id); |
| } |
| |
| void ComputeRootStepNode::ApplyToState(State* state) const { |
| const Stage& stage = (*state)->stages[stage_id]; |
| |
| // Remove the bound information of each iterator since they may not be accurate after |
| // compute root |
| Array<Iterator> new_iters; |
| for (const Iterator& it : stage->iters) { |
| new_iters.push_back( |
| Iterator(it->name, Range(), it->iter_kind, it->annotation, &it->orig_iters)); |
| } |
| |
| StateNode* pstate = state->CopyOnWrite(); |
| pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), |
| ComputeAtKind::kRoot, stage->attrs)); |
| // Update attach map |
| pstate->attach_map.DeleteStage(stage_id); |
| } |
| |
| void ComputeRootStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes) const { |
| auto stage = (*stages)[stage_id]; |
| stage.compute_root(); |
| stages->Set(stage_id, std::move(stage)); |
| } |
| |
| String ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes) const { |
| std::stringstream ss; |
| const auto& stage = (*stages)[stage_id]; |
| ss << "s[" << CleanName(stage->op->name) << "].compute_root()\n"; |
| ApplyToSchedule(stages, stage_to_axes); |
| return ss.str(); |
| } |
| |
| /********** Steps adding new stages **********/ |
| |
| /*! |
| * \brief Common part for steps that add new stages(e.g. CacheReadStep, CacheWriteStep, |
| * RfactorStep). This will return all steps that can change the number of stages in a ComputeDAG, |
| * and stop by the current step. |
| */ |
| Array<Step> GetFormerStageModifiableSteps(Step current_step, const Array<Step>& transform_steps) { |
| Array<Step> ret_steps; |
| for (size_t i = 0; i < transform_steps.size(); ++i) { |
| const Step& step = transform_steps[i]; |
| if (step->IsInstance<CacheWriteStepNode>() || step->IsInstance<CacheReadStepNode>()) { |
| ret_steps.push_back(step); |
| } else if (step->IsInstance<RfactorStepNode>()) { |
| // add FuseStepNode required by rfactor |
| if (i >= 2 && transform_steps[i - 2]->IsInstance<FuseStepNode>()) { |
| const Step& fuse_step = transform_steps[i - 2]; |
| if (fuse_step->stage_id == step->stage_id) { |
| ret_steps.push_back(fuse_step); |
| } |
| } |
| // add SplitStepNode required by rfactor |
| CHECK_GE(i, 1); |
| CHECK(transform_steps[i - 1]->IsInstance<SplitStepNode>()); |
| const Step& split_step = transform_steps[i - 1]; |
| CHECK_EQ(split_step->stage_id, step->stage_id); |
| ret_steps.push_back(split_step); |
| // add RfactorStepNode |
| ret_steps.push_back(step); |
| } |
| // A state may have multiple stage modifiable steps, stop by the current step to avoid |
| // replaying excess steps |
| if (step.same_as(current_step)) { |
| break; |
| } |
| } |
| return ret_steps; |
| } |
| |
| /********** Cache Read **********/ |
| CacheReadStep::CacheReadStep(int stage_id, String scope_name, |
| const Array<Integer>& reader_stage_ids) { |
| auto node = make_object<CacheReadStepNode>(); |
| node->stage_id = stage_id; |
| node->scope_name = std::move(scope_name); |
| node->reader_stage_ids = reader_stage_ids; |
| data_ = std::move(node); |
| } |
| |
| CacheReadStep::CacheReadStep(dmlc::JSONReader* reader) { |
| auto node = make_object<CacheReadStepNode>(); |
| bool s; |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->stage_id); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| std::string string_value; |
| reader->Read(&string_value); |
| node->scope_name = std::move(string_value); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->reader_stage_ids); |
| data_ = std::move(node); |
| } |
| |
| void CacheReadStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
| writer->WriteArraySeperator(); |
| writer->WriteString(record_prefix_str); |
| writer->WriteArrayItem(stage_id); |
| writer->WriteArraySeperator(); |
| writer->WriteString(scope_name); |
| writer->WriteArrayItem(reader_stage_ids); |
| } |
| |
| int CacheReadStepNode::ApplyToState(State* state, const ComputeDAG& dag) const { |
| StateNode* pstate = state->CopyOnWrite(); |
| const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG( |
| GetFormerStageModifiableSteps(GetRef<Step>(this), (*state)->transform_steps)); |
| |
| // target_stage -> target_stage + target_store |
| // Update the op of the target stage, insert a new cache read stage behind, update the op of |
| // later stages, then update the stage_id mapping in AttachMap |
| int added_stage_id = stage_id + 1; |
| Stage tmp_stage = pstate->stages[stage_id]; |
| tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[stage_id]; |
| pstate->stages.Set(stage_id, std::move(tmp_stage)); |
| pstate->stages.insert(pstate->stages.begin() + added_stage_id, |
| Stage(current_compute_dag->ops[added_stage_id])); |
| for (size_t i = added_stage_id + 1; i < pstate->stages.size(); ++i) { |
| tmp_stage = pstate->stages[i]; |
| tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[i]; |
| pstate->stages.Set(i, std::move(tmp_stage)); |
| } |
| pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(added_stage_id); |
| pstate->current_compute_dag = std::move(current_compute_dag); |
| |
| return added_stage_id; |
| } |
| |
| te::Tensor CacheReadStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes, |
| te::Schedule* schedule) const { |
| const te::Stage& stage = (*stages)[stage_id]; |
| Array<te::Operation> readers; |
| for (const auto& i : reader_stage_ids) { |
| readers.push_back((*stages)[i]->origin_op); |
| } |
| auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, readers); |
| |
| const auto& new_stage = (*schedule)[out->op]; |
| UpdateStageToAxesMap(new_stage, stage_to_axes); |
| stages->insert(stages->begin() + stage_id + 1, new_stage); |
| |
| return out; |
| } |
| |
| String CacheReadStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, |
| te::Schedule* schedule) const { |
| std::stringstream ss; |
| // Since the original stage will be changed after schedule apply, keep a copy here |
| // These information will be used to print Python API string later |
| auto stage = (*stages)[stage_id]; |
| Array<te::Stage> reader_stages; |
| for (size_t i = 0; i < reader_stage_ids.size(); ++i) { |
| reader_stages.push_back((*stages)[reader_stage_ids[i]]); |
| } |
| auto out = ApplyToSchedule(stages, stage_to_axes, schedule); |
| |
| const auto& op_name = CleanName(out->op->name); |
| ss << op_name << " = " |
| << "s.cache_read(" << CleanName(stage->op->name) << ", \"" << scope_name << "\", [" |
| << CleanName(reader_stages[0]->op->name); |
| for (size_t i = 1; i < reader_stage_ids.size(); ++i) { |
| ss << ", " << CleanName(reader_stages[i]->op->name); |
| } |
| ss << "])\n"; |
| |
| // Print the iterators of the new added stage |
| const auto& iters = out->op->root_iter_vars(); |
| for (size_t i = 0; i < iters.size(); ++i) { |
| ss << CleanName(iters[i]->var->name_hint, op_name); |
| if (i != iters.size() - 1) { |
| ss << ", "; |
| } |
| } |
| ss << " = " |
| << "tuple(" << op_name << ".op.axis)\n"; |
| |
| return ss.str(); |
| } |
| |
| /********** Cache Write **********/ |
| CacheWriteStep::CacheWriteStep(int stage_id, String scope_name) { |
| auto node = make_object<CacheWriteStepNode>(); |
| node->stage_id = stage_id; |
| node->scope_name = std::move(scope_name); |
| data_ = std::move(node); |
| } |
| |
| CacheWriteStep::CacheWriteStep(dmlc::JSONReader* reader) { |
| auto node = make_object<CacheWriteStepNode>(); |
| bool s; |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->stage_id); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| std::string string_value; |
| reader->Read(&string_value); |
| node->scope_name = std::move(string_value); |
| data_ = std::move(node); |
| } |
| |
| void CacheWriteStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
| writer->WriteArraySeperator(); |
| writer->WriteString(record_prefix_str); |
| writer->WriteArrayItem(stage_id); |
| writer->WriteArraySeperator(); |
| writer->WriteString(scope_name); |
| } |
| |
| int CacheWriteStepNode::ApplyToState(State* state, const ComputeDAG& dag) const { |
| StateNode* pstate = state->CopyOnWrite(); |
| int last_dag_op_size = pstate->current_compute_dag |
| ? pstate->current_compute_dag.value().as<ComputeDAGNode>()->ops.size() |
| : dag->ops.size(); |
| const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG( |
| GetFormerStageModifiableSteps(GetRef<Step>(this), (*state)->transform_steps)); |
| int added_ops = current_compute_dag->ops.size() - last_dag_op_size; |
| // TODO(jcf94): Update this check to equal after fixing the cache write bug in TVM |
| CHECK_GE(added_ops, 1); |
| |
| // target_stage -> cache_write_stage + target_stage |
| // Assume no step has been applied to the target stage before cache write. |
| // Insert a new cache write stage ahead, update the op of the target stage and later stages, then |
| // update the stage_id mapping in AttachMap |
| pstate->stages.insert(pstate->stages.begin() + stage_id, |
| Stage(current_compute_dag->ops[stage_id])); |
| pstate->stages.Set(stage_id + 1, Stage(current_compute_dag->ops[stage_id + 1])); |
| int next_stage_id = stage_id + 2; |
| // TODO(jc94): Fix the cache write bug in TVM and remove added_op == 2 support. |
| // TVM's cache_write has a bug with multi outputs. See |
| // `tests/python/unittest/test_auto_scheduler_loop_state.py::test_cache_read_write` test |
| // for more details |
| if (added_ops == 2) { |
| pstate->stages.insert(pstate->stages.begin() + next_stage_id, |
| Stage(current_compute_dag->ops[next_stage_id])); |
| next_stage_id++; |
| } else if (added_ops > 2) { |
| LOG(ERROR) << "Unexpected behavior of CacheWrite."; |
| } |
| for (size_t i = next_stage_id; i < current_compute_dag->ops.size(); ++i) { |
| Stage tmp_stage = pstate->stages[i]; |
| tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[i]; |
| pstate->stages.Set(i, std::move(tmp_stage)); |
| } |
| pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(stage_id, added_ops); |
| pstate->current_compute_dag = std::move(current_compute_dag); |
| |
| return stage_id; |
| } |
| |
| Array<te::Tensor> CacheWriteStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes, |
| te::Schedule* schedule) const { |
| const te::Stage& stage = (*stages)[stage_id]; |
| Array<te::Tensor> tensor_array; |
| // If the target stage has multi outputs, TVM requires to cache_write |
| // all of them or schedule.cache_write will raise an error |
| for (auto i = 0; i < stage->op->num_outputs(); ++i) { |
| tensor_array.push_back(stage->origin_op.output(i)); |
| } |
| auto outs = schedule->cache_write(tensor_array, scope_name); |
| |
| UpdateStageToAxesMap(stage, stage_to_axes); |
| // Even if there is multi outputs, TVM schedule only generate one |
| // new stage |
| const auto& new_stage = (*schedule)[outs[0]->op]; |
| UpdateStageToAxesMap(new_stage, stage_to_axes); |
| stages->insert(stages->begin() + stage_id, new_stage); |
| |
| return outs; |
| } |
| |
| String CacheWriteStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, |
| te::Schedule* schedule) const { |
| std::stringstream ss; |
| // Since the original stage will be changed after schedule apply, keep a copy here |
| // These information will be used to print Python API string later |
| te::Stage stage = (*stages)[stage_id]; |
| auto outs = ApplyToSchedule(stages, stage_to_axes, schedule); |
| |
| for (size_t i = 0; i < outs.size(); ++i) { |
| ss << CleanName(outs[i]->op->name) << ", "; |
| } |
| ss << "= " |
| << "s.cache_write([" << CleanName(stage->op.output(0)->op->name); |
| for (auto i = 1; i < stage->op->num_outputs(); ++i) { |
| ss << ", " << CleanName(stage->op.output(i)->op->name); |
| } |
| ss << "], \"" << scope_name << "\")\n"; |
| |
| // Print the iterators of the new added stage |
| for (const auto& out : outs) { |
| const auto& iters = out->op->root_iter_vars(); |
| const auto& op_name = CleanName(out->op->name); |
| for (size_t i = 0; i < iters.size(); ++i) { |
| ss << CleanName(iters[i]->var->name_hint, op_name); |
| if (i != iters.size() - 1) { |
| ss << ", "; |
| } |
| } |
| ss << " = " |
| << "tuple(" << op_name << ".op.axis)" |
| << " + " |
| << "tuple(" << op_name << ".op.reduce_axis)\n"; |
| } |
| |
| return ss.str(); |
| } |
| |
| /********** Rfactor **********/ |
| RfactorStep::RfactorStep(int stage_id, int iter_id, int factor_iter_id) { |
| auto node = make_object<RfactorStepNode>(); |
| node->stage_id = stage_id; |
| node->iter_id = iter_id; |
| node->factor_iter_id = factor_iter_id; |
| data_ = std::move(node); |
| } |
| |
| RfactorStep::RfactorStep(dmlc::JSONReader* reader) { |
| auto node = make_object<RfactorStepNode>(); |
| bool s; |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->stage_id); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->iter_id); |
| s = reader->NextArrayItem(); |
| CHECK(s); |
| reader->Read(&node->factor_iter_id); |
| data_ = std::move(node); |
| } |
| |
| void RfactorStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
| writer->WriteArraySeperator(); |
| writer->WriteString(record_prefix_str); |
| writer->WriteArrayItem(stage_id); |
| writer->WriteArrayItem(iter_id); |
| writer->WriteArrayItem(factor_iter_id); |
| } |
| |
| int RfactorStepNode::ApplyToState(State* state, const ComputeDAG& dag) const { |
| StateNode* pstate = state->CopyOnWrite(); |
| const auto& compute_at_type = pstate->stages[stage_id]->compute_at; |
| const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG( |
| GetFormerStageModifiableSteps(GetRef<Step>(this), (*state)->transform_steps)); |
| |
| // target_stage -> rfactor_compute + target_stage |
| // Insert a new compute stage, update the target stage and later stage, then update the stage_id |
| // mapping in AttachMap |
| pstate->stages.insert(pstate->stages.begin() + stage_id, |
| Stage(current_compute_dag->ops[stage_id])); |
| // Maintain the compute_at type of the target stage |
| Stage target_stage = Stage(current_compute_dag->ops[stage_id + 1]); |
| target_stage.CopyOnWrite()->compute_at = compute_at_type; |
| pstate->stages.Set(stage_id + 1, std::move(target_stage)); |
| for (size_t i = stage_id + 2; i < pstate->stages.size(); ++i) { |
| Stage stage = pstate->stages[i]; |
| stage.CopyOnWrite()->op = current_compute_dag->ops[i]; |
| pstate->stages.Set(i, std::move(stage)); |
| } |
| pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(stage_id); |
| pstate->current_compute_dag = std::move(current_compute_dag); |
| |
| return stage_id; |
| } |
| |
| Array<te::Tensor> RfactorStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
| StageToAxesMap* stage_to_axes, |
| te::Schedule* schedule) const { |
| const auto& stage = (*stages)[stage_id]; |
| const Array<IterVar>& axes = (*stage_to_axes)[stage]; |
| |
| const te::Tensor& tensor = stage->origin_op.output(0); |
| const IterVar& axis = axes[iter_id]; |
| auto outs = schedule->rfactor(tensor, axis, factor_iter_id); |
| |
| UpdateStageToAxesMap(stage, stage_to_axes); |
| const auto& new_stage = (*schedule)[outs[0]->op]; |
| UpdateStageToAxesMap(new_stage, stage_to_axes); |
| stages->insert(stages->begin() + stage_id, new_stage); |
| |
| return outs; |
| } |
| |
| String RfactorStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, |
| te::Schedule* schedule) const { |
| std::stringstream ss; |
| const auto& stage = (*stages)[stage_id]; |
| |
| const auto& tensor_name = CleanName(stage->origin_op.output(0)->op->name); |
| const auto& axis_name = CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint); |
| |
| const auto& outs = ApplyToSchedule(stages, stage_to_axes, schedule); |
| |
| for (size_t i = 0; i < outs.size(); ++i) { |
| ss << CleanName(outs[i]->op->name); |
| if (i != outs.size() - 1) { |
| ss << ", "; |
| } |
| } |
| ss << " = " |
| << "s.rfactor(" << tensor_name << ", " << axis_name << ", " << factor_iter_id << ")\n"; |
| |
| for (const auto& out : outs) { |
| const auto& iters = out->op->root_iter_vars(); |
| const auto& op_name = CleanName(out->op->name); |
| for (size_t i = 0; i < iters.size(); ++i) { |
| ss << CleanName(iters[i]->var->name_hint, op_name); |
| if (i != iters.size() - 1) { |
| ss << ", "; |
| } |
| } |
| ss << " = " |
| << "tuple(" << op_name << ".op.axis)" |
| << " + " |
| << "tuple(" << op_name << ".op.reduce_axis)\n"; |
| } |
| |
| const auto& output = (*stages)[stage_id + 1]->op.output(0); |
| const auto& iters = output->op->root_iter_vars(); |
| const auto& op_name = CleanName(output->op->name); |
| for (size_t i = 0; i < iters.size(); ++i) { |
| ss << CleanName(iters[i]->var->name_hint, op_name); |
| if (i != iters.size() - 1) { |
| ss << ", "; |
| } |
| } |
| ss << " = " |
| << "tuple(s[" << op_name << "].op.axis)" |
| << " + " |
| << "tuple(s[" << op_name << "].op.reduce_axis)\n"; |
| |
| return ss.str(); |
| } |
| |
| } // namespace auto_scheduler |
| } // namespace tvm |