| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| /*! |
| * \file stmt_functor.cc |
| */ |
| #include <tvm/ffi/function.h> |
| #include <tvm/ffi/reflection/registry.h> |
| #include <tvm/ir/module.h> |
| #include <tvm/tir/data_type_rewriter.h> |
| #include <tvm/tir/function.h> |
| #include <tvm/tir/stmt_functor.h> |
| |
| #include <functional> |
| |
| #include "functor_common.h" |
| |
| namespace tvm { |
| namespace tir { |
| |
| void StmtVisitor::VisitStmt_(const LetStmtNode* op) { |
| this->VisitExpr(op->value); |
| this->VisitStmt(op->body); |
| } |
| |
| void StmtVisitor::VisitStmt_(const AttrStmtNode* op) { |
| this->VisitExpr(op->value); |
| this->VisitStmt(op->body); |
| } |
| |
| void StmtVisitor::VisitStmt_(const ForNode* op) { |
| this->VisitExpr(op->min); |
| this->VisitExpr(op->extent); |
| this->VisitStmt(op->body); |
| } |
| |
| void StmtVisitor::VisitStmt_(const WhileNode* op) { |
| this->VisitExpr(op->condition); |
| this->VisitStmt(op->body); |
| } |
| |
| void StmtVisitor::VisitStmt_(const AllocateNode* op) { |
| VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); }); |
| this->VisitStmt(op->body); |
| this->VisitExpr(op->condition); |
| } |
| |
| void StmtVisitor::VisitStmt_(const AllocateConstNode* op) { |
| VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); }); |
| this->VisitStmt(op->body); |
| } |
| |
| void StmtVisitor::VisitStmt_(const DeclBufferNode* op) { this->VisitStmt(op->body); } |
| |
| void StmtVisitor::VisitStmt_(const BufferStoreNode* op) { |
| this->VisitExpr(op->value); |
| VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); |
| } |
| |
| void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) { |
| VisitArray(op->bounds, [this](const Range& r) { |
| this->VisitExpr(r->min); |
| this->VisitExpr(r->extent); |
| }); |
| this->VisitExpr(op->condition); |
| this->VisitStmt(op->body); |
| } |
| |
| void StmtVisitor::VisitStmt_(const IfThenElseNode* op) { |
| this->VisitExpr(op->condition); |
| this->VisitStmt(op->then_case); |
| if (op->else_case) { |
| this->VisitStmt(op->else_case.value()); |
| } |
| } |
| |
| void StmtVisitor::VisitStmt_(const AssertStmtNode* op) { |
| this->VisitExpr(op->condition); |
| this->VisitExpr(op->message); |
| this->VisitStmt(op->body); |
| } |
| |
| void StmtVisitor::VisitStmt_(const SeqStmtNode* op) { |
| VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); |
| } |
| |
| void StmtVisitor::VisitStmt_(const EvaluateNode* op) { this->VisitExpr(op->value); } |
| |
| void StmtVisitor::VisitStmt_(const BlockNode* op) { |
| auto fvisit_buffer_region = [this](const BufferRegion& s) { |
| for (const auto& range : s->region) { |
| this->VisitExpr(range->min); |
| this->VisitExpr(range->extent); |
| } |
| }; |
| VisitArray(op->iter_vars, [this](const IterVar& iter_var) { |
| this->VisitExpr(iter_var->dom->min); |
| this->VisitExpr(iter_var->dom->extent); |
| }); |
| VisitArray(op->reads, fvisit_buffer_region); |
| VisitArray(op->writes, fvisit_buffer_region); |
| VisitArray(op->match_buffers, |
| [fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) { |
| fvisit_buffer_region(match_buffer_region->source); |
| }); |
| if (op->init.defined()) { |
| this->VisitStmt(op->init.value()); |
| } |
| this->VisitStmt(op->body); |
| } |
| |
| void StmtVisitor::VisitStmt_(const BlockRealizeNode* op) { |
| VisitArray(op->iter_values, [this](const PrimExpr& e) { this->VisitExpr(e); }); |
| this->VisitExpr(op->predicate); |
| this->VisitStmt(op->block); |
| } |
| |
| class StmtMutator::Internal { |
| public: |
| /*! |
| * \brief Mutate array's element by fmutate function. |
| * |
| * \note Use extra care for copy on write setting. |
| * |
| * In particular, consider the following case of two reference chains: |
| * - strongref0 -> loop0 -> loop1 -> loop2 |
| * - strongref1 -> loop3 -> loop1 -> loop2 |
| * |
| * Think of the case of calling MutateArray on loop1->loop2(as const reference). |
| * When both strongref0 and strongref1 exists, the context does not allow copy |
| * on write, even though loop1 uniquely refers to loop2. |
| * |
| * \param self The pointer to the mutator. |
| * \param arr Array to be mutated, const reference is used to allow copy on write |
| * mutation in a recursive visitor. |
| * \param fmutate The mutator function. |
| * \return The mutated array, a new copy can be created. |
| */ |
| template <typename T, typename F> |
| static ffi::Array<T> MutateArray(StmtMutator* self, const ffi::Array<T>& arr, F fmutate) { |
| if (self->allow_copy_on_write_ && arr.unique()) { |
| // if we allow copy on write, we can directly |
| // call the inplace mutate function. |
| const_cast<ffi::Array<T>&>(arr).MutateByApply(fmutate); |
| return arr; |
| } else { |
| bool allow_cow = false; |
| std::swap(allow_cow, self->allow_copy_on_write_); |
| ffi::Array<T> copy = arr.Map(fmutate); |
| std::swap(allow_cow, self->allow_copy_on_write_); |
| return copy; |
| } |
| } |
| |
| static ffi::Array<IterVar> Mutate(StmtMutator* self, const ffi::Array<IterVar>& arr) { |
| auto fmutate = [self](const IterVar& iter_var) { |
| PrimExpr min = self->VisitExpr(iter_var->dom->min); |
| PrimExpr extent = self->VisitExpr(iter_var->dom->extent); |
| if (min.same_as(iter_var->dom->min) && extent.same_as(iter_var->dom->extent)) { |
| return iter_var; |
| } else { |
| return IterVar(Range(min, extent), iter_var->var, iter_var->iter_type, |
| iter_var->thread_tag); |
| } |
| }; |
| return MutateArray(self, arr, fmutate); |
| } |
| |
| static ffi::Array<PrimExpr> Mutate(StmtMutator* self, const ffi::Array<PrimExpr>& arr) { |
| auto fmutate = [self](const PrimExpr& e) { return self->VisitExpr(e); }; |
| return MutateArray(self, arr, fmutate); |
| } |
| |
| static ffi::Array<Stmt> Mutate(StmtMutator* self, const ffi::Array<Stmt>& arr) { |
| auto fmutate = [self](const Stmt& s) { return self->VisitStmt(s); }; |
| return MutateArray(self, arr, fmutate); |
| } |
| |
| static ffi::Array<Range> Mutate(StmtMutator* self, const ffi::Array<Range>& arr) { |
| auto fmutate = [self](const Range& r) { |
| PrimExpr min = self->VisitExpr(r->min); |
| PrimExpr extent = self->VisitExpr(r->extent); |
| if (min.same_as(r->min) && extent.same_as(r->extent)) { |
| return r; |
| } else { |
| return Range::FromMinExtent(min, extent); |
| } |
| }; |
| return MutateArray(self, arr, fmutate); |
| } |
| |
| static ffi::Array<BufferRegion> Mutate(StmtMutator* self, const ffi::Array<BufferRegion>& arr) { |
| auto fmutate = [self](const BufferRegion& buffer_region) { |
| ffi::Array<Range> region = Mutate(self, buffer_region->region); |
| if (region.same_as(buffer_region->region)) { |
| return buffer_region; |
| } else { |
| return BufferRegion(buffer_region->buffer, region); |
| } |
| }; |
| return MutateArray(self, arr, fmutate); |
| } |
| |
| static ffi::Array<MatchBufferRegion> Mutate(StmtMutator* self, |
| const ffi::Array<MatchBufferRegion>& arr) { |
| auto fmutate = [self](const MatchBufferRegion& match_buffer_region) { |
| ffi::Array<Range> region = Mutate(self, match_buffer_region->source->region); |
| if (region.same_as(match_buffer_region->source->region)) { |
| return match_buffer_region; |
| } else { |
| return MatchBufferRegion(match_buffer_region->buffer, |
| BufferRegion(match_buffer_region->source->buffer, region)); |
| } |
| }; |
| return MutateArray(self, arr, fmutate); |
| } |
| }; |
| |
| Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) { |
| PrimExpr value = this->VisitExpr(op->value); |
| Stmt body = this->VisitStmt(op->body); |
| if (value.same_as(op->value) && body.same_as(op->body)) { |
| return ffi::GetRef<Stmt>(op); |
| } else { |
| auto n = CopyOnWrite(op); |
| n->value = std::move(value); |
| n->body = std::move(body); |
| return Stmt(n); |
| } |
| } |
| |
| Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) { |
| PrimExpr value = this->VisitExpr(op->value); |
| Stmt body = this->VisitStmt(op->body); |
| if (value.same_as(op->value) && body.same_as(op->body)) { |
| return ffi::GetRef<Stmt>(op); |
| } else { |
| auto n = CopyOnWrite(op); |
| n->value = std::move(value); |
| n->body = std::move(body); |
| return Stmt(n); |
| } |
| } |
| |
| Stmt StmtMutator::VisitStmt_(const ForNode* op) { |
| PrimExpr min = this->VisitExpr(op->min); |
| PrimExpr extent = this->VisitExpr(op->extent); |
| Stmt body = this->VisitStmt(op->body); |
| if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) { |
| return ffi::GetRef<Stmt>(op); |
| } else { |
| auto n = CopyOnWrite(op); |
| n->min = std::move(min); |
| n->extent = std::move(extent); |
| n->body = std::move(body); |
| return Stmt(n); |
| } |
| } |
| |
| Stmt StmtMutator::VisitStmt_(const WhileNode* op) { |
| PrimExpr condition = this->VisitExpr(op->condition); |
| Stmt body = this->VisitStmt(op->body); |
| if (condition.same_as(op->condition) && body.same_as(op->body)) { |
| return ffi::GetRef<Stmt>(op); |
| } else { |
| auto n = CopyOnWrite(op); |
| n->condition = std::move(condition); |
| n->body = std::move(body); |
| return Stmt(n); |
| } |
| } |
| |
| Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { |
| ffi::Array<PrimExpr> extents = Internal::Mutate(this, op->extents); |
| Stmt body = this->VisitStmt(op->body); |
| PrimExpr condition = this->VisitExpr(op->condition); |
| |
| if (extents.same_as(op->extents) && body.same_as(op->body) && condition.same_as(op->condition)) { |
| return ffi::GetRef<Stmt>(op); |
| } else { |
| auto n = CopyOnWrite(op); |
| n->extents = std::move(extents); |
| n->body = std::move(body); |
| n->condition = std::move(condition); |
| return Stmt(n); |
| } |
| } |
| |
| Stmt StmtMutator::VisitStmt_(const AllocateConstNode* op) { |
| ffi::Array<PrimExpr> extents = Internal::Mutate(this, op->extents); |
| Stmt body = this->VisitStmt(op->body); |
| |
| if (extents.same_as(op->extents) && body.same_as(op->body)) { |
| return ffi::GetRef<Stmt>(op); |
| } else { |
| auto n = CopyOnWrite(op); |
| n->extents = std::move(extents); |
| n->body = std::move(body); |
| return Stmt(n); |
| } |
| } |
| |
| Stmt StmtMutator::VisitStmt_(const DeclBufferNode* op) { |
| Stmt body = this->VisitStmt(op->body); |
| |
| if (body.same_as(op->body)) { |
| return ffi::GetRef<Stmt>(op); |
| } else { |
| auto n = CopyOnWrite(op); |
| n->body = std::move(body); |
| return Stmt(n); |
| } |
| } |
| |
| Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { |
| PrimExpr condition = this->VisitExpr(op->condition); |
| Stmt then_case = this->VisitStmt(op->then_case); |
| ffi::Optional<Stmt> else_case = std::nullopt; |
| if (op->else_case) { |
| else_case = this->VisitStmt(op->else_case.value()); |
| } |
| if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && |
| else_case.same_as(op->else_case)) { |
| return ffi::GetRef<Stmt>(op); |
| } else { |
| auto n = CopyOnWrite(op); |
| n->condition = std::move(condition); |
| n->then_case = std::move(then_case); |
| n->else_case = std::move(else_case); |
| return Stmt(n); |
| } |
| } |
| |
| Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) { |
| PrimExpr value = this->VisitExpr(op->value); |
| ffi::Array<PrimExpr> indices = Internal::Mutate(this, op->indices); |
| |
| if (value.same_as(op->value) && indices.same_as(op->indices)) { |
| return ffi::GetRef<Stmt>(op); |
| } else { |
| auto n = CopyOnWrite(op); |
| n->value = std::move(value); |
| n->indices = std::move(indices); |
| return Stmt(n); |
| } |
| } |
| |
| Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { |
| Region bounds = Internal::Mutate(this, op->bounds); |
| PrimExpr condition = this->VisitExpr(op->condition); |
| Stmt body = this->VisitStmt(op->body); |
| |
| if (bounds.same_as(op->bounds) && condition.same_as(op->condition) && body.same_as(op->body)) { |
| return ffi::GetRef<Stmt>(op); |
| } else { |
| auto n = CopyOnWrite(op); |
| n->bounds = std::move(bounds); |
| n->condition = std::move(condition); |
| n->body = std::move(body); |
| return Stmt(n); |
| } |
| } |
| |
| Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) { |
| ffi::Array<Stmt> seq = Internal::Mutate(this, op->seq); |
| if (seq.same_as(op->seq)) { |
| return SeqStmt::Flatten(ffi::GetRef<Stmt>(op)); |
| } else { |
| auto node = CopyOnWrite(op); |
| node->seq = std::move(seq); |
| return SeqStmt::Flatten(SeqStmt(node)); |
| } |
| } |
| |
| // advanced visit function for seqstmt. |
| Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit, |
| std::function<Stmt(const Stmt&)> fmutate) { |
| if (flatten_before_visit) { |
| // Pass 1, check if we need to flatten. |
| bool need_flatten = false; |
| for (size_t i = 0; i < op->seq.size(); ++i) { |
| Stmt tmp = (*op)[i]; |
| if (tmp.as<SeqStmtNode>()) need_flatten = true; |
| } |
| flatten_before_visit = need_flatten; |
| } |
| // function to run the visit. |
| auto frunvisit = [&](const SeqStmtNode* op) { |
| ffi::Array<Stmt> seq = fmutate != nullptr ? Internal::MutateArray(this, op->seq, fmutate) |
| : Internal::Mutate(this, op->seq); |
| if (seq.same_as(op->seq)) { |
| return ffi::GetRef<Stmt>(op); |
| } else { |
| auto n = CopyOnWrite(op); |
| n->seq = std::move(seq); |
| return Stmt(n); |
| } |
| }; |
| if (flatten_before_visit) { |
| ffi::Array<Stmt> seq; |
| SeqStmt::Flattener flattener(&seq); |
| flattener(0, op->seq); |
| // NOTE: If copy on write is allowed |
| // the assignment to seq below will |
| // destruct the original seq. |
| // |
| // Such destruction removes duplicated reference |
| // count to children and still enables COW for |
| // child Stmt. |
| ObjectPtr<SeqStmtNode> n = CopyOnWrite(op); |
| n->seq = std::move(seq); |
| return frunvisit(n.operator->()); |
| } else { |
| return frunvisit(op); |
| } |
| } |
| |
| Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) { |
| PrimExpr condition = this->VisitExpr(op->condition); |
| PrimExpr message = this->VisitExpr(op->message); |
| Stmt body = this->VisitStmt(op->body); |
| |
| if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) { |
| return ffi::GetRef<Stmt>(op); |
| } else { |
| auto n = CopyOnWrite(op); |
| n->condition = std::move(condition); |
| n->message = std::move(message); |
| n->body = std::move(body); |
| return Stmt(n); |
| } |
| } |
| |
| Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) { |
| PrimExpr value = this->VisitExpr(op->value); |
| if (value.same_as(op->value)) { |
| return ffi::GetRef<Stmt>(op); |
| } else { |
| auto n = CopyOnWrite(op); |
| n->value = std::move(value); |
| return Stmt(n); |
| } |
| } |
| |
| Stmt StmtMutator::VisitStmt_(const BlockNode* op) { |
| ffi::Array<IterVar> iter_vars = Internal::Mutate(this, op->iter_vars); |
| ffi::Array<BufferRegion> reads = Internal::Mutate(this, op->reads); |
| ffi::Array<BufferRegion> writes = Internal::Mutate(this, op->writes); |
| ffi::Array<MatchBufferRegion> match_buffers = Internal::Mutate(this, op->match_buffers); |
| ffi::Optional<Stmt> init = std::nullopt; |
| if (op->init.defined()) { |
| init = VisitStmt(op->init.value()); |
| } |
| Stmt body = VisitStmt(op->body); |
| if (iter_vars.same_as(op->iter_vars) && reads.same_as(op->reads) && writes.same_as(op->writes) && |
| body.same_as(op->body) && init.same_as(op->init) && |
| match_buffers.same_as(op->match_buffers)) { |
| return ffi::GetRef<Block>(op); |
| } else { |
| auto n = CopyOnWrite(op); |
| n->iter_vars = std::move(iter_vars); |
| n->reads = std::move(reads); |
| n->writes = std::move(writes); |
| n->body = std::move(body); |
| n->init = std::move(init); |
| n->match_buffers = std::move(match_buffers); |
| return Stmt(n); |
| } |
| } |
| |
| Stmt StmtMutator::VisitStmt_(const BlockRealizeNode* op) { |
| ffi::Array<PrimExpr> v = Internal::Mutate(this, op->iter_values); |
| PrimExpr pred = this->VisitExpr(op->predicate); |
| Stmt block = this->VisitStmt(op->block); |
| if (v.same_as(op->iter_values) && pred.same_as(op->predicate) && block.same_as(op->block)) { |
| return ffi::GetRef<Stmt>(op); |
| } else { |
| auto n = CopyOnWrite(op); |
| n->iter_values = std::move(v); |
| n->predicate = std::move(pred); |
| n->block = Downcast<Block>(block); |
| return Stmt(n); |
| } |
| } |
| |
| // Implementations of IRTransform, PostOrderVisit and Substitute |
| class IRApplyVisit : public StmtExprVisitor { |
| public: |
| explicit IRApplyVisit(std::function<void(const ObjectRef&)> f) : f_(f) {} |
| |
| void VisitExpr(const PrimExpr& node) final { |
| if (visited_.count(node.get()) != 0) return; |
| visited_.insert(node.get()); |
| ExprVisitor::VisitExpr(node); |
| f_(node); |
| } |
| |
| void VisitStmt(const Stmt& node) final { |
| if (visited_.count(node.get()) != 0) return; |
| visited_.insert(node.get()); |
| StmtVisitor::VisitStmt(node); |
| f_(node); |
| } |
| |
| private: |
| std::function<void(const ObjectRef&)> f_; |
| std::unordered_set<const Object*> visited_; |
| }; |
| |
| void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit) { |
| if (node.as<StmtNode>()) { |
| IRApplyVisit visitor(fvisit); |
| visitor(Downcast<Stmt>(node)); |
| } else { |
| IRApplyVisit visitor(fvisit); |
| visitor(Downcast<PrimExpr>(node)); |
| } |
| } |
| |
| class IRTransformer final : public StmtExprMutator { |
| public: |
| IRTransformer(const ffi::Function& f_preorder, const ffi::Function& f_postorder, |
| const std::unordered_set<uint32_t>& only_enable) |
| : f_preorder_(f_preorder), f_postorder_(f_postorder), only_enable_(only_enable) {} |
| |
| Stmt VisitStmt(const Stmt& stmt) final { |
| return MutateInternal<Stmt>(stmt, [this](const Stmt& s) { return this->BaseVisitStmt(s); }); |
| } |
| PrimExpr VisitExpr(const PrimExpr& expr) final { |
| return MutateInternal<PrimExpr>(expr, |
| [this](const PrimExpr& e) { return this->BaseVisitExpr(e); }); |
| } |
| |
| private: |
| // NOTE: redirect to parent's call |
| // This is used to get around limitation of gcc-4.8 |
| Stmt BaseVisitStmt(const Stmt& s) { return StmtMutator::VisitStmt(s); } |
| PrimExpr BaseVisitExpr(const PrimExpr& e) { return ExprMutator::VisitExpr(e); } |
| |
| template <typename T, typename F> |
| T MutateInternal(const T& node, F fmutate) { |
| if (only_enable_.size() && !only_enable_.count(node->type_index())) { |
| return fmutate(node); |
| } |
| if (f_preorder_ != nullptr) { |
| T pre = f_preorder_(node).template cast<T>(); |
| if (pre.defined()) return pre; |
| } |
| T new_node = fmutate(node); |
| if (f_postorder_ != nullptr) { |
| T post = f_postorder_(new_node).template cast<T>(); |
| if (post.defined()) return post; |
| } |
| return new_node; |
| } |
| // The functions |
| const ffi::Function& f_preorder_; |
| const ffi::Function& f_postorder_; |
| // type indices enabled. |
| const std::unordered_set<uint32_t>& only_enable_; |
| }; |
| |
| Stmt IRTransform(Stmt ir_node, const ffi::Function& f_preorder, const ffi::Function& f_postorder, |
| ffi::Optional<ffi::Array<ffi::String>> only_enable) { |
| std::unordered_set<uint32_t> only_type_index; |
| if (only_enable.defined()) { |
| for (auto s : only_enable.value()) { |
| only_type_index.insert(ffi::TypeKeyToIndex(s.c_str())); |
| } |
| } |
| IRTransformer transform(f_preorder, f_postorder, only_type_index); |
| return transform(std::move(ir_node)); |
| } |
| |
| class IRSubstitute : public StmtExprMutator { |
| public: |
| explicit IRSubstitute(std::function<ffi::Optional<PrimExpr>(const Var&)> vmap) : vmap_(vmap) {} |
| |
| PrimExpr VisitExpr_(const VarNode* op) final { |
| Var var = ffi::GetRef<Var>(op); |
| auto ret = vmap_(var); |
| if (ret.defined()) { |
| // Allow substitution of void variables with any expression. The TVM script parser |
| // uses void variables for lambda parameters (since exact types are not known yet). |
| if (!var.dtype().is_void()) { |
| PrimExpr ret_ex = Downcast<PrimExpr>(ret.value()); |
| ICHECK(ret_ex.dtype() == var.dtype()) << "substituting " << var << ":" << var.dtype() |
| << " -> " << ret_ex << ":" << ret_ex.dtype(); |
| } |
| return ret.value(); |
| } |
| return var; |
| } |
| |
| PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
| auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
| return VisitBufferAccess(std::move(node)); |
| } |
| |
| Stmt VisitStmt_(const BufferStoreNode* op) final { |
| auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
| return VisitBufferAccess(std::move(node)); |
| } |
| |
| Stmt VisitStmt_(const DeclBufferNode* op) final { |
| auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op)); |
| return VisitBufferAccess(std::move(node)); |
| } |
| |
| template <typename Node> |
| Node VisitBufferAccess(Node node) { |
| Buffer new_buf = GetRemappedBuffer(node->buffer); |
| |
| if (!new_buf.same_as(node->buffer)) { |
| auto writer = node.CopyOnWrite(); |
| writer->buffer = new_buf; |
| } |
| |
| return node; |
| } |
| |
| Buffer GetRemappedBuffer(Buffer buf) { |
| auto key = buf.get(); |
| auto it = buf_remap_.find(key); |
| if (it != buf_remap_.end()) { |
| return it->second; |
| } |
| |
| PrimExpr new_buffer_var_expr = VisitExpr(buf->data); |
| CHECK(new_buffer_var_expr->IsInstance<VarNode>()) |
| << "Buffer " << buf << " uses backing allocation " << buf->data |
| << ", which was substituted into the expression " << new_buffer_var_expr << ". " |
| << "However, this expression is of type " << new_buffer_var_expr->GetTypeKey() |
| << " and the backing allocation must be a tir::Var"; |
| |
| Var buffer_var = Downcast<Var>(new_buffer_var_expr); |
| auto elem_offset = VisitExpr(buf->elem_offset); |
| auto shape = buf->shape.Map([this](const auto& expr) { return VisitExpr(expr); }); |
| auto strides = buf->strides.Map([this](const auto& expr) { return VisitExpr(expr); }); |
| |
| if (!buffer_var.same_as(buf->data) || !elem_offset.same_as(buf->elem_offset) || |
| !shape.same_as(buf->shape) || !strides.same_as(buf->strides)) { |
| auto writer = buf.CopyOnWrite(); |
| writer->data = buffer_var; |
| writer->elem_offset = elem_offset; |
| writer->shape = shape; |
| writer->strides = strides; |
| } |
| |
| buf_remap_[key] = buf; |
| return buf; |
| } |
| |
| Stmt VisitStmt_(const AttrStmtNode* op) final { |
| Stmt ret = StmtExprMutator::VisitStmt_(op); |
| op = ret.as<AttrStmtNode>(); |
| // remap var node in attr |
| if (auto var_node = op->node.as<Var>()) { |
| if (auto mapped_var = vmap_(var_node.value())) { |
| return AttrStmt(mapped_var, op->attr_key, op->value, op->body); |
| } |
| } |
| return ret; |
| } |
| |
| private: |
| // Caller provided function that defines the variables to be remapped. |
| std::function<ffi::Optional<PrimExpr>(const Var&)> vmap_; |
| |
| /* \brief Generated map to track buffers being remapped. |
| * |
| * If a `Var BufferNode::data` is remapped, then all buffers |
| * containing that data pointer should also be remapped. This map |
| * is used to track buffer modifications, and ensure all instances |
| * of a buffer are replaced by the same modified buffer object. |
| */ |
| std::unordered_map<const BufferNode*, Buffer> buf_remap_; |
| }; |
| |
| Stmt Substitute(Stmt stmt, std::function<ffi::Optional<PrimExpr>(const Var&)> vmap) { |
| return IRSubstitute(vmap)(std::move(stmt)); |
| } |
| |
| PrimExpr Substitute(PrimExpr expr, std::function<ffi::Optional<PrimExpr>(const Var&)> vmap) { |
| return IRSubstitute(vmap)(std::move(expr)); |
| } |
| |
| void PreOrderVisit(const ObjectRef& stmt_or_expr, |
| const std::function<bool(const ObjectRef&)>& fvisit) { |
| class PreOrderVisitor : public StmtExprVisitor { |
| public: |
| explicit PreOrderVisitor(const std::function<bool(const ObjectRef&)>& f) : f_(f) {} |
| |
| private: |
| void VisitExpr(const PrimExpr& expr) final { |
| const PrimExprNode* p_expr = expr.get(); |
| if (visited_.count(p_expr) == 0) { |
| visited_.insert(p_expr); |
| if (f_(expr)) { |
| ExprVisitor::VisitExpr(expr); |
| } |
| } |
| } |
| |
| void VisitStmt(const Stmt& stmt) final { |
| const StmtNode* p_stmt = stmt.get(); |
| if (visited_.count(p_stmt) == 0) { |
| visited_.insert(p_stmt); |
| if (f_(stmt)) { |
| StmtVisitor::VisitStmt(stmt); |
| } |
| } |
| } |
| |
| const std::function<bool(const ObjectRef&)>& f_; |
| std::unordered_set<const Object*> visited_; |
| }; |
| |
| PreOrderVisitor visitor(fvisit); |
| if (auto stmt = stmt_or_expr.as<Stmt>()) { |
| visitor(stmt.value()); |
| } else if (auto expr = stmt_or_expr.as<PrimExpr>()) { |
| visitor(expr.value()); |
| } else { |
| LOG(FATAL) << "InternalError: PreOrderVisit does not accept object with type: " |
| << stmt_or_expr->GetTypeKey(); |
| } |
| } |
| |
| class IRSubstituteWithDataTypeLegalization : public DataTypeLegalizer { |
| public: |
| explicit IRSubstituteWithDataTypeLegalization( |
| std::function<ffi::Optional<PrimExpr>(const Var&)> vmap) |
| : vmap_(vmap) {} |
| |
| using DataTypeLegalizer::VisitExpr_; |
| using DataTypeLegalizer::VisitStmt_; |
| |
| PrimExpr VisitExpr_(const VarNode* op) final { |
| Var var = ffi::GetRef<Var>(op); |
| auto ret = vmap_(var); |
| if (ret.defined()) { |
| return ret.value(); |
| } |
| return var; |
| } |
| |
| PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
| auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
| return VisitBufferAccess(std::move(node)); |
| } |
| |
| Stmt VisitStmt_(const BufferStoreNode* op) final { |
| auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
| return VisitBufferAccess(std::move(node)); |
| } |
| |
| template <typename Node> |
| Node VisitBufferAccess(Node node) { |
| Buffer new_buf = GetRemappedBuffer(node->buffer); |
| |
| if (!new_buf.same_as(node->buffer)) { |
| auto writer = node.CopyOnWrite(); |
| writer->buffer = new_buf; |
| } |
| |
| return node; |
| } |
| |
| Buffer GetRemappedBuffer(Buffer buf) { |
| auto key = buf.get(); |
| auto it = buf_remap_.find(key); |
| if (it != buf_remap_.end()) { |
| return it->second; |
| } |
| |
| auto new_buffer_var = vmap_(buf->data); |
| if (new_buffer_var.defined() && !new_buffer_var.value().same_as(buf->data)) { |
| auto writer = buf.CopyOnWrite(); |
| writer->data = Downcast<Var>(new_buffer_var); |
| } |
| |
| buf_remap_[key] = buf; |
| return buf; |
| } |
| |
| Stmt VisitStmt_(const AttrStmtNode* op) final { |
| Stmt ret = StmtExprMutator::VisitStmt_(op); |
| op = ret.as<AttrStmtNode>(); |
| // remap var node in attr |
| if (auto var_node = op->node.as<Var>()) { |
| if (auto mapped_var = vmap_(var_node.value())) { |
| return AttrStmt(mapped_var, op->attr_key, op->value, op->body); |
| } |
| } |
| return ret; |
| } |
| |
| private: |
| // Caller provided function that defines the variables to be remapped. |
| std::function<ffi::Optional<PrimExpr>(const Var&)> vmap_; |
| |
| /* \brief Generated map to track buffers being remapped. |
| * |
| * If a `Var BufferNode::data` is remapped, then all buffers |
| * containing that data pointer should also be remapped. This map |
| * is used to track buffer modifications, and ensure all instances |
| * of a buffer are replaced by the same modified buffer object. |
| */ |
| std::unordered_map<const BufferNode*, Buffer> buf_remap_; |
| }; |
| |
| Stmt SubstituteWithDataTypeLegalization(Stmt stmt, |
| std::function<ffi::Optional<PrimExpr>(const Var&)> vmap) { |
| return IRSubstituteWithDataTypeLegalization(vmap)(std::move(stmt)); |
| } |
| |
| PrimExpr SubstituteWithDataTypeLegalization( |
| PrimExpr expr, std::function<ffi::Optional<PrimExpr>(const Var&)> vmap) { |
| return IRSubstituteWithDataTypeLegalization(vmap)(std::move(expr)); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef() |
| .def("tir.IRTransform", IRTransform) |
| .def("tir.PostOrderVisit", |
| [](ObjectRef node, ffi::Function f) { |
| tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); }); |
| }) |
| .def("tir.PreOrderVisit", |
| [](ObjectRef node, ffi::Function f) { |
| tir::PreOrderVisit(node, [f](const ObjectRef& n) { return f(n).cast<bool>(); }); |
| }) |
| .def("tir.Substitute", [](ObjectRef node, ffi::Map<Var, PrimExpr> vmap) -> ObjectRef { |
| if (node->IsInstance<StmtNode>()) { |
| return Substitute(Downcast<Stmt>(node), vmap); |
| } else { |
| return Substitute(Downcast<PrimExpr>(node), vmap); |
| } |
| }); |
| } |
| |
| } // namespace tir |
| } // namespace tvm |